model.py 2 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10
import os
import json
import numpy
from itertools import islice
from ctranslate2 import Translator
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
11
        self.source_lang, self.target_lang = input_lang, output_lang
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
12 13
        self.model_config = json.loads(args["model_config"])
        self.device_id = int(json.loads(args['model_instance_device_id']))
14
        target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_SENT")
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
15 16 17
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
        try: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cuda", intra_threads=1, inter_threads=1, device_index=[self.device_id])
        except: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cpu", intra_threads=4)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
18 19
    def clean_output(self, text):
        text = text.replace('@@ ', '')
20
        text = text.replace('\u200c', '')
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
21 22 23
        if text.startswith('<to-gu> '): text = text[8:]
        if text.endswith(' <to-gu>'): text = text[:-8]
        return text
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
24
    def execute(self, requests):
25
        source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
26 27
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
        src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
28
        tgt_sentences = [self.clean_output(' '.join(result.hypotheses[0])) for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)]
29
        responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_SENT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
30
        return responses
31
    def finalize(self): self.translator.unload_model()