model.py 5.95 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
import os

os.environ["NLTK_DATA"] = "."
os.chdir(os.path.dirname(__file__))
import json
import numpy
import onnxruntime
from sys import path

MAX_WAV_VALUE = 32768.0
SAMPLING_RATE = 22050
from random import choice
from itertools import product
from espnet_onnx import Text2Speech
import triton_python_backend_utils as pb_utils
from text_preprocess_for_inference import (
    TTSDurAlignPreprocessor,
    CharTextPreprocessor,
    TTSPreprocessor,
)

LANGMAP = {
    "as": ("assamese", "aryan"),
    "bn": ("bengali", "aryan"),
    "brx": ("bodo", "aryan"),
    "en": ("english", "aryan"),
    "gu": ("gujarati", "aryan"),
    "hi": ("hindi", "aryan"),
    "kn": ("kannada", "dravidian"),
    "ml": ("malayalam", "dravidian"),
    "mni": ("manipuri", "aryan"),
    "mr": ("marathi", "aryan"),
    "or": ("odia", "aryan"),
    "pa": ("punjabi", "aryan"),
    "rj": ("rajasthani", "aryan"),
    "ta": ("tamil", "dravidian"),
    "te": ("telugu", "dravidian"),
    "ur": ("urdu", "aryan"),
}


class TritonPythonModel:
    def initialize(self, args):
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        self.target_dtype = pb_utils.triton_string_to_numpy(
            pb_utils.get_output_config_by_name(
                json.loads(args["model_config"]), "OUTPUT_GENERATED_AUDIO"
            )["data_type"]
        )
        self.tts_preprocessor = TTSPreprocessor()
        self.char_text_preprocessor = CharTextPreprocessor()
        self.tts_dur_align_preprocessor = TTSDurAlignPreprocessor()
        self.preprocessors = {}
        for lang, _ in LANGMAP.values():
            if lang == "urdu" or lang == "punjabi":
                self.preprocessors[lang] = self.char_text_preprocessor
            elif lang == "english":
                self.preprocessors[lang] = self.tts_preprocessor
            else:
                self.preprocessors[lang] = self.tts_dur_align_preprocessor
        self.models = {}
62
        for (language, _), gender in product(LANGMAP.values(), ("male", "female")):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
63 64 65 66
            try:
                self.models[(language, gender)] = self.load_fastspeech2_model(
                    language,
                    gender,
67
                    "cuda",
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
68 69 70 71 72
                )
            except:
                pass
        self.vocoders = {
            (gender, family): self.load_vocoder(gender, family, "cuda")
73
            for gender, family in product(("male", "female"), ("aryan", "dravidian"))
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
74 75 76 77 78 79 80 81
        }

    def load_vocoder(self, gender, family, device):
        return onnxruntime.InferenceSession(
            f"vocoders/{gender}-{family}-vocoder.onnx",
            providers=[
                "CPUExecutionProvider"
                if device == "cpu"
82
                else ("CUDAExecutionProvider", {"device_id": self.device_id, "arena_extend_strategy": "kSameAsRequested"})
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
83 84 85 86 87 88 89 90
            ],
        )

    def load_fastspeech2_model(self, language, gender, device):
        model = Text2Speech(
            providers=[
                "CPUExecutionProvider"
                if device == "cpu"
91
                else ("CUDAExecutionProvider", {"device_id": self.device_id, "arena_extend_strategy": "kSameAsRequested"})
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
            ],
            model_dir=f"text2phone/{language}-{gender}-ort",
            use_quantized=True,
        )
        return model

    def determine_gender(self, name):
        if name.lower() in ("m", "male"):
            return "male"
        elif name.lower() in ("f", "fem", "female"):
            return "female"
        else:
            return choice(["male", "female"])

    def synthesize_audio(self, text, lang_id, speaker_id):
        (language, family), gender = LANGMAP[
            lang_id[0].decode("utf-8")
        ], self.determine_gender(speaker_id[0].decode("utf-8"))
        preprocessor = self.preprocessors[language]
        preprocessed_text = " ".join(
            preprocessor.preprocess(text[0].decode("utf=8"), language, gender)[0]
        )
        model, vocoder = (
            self.models[(language, gender)],
            self.vocoders[(gender, family)],
        )
        x = (
            numpy.expand_dims(
                model.postprocess(
                    model.tts_model(
                        model.preprocess.token_id_converter.tokens2ids(
                            model.preprocess.tokenizer.text2tokens(preprocessed_text)
                        )
                    )["feat_gen"]
                ).T,
                axis=0,
            )
            * 2.3262
        )
        y_g_hat = vocoder.run(None, {"input": x})[0]
        audio = y_g_hat.squeeze() * MAX_WAV_VALUE
        return audio.astype("int16")

    def execute(self, requests):
        return [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_GENERATED_AUDIO",
                        numpy.array(
                            [[processed_sent] for processed_sent in processed_sents],
                            dtype=self.target_dtype,
                        ),
                    )
                ]
            )
            for processed_sents in (
                (
                    self.synthesize_audio(
                        input_text, input_language_id, input_speaker_id
                    ).tobytes()
                    for input_text, input_speaker_id, input_language_id in zip(
                        input_texts.as_numpy(),
                        input_speaker_ids.as_numpy(),
                        input_language_ids.as_numpy(),
                    )
                )
                for input_texts, input_speaker_ids, input_language_ids in (
                    (
                        pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"),
                        pb_utils.get_input_tensor_by_name(request, "INPUT_SPEAKER_ID"),
                        pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"),
                    )
                    for request in requests
                )
            )
        ]

    def finalize(self):
        pass