diff --git a/espnet_onnx/tts/abs_tts_model.py b/espnet_onnx/tts/abs_tts_model.py index 591947a..9de64a5 100644 --- a/espnet_onnx/tts/abs_tts_model.py +++ b/espnet_onnx/tts/abs_tts_model.py @@ -86,20 +86,20 @@ class AbsTTSModel(AbsModel): self._build_normalizer() self._build_vocoder(providers, use_quantized) - def _check_ort_version(self, providers: List[str]): + def _check_ort_version(self, providers: List): # check cpu if ( onnxruntime.get_device() == "CPU" and "CPUExecutionProvider" not in providers - ): - raise RuntimeError( - "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu." - ) + ): pass +# raise RuntimeError( +# "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu." +# ) # check GPU - if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: - warnings.warn( - "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail." - ) + if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: pass +# warnings.warn( +# "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail." +# ) - logging.info(f'Providers [{" ,".join(providers)}] detected.') +# logging.info(f'Providers [{" ,".join(providers)}] detected.') diff --git a/espnet_onnx/tts/tts_model.py b/espnet_onnx/tts/tts_model.py index 78023f5..de4ebba 100644 --- a/espnet_onnx/tts/tts_model.py +++ b/espnet_onnx/tts/tts_model.py @@ -14,7 +14,7 @@ class Text2Speech(AbsTTSModel): self, tag_name: str = None, model_dir: Union[Path, str] = None, - providers: List[str] = ["CPUExecutionProvider"], + providers: List = ["CPUExecutionProvider"], use_quantized: bool = False, ): assert check_argument_types() diff --git a/espnet_onnx/utils/abs_model.py b/espnet_onnx/utils/abs_model.py index 1270468..4aa63c6 100644 --- a/espnet_onnx/utils/abs_model.py +++ b/espnet_onnx/utils/abs_model.py @@ -46,23 +46,23 @@ class AbsModel(ABC): def _build_model(self, providers, use_quantized): raise NotImplementedError - def _check_ort_version(self, providers: List[str]): + def _check_ort_version(self, providers: List): # check cpu if ( onnxruntime.get_device() == "CPU" and "CPUExecutionProvider" not in providers - ): - raise RuntimeError( - "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu." - ) + ): pass +# raise RuntimeError( +# "If you want to use GPU, then follow `How to use GPU on espnet_onnx` chapter in readme to install onnxruntime-gpu." +# ) # check GPU - if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: - warnings.warn( - "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail." - ) + if onnxruntime.get_device() == "GPU" and providers == ["CPUExecutionProvider"]: pass +# warnings.warn( +# "Inference will be executed on the CPU. Please provide gpu providers. Read `How to use GPU on espnet_onnx` in readme in detail." +# ) - logging.info(f'Providers [{" ,".join(providers)}] detected.') +# logging.info(f'Providers [{" ,".join(providers)}] detected.') class AbsExportModel(ABC): diff --git a/setup.py b/setup.py index 483b062..ee37d37 100644 --- a/setup.py +++ b/setup.py @@ -4,9 +4,9 @@ requirements = { "install": [ "setuptools>=38.5.1", "librosa>=0.8.0", - "onnxruntime", + "onnxruntime-gpu", "sentencepiece>=0.1.91,!=0.1.92", - "typeguard==2.13.0", + "typeguard==2.13.3", "PyYAML>=5.1.2", "g2p-en", "jamo==0.4.1", # For kss