import os os.environ["NLTK_DATA"] = "." os.chdir(os.path.dirname(__file__)) import json import numpy import onnxruntime from sys import path MAX_WAV_VALUE = 32768.0 SAMPLING_RATE = 22050 from random import choice from itertools import product from espnet_onnx import Text2Speech import triton_python_backend_utils as pb_utils from text_preprocess_for_inference import ( TTSDurAlignPreprocessor, CharTextPreprocessor, TTSPreprocessor, ) LANGMAP = { "as": ("assamese", "aryan"), "bn": ("bengali", "aryan"), "brx": ("bodo", "aryan"), "en": ("english", "aryan"), "gu": ("gujarati", "aryan"), "hi": ("hindi", "aryan"), "kn": ("kannada", "dravidian"), "ml": ("malayalam", "dravidian"), "mni": ("manipuri", "aryan"), "mr": ("marathi", "aryan"), "or": ("odia", "aryan"), "pa": ("punjabi", "aryan"), "rj": ("rajasthani", "aryan"), "ta": ("tamil", "dravidian"), "te": ("telugu", "dravidian"), "ur": ("urdu", "aryan"), } class TritonPythonModel: def initialize(self, args): self.device_id = int(json.loads(args["model_instance_device_id"])) self.target_dtype = pb_utils.triton_string_to_numpy( pb_utils.get_output_config_by_name( json.loads(args["model_config"]), "OUTPUT_GENERATED_AUDIO" )["data_type"] ) self.tts_preprocessor = TTSPreprocessor() self.char_text_preprocessor = CharTextPreprocessor() self.tts_dur_align_preprocessor = TTSDurAlignPreprocessor() self.preprocessors = {} for lang, _ in LANGMAP.values(): if lang == "urdu" or lang == "punjabi": self.preprocessors[lang] = self.char_text_preprocessor elif lang == "english": self.preprocessors[lang] = self.tts_preprocessor else: self.preprocessors[lang] = self.tts_dur_align_preprocessor self.models = {} for (language, _), gender in product(LANGMAP.values(), ("male", "female")): try: self.models[(language, gender)] = self.load_fastspeech2_model( language, gender, "cuda", ) except: pass self.vocoders = { (gender, family): self.load_vocoder(gender, family, "cuda") for gender, family in product(("male", "female"), ("aryan", "dravidian")) } def load_vocoder(self, gender, family, device): return onnxruntime.InferenceSession( f"vocoders/{gender}-{family}-vocoder.onnx", providers=[ "CPUExecutionProvider" if device == "cpu" else ("CUDAExecutionProvider", {"device_id": self.device_id, "arena_extend_strategy": "kSameAsRequested"}) ], ) def load_fastspeech2_model(self, language, gender, device): model = Text2Speech( providers=[ "CPUExecutionProvider" if device == "cpu" else ("CUDAExecutionProvider", {"device_id": self.device_id, "arena_extend_strategy": "kSameAsRequested"}) ], model_dir=f"text2phone/{language}-{gender}-ort", use_quantized=True, ) return model def determine_gender(self, name): if name.lower() in ("m", "male"): return "male" elif name.lower() in ("f", "fem", "female"): return "female" else: return choice(["male", "female"]) def synthesize_audio(self, text, lang_id, speaker_id): (language, family), gender = LANGMAP[ lang_id[0].decode("utf-8") ], self.determine_gender(speaker_id[0].decode("utf-8")) preprocessor = self.preprocessors[language] preprocessed_text = " ".join( preprocessor.preprocess(text[0].decode("utf=8"), language, gender)[0] ) model, vocoder = ( self.models[(language, gender)], self.vocoders[(gender, family)], ) x = ( numpy.expand_dims( model.postprocess( model.tts_model( model.preprocess.token_id_converter.tokens2ids( model.preprocess.tokenizer.text2tokens(preprocessed_text) ) )["feat_gen"] ).T, axis=0, ) * 2.3262 ) y_g_hat = vocoder.run(None, {"input": x})[0] audio = y_g_hat.squeeze() * MAX_WAV_VALUE return audio.astype("int16") def execute(self, requests): return [ pb_utils.InferenceResponse( output_tensors=[ pb_utils.Tensor( "OUTPUT_GENERATED_AUDIO", numpy.array( [[processed_sent] for processed_sent in processed_sents], dtype=self.target_dtype, ), ) ] ) for processed_sents in ( ( self.synthesize_audio( input_text, input_language_id, input_speaker_id ).tobytes() for input_text, input_speaker_id, input_language_id in zip( input_texts.as_numpy(), input_speaker_ids.as_numpy(), input_language_ids.as_numpy(), ) ) for input_texts, input_speaker_ids, input_language_ids in ( ( pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_SPEAKER_ID"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), ) for request in requests ) ) ] def finalize(self): pass