model.py 1.62 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import os
import json
import numpy
from .apply_bpe import BPE
from ilstokenizer import tokenizer
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
        self.model_config = json.loads(args["model_config"])
        target_config = pb_utils.get_output_config_by_name(self.model_config, "INPUT_TEXT_TOKENIZED")
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
        self.lang_pair_map = {'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9}
        self.bpes = {lang_pair: BPE(open(os.path.join(current_path, f'bpe_src/{model_id}.src'), encoding='utf-8')) for lang_pair, model_id in self.lang_pair_map.items()}

    def execute(self, requests):
        source_gen = ((pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID")) for request in requests)
        tokenized_gen = (self.bpes[f"{input_language_id.as_numpy()[0, 0].decode('utf-8')}-{output_language_id.as_numpy()[0, 0].decode('utf-8')}"].segment(tokenizer.tokenize(input_text.as_numpy()[0, 0].decode('utf-8'))).strip() for input_text, input_language_id, output_language_id in source_gen)
        responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent]], dtype=self.target_dtype))]) for tokenized_sent in tokenized_gen]
        return responses

    def finalize(self): pass