import os import json import numpy from itertools import islice from argparse import Namespace import triton_python_backend_utils as pb_utils from onmt.translate.translator import build_translator class TritonPythonModel: def initialize(self, args): current_path = os.path.dirname(os.path.abspath(__file__)) self.source_lang, self.target_lang = input_lang, output_lang self.model_config = json.loads(args["model_config"]) self.device_id = int(json.loads(args['model_instance_device_id'])) target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_SENT") self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) try: self.translator = build_translator(Namespace(tgt_prefix=False, alpha=0.0, batch_type='sents', beam_size=5, beta=-0.0, block_ngram_repeat=0, coverage_penalty='none', data_type='text', dump_beam='', fp32=True, gpu=self.device_id, ignore_when_blocking=[], length_penalty='none', max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output='/dev/null', phrase_table='', random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=False, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False), report_score=False) except: self.translator = build_translator(Namespace(tgt_prefix=False, alpha=0.0, batch_type='sents', beam_size=5, beta=-0.0, block_ngram_repeat=0, coverage_penalty='none', data_type='text', dump_beam='', fp32=True, gpu=-1, ignore_when_blocking=[], length_penalty='none', max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output='/dev/null', phrase_table='', random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=False, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False), report_score=False) def clean_output(self, text): text = text.replace('@@ ', '') text = text.replace('\u200c', '') if text.startswith(' '): text = text[8:] if text.endswith(' '): text = text[:-8] return text def execute(self, requests): source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests] bsize_list = [source.as_numpy().shape[0] for source in source_list] src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()] tgt_sentences = [self.clean_output(result[0]) for result in self.translator.translate(src_sentences, batch_size=128)[1]] responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_SENT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list] return responses def finalize(self): del self.translator