import os import json import numpy from itertools import islice from argparse import Namespace import triton_python_backend_utils as pb_utils from onmt.translate.translator import build_translator class TritonPythonModel: def initialize(self, args): current_path = os.path.dirname(os.path.abspath(__file__)) self.source_lang, self.target_lang = input_lang, output_lang self.model_config = json.loads(args["model_config"]) self.device_id = int(json.loads(args["model_instance_device_id"])) target_config = pb_utils.get_output_config_by_name( self.model_config, "OUTPUT_SENT" ) self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) try: self.translator = build_translator( Namespace( tgt_prefix=False, alpha=0.0, batch_type="sents", beam_size=15, beta=-0.0, block_ngram_repeat=0, coverage_penalty="none", data_type="text", dump_beam="", fp32=True, gpu=self.device_id, ignore_when_blocking=[], length_penalty="none", max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output="/dev/null", phrase_table="", random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=True, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False, ), report_score=False, ) except: self.translator = build_translator( Namespace( tgt_prefix=False, alpha=0.0, batch_type="sents", beam_size=15, beta=-0.0, block_ngram_repeat=0, coverage_penalty="none", data_type="text", dump_beam="", fp32=True, gpu=-1, ignore_when_blocking=[], length_penalty="none", max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output="/dev/null", phrase_table="", random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=True, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False, ), report_score=False, ) def clean_output(self, text): text = text.replace("@@ ", "") text = text.replace("\u200c", "") if text.startswith(" "): text = text[8:] if text.endswith(" "): text = text[:-8] return text def execute(self, requests): source_list = [ pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests ] bsize_list = [source.as_numpy().shape[0] for source in source_list] src_sentences = [ s[0].decode("utf-8").strip().split(" ") for source in source_list for s in source.as_numpy() ] tgt_sentences = [ self.clean_output(result[0]) for result in self.translator.translate(src_sentences, batch_size=128)[1] ] responses = [ pb_utils.InferenceResponse( output_tensors=[ pb_utils.Tensor( "OUTPUT_SENT", numpy.array( [[s] for s in islice(tgt_sentences, bsize)], dtype="object" ).astype(self.target_dtype), ) ] ) for bsize in bsize_list ] return responses def finalize(self): del self.translator