espnet_onnx.patch 5.64 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
diff --git a/espnet_onnx/export/tts/models/tts_models/fastspeech2.py b/espnet_onnx/export/tts/models/tts_models/fastspeech2.py
index 29197e2..0022e51 100644
--- a/espnet_onnx/export/tts/models/tts_models/fastspeech2.py
+++ b/espnet_onnx/export/tts/models/tts_models/fastspeech2.py
@@ -10,7 +10,7 @@ from espnet_onnx.utils.torch_function import MakePadMask, normalize
 
 
 class OnnxLengthRegurator(nn.Module):
-    def __init__(self, alpha=1.0, max_seq_len=512):
+    def __init__(self, alpha=1.0, max_seq_len=1000):
         super().__init__()
         self.alpha = alpha
         # The maximum length of the make_pad_mask is the
@@ -59,7 +59,7 @@ class OnnxFastSpeech2(nn.Module, AbsExportModel):
     def __init__(
         self,
         model,
-        max_seq_len: int = 512,
+        max_seq_len: int = 1000,
         alpha: float = 1.0,
         use_cache: bool = True,
         **kwargs,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
diff --git a/espnet_onnx/tts/abs_tts_model.py b/espnet_onnx/tts/abs_tts_model.py
index 591947a..9de64a5 100644
--- a/espnet_onnx/tts/abs_tts_model.py
+++ b/espnet_onnx/tts/abs_tts_model.py
@@ -86,20 +86,20 @@ class AbsTTSModel(AbsModel):
         self._build_normalizer()
         self._build_vocoder(providers, use_quantized)
 
-    def _check_ort_version(self, providers: List[str]):
+    def _check_ort_version(self, providers: List):
         # check cpu
         if (
             onnxruntime.get_device() == "CPU"
             and "CPUExecutionProvider" not in providers
-        ):
-            raise RuntimeError(
-                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
-            )
+        ): pass
+#            raise RuntimeError(
+#                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
+#            )
 
         # check GPU
-        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]:
-            warnings.warn(
-                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
-            )
+        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: pass
+#            warnings.warn(
+#                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
+#            )
 
-        logging.info(f'Providers [{" ,".join(providers)}] detected.')
+#        logging.info(f'Providers [{" ,".join(providers)}] detected.')
diff --git a/espnet_onnx/tts/tts_model.py b/espnet_onnx/tts/tts_model.py
index 78023f5..de4ebba 100644
--- a/espnet_onnx/tts/tts_model.py
+++ b/espnet_onnx/tts/tts_model.py
@@ -14,7 +14,7 @@ class Text2Speech(AbsTTSModel):
         self,
         tag_name: str = None,
         model_dir: Union[Path, str] = None,
-        providers: List[str] = ["CPUExecutionProvider"],
+        providers: List = ["CPUExecutionProvider"],
         use_quantized: bool = False,
     ):
         assert check_argument_types()
diff --git a/espnet_onnx/utils/abs_model.py b/espnet_onnx/utils/abs_model.py
index 1270468..4aa63c6 100644
--- a/espnet_onnx/utils/abs_model.py
+++ b/espnet_onnx/utils/abs_model.py
@@ -46,23 +46,23 @@ class AbsModel(ABC):
     def _build_model(self, providers, use_quantized):
         raise NotImplementedError
 
-    def _check_ort_version(self, providers: List[str]):
+    def _check_ort_version(self, providers: List):
         # check cpu
         if (
             onnxruntime.get_device() == "CPU"
             and "CPUExecutionProvider" not in providers
-        ):
-            raise RuntimeError(
-                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
-            )
+        ): pass
+#            raise RuntimeError(
+#                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
+#            )
 
         # check GPU
-        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]:
-            warnings.warn(
-                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
-            )
+        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: pass
+#            warnings.warn(
+#                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
+#            )
 
-        logging.info(f'Providers [{" ,".join(providers)}] detected.')
+#        logging.info(f'Providers [{" ,".join(providers)}] detected.')
 
 
 class AbsExportModel(ABC):
109 110 111 112 113 114 115 116 117 118 119 120 121
diff --git a/espnet_onnx/utils/torch_function.py b/espnet_onnx/utils/torch_function.py
index c274346..f88f8bc 100644
--- a/espnet_onnx/utils/torch_function.py
+++ b/espnet_onnx/utils/torch_function.py
@@ -6,7 +6,7 @@ import torch.nn as nn
 
 
 class MakePadMask(nn.Module):
-    def __init__(self, max_seq_len=512, flip=True):
+    def __init__(self, max_seq_len=1000, flip=True):
         super().__init__()
         if flip:
             self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
diff --git a/setup.py b/setup.py
index 483b062..ee37d37 100644
--- a/setup.py
+++ b/setup.py
@@ -4,9 +4,9 @@ requirements = {
     "install": [
         "setuptools>=38.5.1",
         "librosa>=0.8.0",
-        "onnxruntime",
+        "onnxruntime-gpu",
         "sentencepiece>=0.1.91,!=0.1.92",
-        "typeguard==2.13.0",
+        "typeguard==2.13.3",
         "PyYAML>=5.1.2",
         "g2p-en",
         "jamo==0.4.1",  # For kss