import gc import os import re import sys import json import time import yaml import numpy import torch import codecs import fairseq import soundfile import torchaudio sys.path.append("aqc") from random import choice from threading import Timer import torch.nn.functional as F from types import SimpleNamespace os.chdir(os.path.dirname(__file__)) import torchaudio.sox_effects as ta_sox from faster_whisper import WhisperModel import triton_python_backend_utils as pb_utils from examples import speech_recognition, data2vec from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder SAMPLING_RATE = 16000 LANGMAP = { "bn": "Bengali", "bh": "Bhojpuri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi", "or": "Odia", "pa": "Punjabi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu", 'en': 'English' } class ASRModel(object): def __init__( self, lang, device, ): self.lang = lang self.timeout = 60 * 30 self.language = LANGMAP[lang] self.sampling_rate = SAMPLING_RATE self.device = device self.effects = [["gain", "-n"]] self.decoder_args = SimpleNamespace( post_process="letter", sil_weight=0, max_tokens=4000000, beam=100, nbest=1, criterion="ctc", ) self.space_remover = re.compile(' +') self.replacer = re.compile("\|+") self.load() self.last_accessed = time.time() def load(self): self.model, self.cfg, self.task, self.token, self.decoder = ( None, None, None, None, None, ) print(f"loading {self.language} asr....", end="") if self.language == 'English': self.model = WhisperModel('medium', device=self.device.split(':')[0], compute_type="float32", device_index=int(self.device.split(':')[1]), download_root='./whisper_models') else: model, self.cfg, self.task = ( fairseq.checkpoint_utils.load_model_ensemble_and_task( [f"models/SPRING_INX_data2vec_aqc_{self.language}.pt"], arg_overrides={"data": f"models/SPRING_INX_{self.language}_dict.txt"}, ) ) self.model = model[0].to(self.device).eval() self.token = [ x.strip().split(maxsplit=1)[0] for x in open( f"models/SPRING_INX_{self.language}_dict.txt", "rt" ).readlines() ] self.decoder = W2lViterbiDecoder(self.decoder_args, self.task.target_dictionary) self.loaded = True self.timer = Timer(self.timeout, self.unload) self.timer.start() print(" ....loaded!") def unload(self): print(f"unloading {self.language} asr....", end="") self.model, self.cfg, self.task, self.token, self.decoder = ( None, None, None, None, None, ) gc.collect() torch.cuda.empty_cache() self.loaded = False print(" ....unloaded!") def recognize(self, audio_frames): if self.loaded == False: self.load() self.timer.cancel() self.last_accessed = time.time() if self.language == 'English': segments, _ = self.model.transcribe(audio_frames, beam_size=5, language='en') output_text = self.space_remover.sub(' ', ' '.join(segment.text for segment in segments)) self.timer = Timer(self.timeout, self.unload) self.timer.start() return output_text else: input_sample, _ = ta_sox.apply_effects_tensor( torch.tensor(audio_frames).unsqueeze(0), self.sampling_rate, self.effects ) input_sample = input_sample.float().to(self.device) with torch.no_grad(): input_sample = F.layer_norm(input_sample, input_sample.shape) logits = self.model(source=input_sample.unsqueeze(0), padding_mask=None)[ "encoder_out" ] predicted_ids = [ self.token[x - 4] for x in torch.unique_consecutive( torch.argmax(logits[:, 0], axis=-1) ).tolist() if x != 0 ] text = self.decoder.decode(logits.to("cpu")) recognized_text = self.replacer.sub( " ", "".join( self.task.target_dictionary.string(tstep[0]["tokens"]) for tstep in text ), ) self.timer = Timer(self.timeout, self.unload) self.timer.start() return recognized_text class TritonPythonModel: def initialize(self, args): self.device_id = int(json.loads(args["model_instance_device_id"])) self.target_dtype = pb_utils.triton_string_to_numpy( pb_utils.get_output_config_by_name( json.loads(args["model_config"]), "OUTPUT_RECOGNIZED_TEXT" )["data_type"] ) self.models = {} def execute(self, requests): responses = [] for request in requests: input_audio = pb_utils.get_input_tensor_by_name( request, "INPUT_AUDIO" ).as_numpy() input_language_id = ( pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID") .as_numpy()[0] .decode("utf-8") ) success_flag = True for i in range(100): success_flag = False try: if input_language_id not in self.models: self.models[input_language_id] = ASRModel( input_language_id, f"cuda:{str(self.device_id)}", ) recognized_text = self.models[input_language_id].recognize( input_audio ) success_flag = True except RuntimeError: if i == 0: print( f"GPU {self.device_id} full, releasing cuda cache..." ) gc.collect() torch.cuda.empty_cache() else: print( f"GPU {self.device_id} full, unloading the least recently used model..." ) recency_list = sorted( filter(lambda x: x.loaded, self.models.values()), key=lambda x: x.last_accessed, ) if len(recency_list) == 0: time.sleep(5) else: recency_list[0].unload() if success_flag: break responses.append( pb_utils.InferenceResponse( output_tensors=[ pb_utils.Tensor( "OUTPUT_RECOGNIZED_TEXT", numpy.array([recognized_text], dtype="object"), ) ] ) ) return responses def finalize(self): pass