model.py 4.64 KB
Newer Older
1 2 3 4 5 6 7 8
import os
import json
import numpy
from itertools import islice
from argparse import Namespace
import triton_python_backend_utils as pb_utils
from onmt.translate.translator import build_translator

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
9

10 11 12 13 14
class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
        self.source_lang, self.target_lang = input_lang, output_lang
        self.model_config = json.loads(args["model_config"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
15 16 17 18
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        target_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT_SENT"
        )
19
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
20 21 22 23 24 25
        try:
            self.translator = build_translator(
                Namespace(
                    tgt_prefix=False,
                    alpha=0.0,
                    batch_type="sents",
26
                    beam_size=15,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
                    beta=-0.0,
                    block_ngram_repeat=0,
                    coverage_penalty="none",
                    data_type="text",
                    dump_beam="",
                    fp32=True,
                    gpu=self.device_id,
                    ignore_when_blocking=[],
                    length_penalty="none",
                    max_length=100,
                    max_sent_length=None,
                    min_length=0,
                    models=[f"{os.path.join(current_path, 'translator.pt')}"],
                    n_best=1,
                    output="/dev/null",
                    phrase_table="",
                    random_sampling_temp=1.0,
                    random_sampling_topk=1,
                    ratio=-0.0,
46
                    replace_unk=True,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
                    report_align=False,
                    report_time=False,
                    seed=829,
                    stepwise_penalty=False,
                    tgt=None,
                    verbose=False,
                ),
                report_score=False,
            )
        except:
            self.translator = build_translator(
                Namespace(
                    tgt_prefix=False,
                    alpha=0.0,
                    batch_type="sents",
62
                    beam_size=15,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
                    beta=-0.0,
                    block_ngram_repeat=0,
                    coverage_penalty="none",
                    data_type="text",
                    dump_beam="",
                    fp32=True,
                    gpu=-1,
                    ignore_when_blocking=[],
                    length_penalty="none",
                    max_length=100,
                    max_sent_length=None,
                    min_length=0,
                    models=[f"{os.path.join(current_path, 'translator.pt')}"],
                    n_best=1,
                    output="/dev/null",
                    phrase_table="",
                    random_sampling_temp=1.0,
                    random_sampling_topk=1,
                    ratio=-0.0,
82
                    replace_unk=True,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
83 84 85 86 87 88 89 90 91 92
                    report_align=False,
                    report_time=False,
                    seed=829,
                    stepwise_penalty=False,
                    tgt=None,
                    verbose=False,
                ),
                report_score=False,
            )

93
    def clean_output(self, text):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
94 95 96 97 98 99
        text = text.replace("@@ ", "")
        text = text.replace("\u200c", "")
        if text.startswith("<to-gu> "):
            text = text[8:]
        if text.endswith(" <to-gu>"):
            text = text[:-8]
100
        return text
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
101

102
    def execute(self, requests):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
103 104 105 106
        source_list = [
            pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
            for request in requests
        ]
107
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        src_sentences = [
            s[0].decode("utf-8").strip().split(" ")
            for source in source_list
            for s in source.as_numpy()
        ]
        tgt_sentences = [
            self.clean_output(result[0])
            for result in self.translator.translate(src_sentences, batch_size=128)[1]
        ]
        responses = [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_SENT",
                        numpy.array(
                            [[s] for s in islice(tgt_sentences, bsize)], dtype="object"
                        ).astype(self.target_dtype),
                    )
                ]
            )
            for bsize in bsize_list
        ]
130
        return responses
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
131 132 133

    def finalize(self):
        del self.translator