model.py 1.25 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9
import json
import asyncio
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
    def initialize(self, args):
        self.model_config = json.loads(args["model_config"])
        target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_TEXT")
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
10
        self.lang_pair_map = {'en-hi': 1, 'hi-en': 2, 'te-en': 4, 'hi-te': 6, 'te-hi': 7, 'en-gu': 8, 'gu-en': 9}
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
11 12 13 14 15 16 17 18 19 20 21 22

    async def execute(self, requests):
        responses = []
        infer_response_awaits = []
        for request in requests:
            language_pair = f"{pb_utils.get_input_tensor_by_name(request, 'INPUT_LANGUAGE_ID').as_numpy()[0, 0].decode('utf-8')}-{pb_utils.get_input_tensor_by_name(request, 'OUTPUT_LANGUAGE_ID').as_numpy()[0, 0].decode('utf-8')}"
            inference_request = pb_utils.InferenceRequest(model_name=f'ssmt_{self.lang_pair_map[language_pair]}_ct2', requested_output_names=['OUTPUT_TEXT'], inputs=[pb_utils.get_input_tensor_by_name(request, 'INPUT_TEXT_TOKENIZED')])
            infer_response_awaits.append(inference_request.async_exec())
        responses = await asyncio.gather(*infer_response_awaits)
        return responses

    def finalize(self): pass