model.py 4.64 KB
Newer Older
1 2 3 4 5 6 7 8
import os
import json
import numpy
from itertools import islice
from argparse import Namespace
import triton_python_backend_utils as pb_utils
from onmt.translate.translator import build_translator

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
9

10 11 12 13 14
class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
        self.source_lang, self.target_lang = input_lang, output_lang
        self.model_config = json.loads(args["model_config"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
15 16 17 18
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        target_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT_SENT"
        )
19
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        try:
            self.translator = build_translator(
                Namespace(
                    tgt_prefix=False,
                    alpha=0.0,
                    batch_type="sents",
                    beam_size=5,
                    beta=-0.0,
                    block_ngram_repeat=0,
                    coverage_penalty="none",
                    data_type="text",
                    dump_beam="",
                    fp32=True,
                    gpu=self.device_id,
                    ignore_when_blocking=[],
                    length_penalty="none",
                    max_length=100,
                    max_sent_length=None,
                    min_length=0,
                    models=[f"{os.path.join(current_path, 'translator.pt')}"],
                    n_best=1,
                    output="/dev/null",
                    phrase_table="",
                    random_sampling_temp=1.0,
                    random_sampling_topk=1,
                    ratio=-0.0,
                    replace_unk=False,
                    report_align=False,
                    report_time=False,
                    seed=829,
                    stepwise_penalty=False,
                    tgt=None,
                    verbose=False,
                ),
                report_score=False,
            )
        except:
            self.translator = build_translator(
                Namespace(
                    tgt_prefix=False,
                    alpha=0.0,
                    batch_type="sents",
                    beam_size=5,
                    beta=-0.0,
                    block_ngram_repeat=0,
                    coverage_penalty="none",
                    data_type="text",
                    dump_beam="",
                    fp32=True,
                    gpu=-1,
                    ignore_when_blocking=[],
                    length_penalty="none",
                    max_length=100,
                    max_sent_length=None,
                    min_length=0,
                    models=[f"{os.path.join(current_path, 'translator.pt')}"],
                    n_best=1,
                    output="/dev/null",
                    phrase_table="",
                    random_sampling_temp=1.0,
                    random_sampling_topk=1,
                    ratio=-0.0,
                    replace_unk=False,
                    report_align=False,
                    report_time=False,
                    seed=829,
                    stepwise_penalty=False,
                    tgt=None,
                    verbose=False,
                ),
                report_score=False,
            )

93
    def clean_output(self, text):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
94 95 96 97 98 99
        text = text.replace("@@ ", "")
        text = text.replace("\u200c", "")
        if text.startswith("<to-gu> "):
            text = text[8:]
        if text.endswith(" <to-gu>"):
            text = text[:-8]
100
        return text
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
101

102
    def execute(self, requests):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
103 104 105 106
        source_list = [
            pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
            for request in requests
        ]
107
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        src_sentences = [
            s[0].decode("utf-8").strip().split(" ")
            for source in source_list
            for s in source.as_numpy()
        ]
        tgt_sentences = [
            self.clean_output(result[0])
            for result in self.translator.translate(src_sentences, batch_size=128)[1]
        ]
        responses = [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_SENT",
                        numpy.array(
                            [[s] for s in islice(tgt_sentences, bsize)], dtype="object"
                        ).astype(self.target_dtype),
                    )
                ]
            )
            for bsize in bsize_list
        ]
130
        return responses
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
131 132 133

    def finalize(self):
        del self.translator