diff --git a/make_triton_repo.sh b/make_triton_repo.sh index 1a9b4c15b167fde671b18d5dc61293604c4cd9e3..c730fdd1c59080c95b51ac0bec107e66bfd2001b 100644 --- a/make_triton_repo.sh +++ b/make_triton_repo.sh @@ -6,5 +6,6 @@ nvidia-docker run --gpus=all -it --rm --name iitm-tts-dhruva-builder -v ./Fastsp mkdir triton_model_repo nvidia-docker run --gpus=all -it --rm --name iitm-tts-dhruva-builder -v ./patches:/patches -v ./triton_models/tts:/model -v ./triton_model_repo:/model_repo dhruva/iitm-tts-envbuilder bash /model/envbuilder.sh cp Fastspeech2_HS/text_preprocess_for_inference.py triton_model_repo/tts/1 +cp Fastspeech2_HS/get_phone_mapped_python.py triton_model_repo/tts/1 cp -r triton_models/tts/config.pbtxt triton_models/tts/1 triton_model_repo/tts cp -r onnx_models/* Fastspeech2_HS/phone_dict Fastspeech2_HS/multilingualcharmap.json triton_model_repo/tts/1 \ No newline at end of file diff --git a/patches/espnet_onnx.patch b/patches/espnet_onnx.patch index 99374d4a39fd64122201bfa4cc679710074f3601..169c09cd83a68d627bfd5e767185b02f5b9b5aed 100644 --- a/patches/espnet_onnx.patch +++ b/patches/espnet_onnx.patch @@ -1,3 +1,25 @@ +diff --git a/espnet_onnx/export/tts/models/tts_models/fastspeech2.py b/espnet_onnx/export/tts/models/tts_models/fastspeech2.py +index 29197e2..0022e51 100644 +--- a/espnet_onnx/export/tts/models/tts_models/fastspeech2.py ++++ b/espnet_onnx/export/tts/models/tts_models/fastspeech2.py +@@ -10,7 +10,7 @@ from espnet_onnx.utils.torch_function import MakePadMask, normalize + + + class OnnxLengthRegurator(nn.Module): +- def __init__(self, alpha=1.0, max_seq_len=512): ++ def __init__(self, alpha=1.0, max_seq_len=1000): + super().__init__() + self.alpha = alpha + # The maximum length of the make_pad_mask is the +@@ -59,7 +59,7 @@ class OnnxFastSpeech2(nn.Module, AbsExportModel): + def __init__( + self, + model, +- max_seq_len: int = 512, ++ max_seq_len: int = 1000, + alpha: float = 1.0, + use_cache: bool = True, + **kwargs, diff --git a/espnet_onnx/tts/abs_tts_model.py b/espnet_onnx/tts/abs_tts_model.py index 591947a..9de64a5 100644 --- a/espnet_onnx/tts/abs_tts_model.py @@ -84,6 +106,19 @@ index 1270468..4aa63c6 100644 class AbsExportModel(ABC): +diff --git a/espnet_onnx/utils/torch_function.py b/espnet_onnx/utils/torch_function.py +index c274346..f88f8bc 100644 +--- a/espnet_onnx/utils/torch_function.py ++++ b/espnet_onnx/utils/torch_function.py +@@ -6,7 +6,7 @@ import torch.nn as nn + + + class MakePadMask(nn.Module): +- def __init__(self, max_seq_len=512, flip=True): ++ def __init__(self, max_seq_len=1000, flip=True): + super().__init__() + if flip: + self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool) diff --git a/setup.py b/setup.py index 483b062..ee37d37 100644 --- a/setup.py diff --git a/triton_models/tts/1/model.py b/triton_models/tts/1/model.py index 66e71bd7fb6a85e3485d188ce44d1a64da4df490..8d3a631925e71263226fa3cbc845faf399bc0199 100644 --- a/triton_models/tts/1/model.py +++ b/triton_models/tts/1/model.py @@ -79,7 +79,7 @@ class TritonPythonModel: providers=[ "CPUExecutionProvider" if device == "cpu" - else ("CUDAExecutionProvider", {"device_id": self.device_id}) + else ("CUDAExecutionProvider", {"device_id": self.device_id, "arena_extend_strategy": "kSameAsRequested"}) ], ) @@ -88,7 +88,7 @@ class TritonPythonModel: providers=[ "CPUExecutionProvider" if device == "cpu" - else ("CUDAExecutionProvider", {"device_id": self.device_id}) + else ("CUDAExecutionProvider", {"device_id": self.device_id, "arena_extend_strategy": "kSameAsRequested"}) ], model_dir=f"text2phone/{language}-{gender}-ort", use_quantized=True, diff --git a/triton_models/tts/envbuilder.sh b/triton_models/tts/envbuilder.sh index c77e31db1d612f87bc0c52d7be4feff92f4fd3c8..f684a08f68ee72ce6699c9cfc4075ec8ff283cc6 100644 --- a/triton_models/tts/envbuilder.sh +++ b/triton_models/tts/envbuilder.sh @@ -11,7 +11,7 @@ conda activate tts mamba install -c "nvidia/label/cuda-11.8.0" libcublas libcufft cuda-cudart -y git clone --recursive https://github.com/espnet/espnet_onnx.git cd espnet_onnx && git apply /patches/espnet_onnx.patch && python setup.py bdist_wheel && cd .. -pip install -U numpy pandas nltk indic-num2words g2p_en "espnet_onnx/dist/espnet_onnx-0.2.0-py3-none-any.whl" +pip install -U numpy pandas nltk indic-num2words g2p_en "espnet_onnx/dist/espnet_onnx-0.2.0-py3-none-any.whl" indic-unified-parser conda deactivate conda pack -n tts conda activate tts