From 19adc2a40f65c1bab713baa3f660aef4b45d856a Mon Sep 17 00:00:00 2001 From: Nikhilesh Bhatnagar Date: Wed, 30 Aug 2023 07:59:59 +0000 Subject: [PATCH] batching fixes --- triton_models/ssmt_tokenizer/1/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/triton_models/ssmt_tokenizer/1/model.py b/triton_models/ssmt_tokenizer/1/model.py index 873b82a..06d9cf8 100644 --- a/triton_models/ssmt_tokenizer/1/model.py +++ b/triton_models/ssmt_tokenizer/1/model.py @@ -21,8 +21,8 @@ class TritonPythonModel: def execute(self, requests): source_gen = ((pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID")) for request in requests) - tokenized_gen = (self.tokenize_and_segment(input_text.as_numpy()[0, 0].decode('utf-8'), input_language_id.as_numpy()[0, 0].decode('utf-8'), output_language_id.as_numpy()[0, 0].decode('utf-8')) for input_text, input_language_id, output_language_id in source_gen) - responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent]], dtype=self.target_dtype))]) for tokenized_sent in tokenized_gen] + tokenized_gen = ((self.tokenize_and_segment(input_text[0].decode('utf-8'), input_language_id[0].decode('utf-8'), output_language_id[0].decode('utf-8')) for input_text, input_language_id, output_language_id in zip(input_texts.as_numpy(), input_language_ids.as_numpy(), output_language_ids.as_numpy())) for input_texts, input_language_ids, output_language_ids in source_gen) + responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent] for tokenized_sent in tokenized_sents], dtype=self.target_dtype))]) for tokenized_sents in tokenized_gen] return responses def finalize(self): pass -- GitLab