Commit 4029cd7c authored by Nikhilesh Bhatnagar's avatar Nikhilesh Bhatnagar

fix compatibility issue

parent da7746e4
......@@ -15,11 +15,17 @@ class TritonPythonModel:
try: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cuda", intra_threads=1, inter_threads=1, device_index=[self.device_id])
except: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cpu", intra_threads=4)
def clean_output(self, text):
text = text.replace('@@ ', '')
if text.startswith('<to-gu> '): text = text[8:]
if text.endswith(' <to-gu>'): text = text[:-8]
return text
def execute(self, requests):
source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT_TOKENIZED") for request in requests]
bsize_list = [source.as_numpy().shape[0] for source in source_list]
src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()]
tgt_sentences = [' '.join(result.hypotheses[0]).replace('@@ ', '').removeprefix('<to-gu> ').removesuffix(' <to-gu>') for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)]
tgt_sentences = [self.clean_output(' '.join(result.hypotheses[0])) for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)]
responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_TEXT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list]
return responses
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment