import os import json import numpy from glob import iglob from .apply_bpe import BPE from ilstokenizer import tokenizer import triton_python_backend_utils as pb_utils class TritonPythonModel: def initialize(self, args): self.target_dtype, self.bpes = pb_utils.triton_string_to_numpy( pb_utils.get_output_config_by_name( json.loads(args["model_config"]), "INPUT_TEXT_TOKENIZED" )["data_type"] ), { fname.rsplit("/", maxsplit=1)[-1][: -len(".src")]: BPE( open(fname, "r", encoding="utf-8") ) for fname in iglob( f"{os.path.dirname(os.path.abspath(__file__))}/bpe_src/*.src" ) } def preprocess_text(self, text, source_lang, target_lang): return ( f" {text} " if source_lang == "en" and target_lang == "gu" else text ) def execute(self, requests): return [ pb_utils.InferenceResponse( output_tensors=[ pb_utils.Tensor( "INPUT_TEXT_TOKENIZED", numpy.array( [[tokenized_sent] for tokenized_sent in tokenized_sents], dtype=self.target_dtype, ), ) ] ) for tokenized_sents in ( ( self.bpes[ f"{input_language_id[0].decode('utf-8').split('_', maxsplit=1)[0]}-{output_language_id[0].decode('utf-8').split('_', maxsplit=1)[0]}" ] .segment( self.preprocess_text( tokenizer.tokenize(input_text[0].decode("utf-8").lower()), input_language_id[0].decode("utf-8").split('_', maxsplit=1)[0], output_language_id[0].decode("utf-8").split('_', maxsplit=1)[0], ) ) .strip() for input_text, input_language_id, output_language_id in zip( input_texts.as_numpy(), input_language_ids.as_numpy(), output_language_ids.as_numpy(), ) ) for input_texts, input_language_ids, output_language_ids in ( ( pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name( request, "OUTPUT_LANGUAGE_ID" ), ) for request in requests ) ) ] def finalize(self): pass