model.py 2.71 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7
import os
import json
import numpy
from itertools import islice
from ctranslate2 import Translator
import triton_python_backend_utils as pb_utils

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
8

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
9 10 11
class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
12
        self.source_lang, self.target_lang = input_lang, output_lang
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
13
        self.model_config = json.loads(args["model_config"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
14 15 16 17
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        target_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT_SENT"
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
18
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
        try:
            self.translator = Translator(
                f"{os.path.join(current_path, 'translator')}",
                device="cuda",
                intra_threads=1,
                inter_threads=1,
                device_index=[self.device_id],
            )
        except:
            self.translator = Translator(
                f"{os.path.join(current_path, 'translator')}",
                device="cpu",
                intra_threads=4,
            )

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
34
    def clean_output(self, text):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
35 36 37 38 39 40
        text = text.replace("@@ ", "")
        text = text.replace("\u200c", "")
        if text.startswith("<to-gu> "):
            text = text[8:]
        if text.endswith(" <to-gu>"):
            text = text[:-8]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
41
        return text
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
42

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
43
    def execute(self, requests):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
44 45 46 47
        source_list = [
            pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
            for request in requests
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
48
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
49 50 51 52 53 54 55 56 57 58 59 60
        src_sentences = [
            s[0].decode("utf-8").strip().split(" ")
            for source in source_list
            for s in source.as_numpy()
        ]
        tgt_sentences = [
            self.clean_output(" ".join(result.hypotheses[0]))
            for result in self.translator.translate_iterable(
                src_sentences,
                max_batch_size=128,
                max_input_length=100,
                max_decoding_length=100,
61 62
                beam_size=15,
                replace_unknowns=True,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
            )
        ]
        responses = [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_SENT",
                        numpy.array(
                            [[s] for s in islice(tgt_sentences, bsize)], dtype="object"
                        ).astype(self.target_dtype),
                    )
                ]
            )
            for bsize in bsize_list
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
78
        return responses
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
79 80 81

    def finalize(self):
        self.translator.unload_model()