model.py 2.8 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7
import os
import json
import numpy
from itertools import islice
from ctranslate2 import Translator
import triton_python_backend_utils as pb_utils

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
8

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
9 10 11
class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
12
        self.source_lang, self.target_lang = input_lang, output_lang
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
13
        self.model_config = json.loads(args["model_config"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
14 15 16 17
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        target_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT_SENT"
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
18
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
        try:
            self.translator = Translator(
                f"{os.path.join(current_path, 'translator')}",
                device="cuda",
                intra_threads=1,
                inter_threads=1,
                device_index=[self.device_id],
            )
        except:
            self.translator = Translator(
                f"{os.path.join(current_path, 'translator')}",
                device="cpu",
                intra_threads=4,
            )

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
34
    def clean_output(self, text):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
35 36
        text = text.replace("@@ ", "")
        text = text.replace("\u200c", "")
37
        text = text.replace(" ?", "?").replace(" !", "!").replace(" .", ".").replace(" ,", ",")
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
38 39 40 41
        if text.startswith("<to-gu> "):
            text = text[8:]
        if text.endswith(" <to-gu>"):
            text = text[:-8]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
42
        return text
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
43

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
44
    def execute(self, requests):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
45 46 47 48
        source_list = [
            pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
            for request in requests
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
49
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
50 51 52 53 54 55 56 57 58 59 60 61
        src_sentences = [
            s[0].decode("utf-8").strip().split(" ")
            for source in source_list
            for s in source.as_numpy()
        ]
        tgt_sentences = [
            self.clean_output(" ".join(result.hypotheses[0]))
            for result in self.translator.translate_iterable(
                src_sentences,
                max_batch_size=128,
                max_input_length=100,
                max_decoding_length=100,
62 63
                beam_size=15,
                replace_unknowns=True,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
64 65 66 67 68 69 70 71
            )
        ]
        responses = [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_SENT",
                        numpy.array(
72 73
                            [[s.encode('utf-8')] for s in islice(tgt_sentences, bsize)], dtype=self.target_dtype
                        ),
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
74 75 76 77 78
                    )
                ]
            )
            for bsize in bsize_list
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
79
        return responses
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
80 81 82

    def finalize(self):
        self.translator.unload_model()