espnet_onnx.patch 4.27 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
diff --git a/espnet_onnx/tts/abs_tts_model.py b/espnet_onnx/tts/abs_tts_model.py
index 591947a..9de64a5 100644
--- a/espnet_onnx/tts/abs_tts_model.py
+++ b/espnet_onnx/tts/abs_tts_model.py
@@ -86,20 +86,20 @@ class AbsTTSModel(AbsModel):
         self._build_normalizer()
         self._build_vocoder(providers, use_quantized)
 
-    def _check_ort_version(self, providers: List[str]):
+    def _check_ort_version(self, providers: List):
         # check cpu
         if (
             onnxruntime.get_device() == "CPU"
             and "CPUExecutionProvider" not in providers
-        ):
-            raise RuntimeError(
-                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
-            )
+        ): pass
+#            raise RuntimeError(
+#                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
+#            )
 
         # check GPU
-        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]:
-            warnings.warn(
-                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
-            )
+        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: pass
+#            warnings.warn(
+#                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
+#            )
 
-        logging.info(f'Providers [{" ,".join(providers)}] detected.')
+#        logging.info(f'Providers [{" ,".join(providers)}] detected.')
diff --git a/espnet_onnx/tts/tts_model.py b/espnet_onnx/tts/tts_model.py
index 78023f5..de4ebba 100644
--- a/espnet_onnx/tts/tts_model.py
+++ b/espnet_onnx/tts/tts_model.py
@@ -14,7 +14,7 @@ class Text2Speech(AbsTTSModel):
         self,
         tag_name: str = None,
         model_dir: Union[Path, str] = None,
-        providers: List[str] = ["CPUExecutionProvider"],
+        providers: List = ["CPUExecutionProvider"],
         use_quantized: bool = False,
     ):
         assert check_argument_types()
diff --git a/espnet_onnx/utils/abs_model.py b/espnet_onnx/utils/abs_model.py
index 1270468..4aa63c6 100644
--- a/espnet_onnx/utils/abs_model.py
+++ b/espnet_onnx/utils/abs_model.py
@@ -46,23 +46,23 @@ class AbsModel(ABC):
     def _build_model(self, providers, use_quantized):
         raise NotImplementedError
 
-    def _check_ort_version(self, providers: List[str]):
+    def _check_ort_version(self, providers: List):
         # check cpu
         if (
             onnxruntime.get_device() == "CPU"
             and "CPUExecutionProvider" not in providers
-        ):
-            raise RuntimeError(
-                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
-            )
+        ): pass
+#            raise RuntimeError(
+#                "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu."
+#            )
 
         # check GPU
-        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]:
-            warnings.warn(
-                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
-            )
+        if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: pass
+#            warnings.warn(
+#                "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail."
+#            )
 
-        logging.info(f'Providers [{" ,".join(providers)}] detected.')
+#        logging.info(f'Providers [{" ,".join(providers)}] detected.')
 
 
 class AbsExportModel(ABC):
diff --git a/setup.py b/setup.py
index 483b062..ee37d37 100644
--- a/setup.py
+++ b/setup.py
@@ -4,9 +4,9 @@ requirements = {
     "install": [
         "setuptools>=38.5.1",
         "librosa>=0.8.0",
-        "onnxruntime",
+        "onnxruntime-gpu",
         "sentencepiece>=0.1.91,!=0.1.92",
-        "typeguard==2.13.0",
+        "typeguard==2.13.3",
         "PyYAML>=5.1.2",
         "g2p-en",
         "jamo==0.4.1",  # For kss