model.py 2.64 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7
import os
import json
import numpy
from itertools import islice
from ctranslate2 import Translator
import triton_python_backend_utils as pb_utils

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
8

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
9 10 11
class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
12
        self.source_lang, self.target_lang = input_lang, output_lang
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
13
        self.model_config = json.loads(args["model_config"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
14 15 16 17
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        target_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT_SENT"
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
18
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
        try:
            self.translator = Translator(
                f"{os.path.join(current_path, 'translator')}",
                device="cuda",
                intra_threads=1,
                inter_threads=1,
                device_index=[self.device_id],
            )
        except:
            self.translator = Translator(
                f"{os.path.join(current_path, 'translator')}",
                device="cpu",
                intra_threads=4,
            )

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
34
    def clean_output(self, text):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
35 36 37 38 39 40
        text = text.replace("@@ ", "")
        text = text.replace("\u200c", "")
        if text.startswith("<to-gu> "):
            text = text[8:]
        if text.endswith(" <to-gu>"):
            text = text[:-8]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
41
        return text
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
42

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
43
    def execute(self, requests):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
44 45 46 47
        source_list = [
            pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
            for request in requests
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
48
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        src_sentences = [
            s[0].decode("utf-8").strip().split(" ")
            for source in source_list
            for s in source.as_numpy()
        ]
        tgt_sentences = [
            self.clean_output(" ".join(result.hypotheses[0]))
            for result in self.translator.translate_iterable(
                src_sentences,
                max_batch_size=128,
                max_input_length=100,
                max_decoding_length=100,
            )
        ]
        responses = [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_SENT",
                        numpy.array(
                            [[s] for s in islice(tgt_sentences, bsize)], dtype="object"
                        ).astype(self.target_dtype),
                    )
                ]
            )
            for bsize in bsize_list
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
76
        return responses
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
77 78 79

    def finalize(self):
        self.translator.unload_model()