From da7746e47128903c60b6d381bc07921fad7fadb3 Mon Sep 17 00:00:00 2001 From: Nikhilesh Bhatnagar Date: Thu, 17 Aug 2023 12:43:04 +0000 Subject: [PATCH] Fixes for en-gu and gu-en models --- make_triton_model_repo.sh | 2 +- triton_client.ipynb | 2 +- triton_models/ssmt_pipeline/config.pbtxt | 2 +- triton_models/ssmt_template_model_repo/1/model.py | 2 +- triton_models/ssmt_tokenizer/1/model.py | 7 ++++++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/make_triton_model_repo.sh b/make_triton_model_repo.sh index d0088e9..075b4ed 100644 --- a/make_triton_model_repo.sh +++ b/make_triton_model_repo.sh @@ -19,7 +19,7 @@ ct2-opennmt-py-converter --model_path 9.pt --output_dir ./9_ct2 cd .. mkdir ssmt_triton_repo cd ssmt_triton_repo -cp -r ../triton_models/ssmt_pipeline . +cp -r ../triton_models/ssmt_pipeline nmt cp -r ../triton_models/ssmt_model_demuxer . cp -r ../triton_models/ssmt_tokenizer . cp -r ../models/*.src ssmt_tokenizer/1/bpe_src diff --git a/triton_client.ipynb b/triton_client.ipynb index ed5e2a6..4b92e98 100644 --- a/triton_client.ipynb +++ b/triton_client.ipynb @@ -39,7 +39,7 @@ "source": [ "shape = [1]\n", "MIN_WORDS, MAX_WORDS = 4, 20\n", - "model_name = \"ssmt_pipeline\"\n", + "model_name = \"nmt\"\n", "rs = wonderwords.RandomWord()" ] }, diff --git a/triton_models/ssmt_pipeline/config.pbtxt b/triton_models/ssmt_pipeline/config.pbtxt index d519b03..3876e72 100644 --- a/triton_models/ssmt_pipeline/config.pbtxt +++ b/triton_models/ssmt_pipeline/config.pbtxt @@ -1,4 +1,4 @@ -name: "ssmt_pipeline" +name: "nmt" platform: "ensemble" max_batch_size: 4096 diff --git a/triton_models/ssmt_template_model_repo/1/model.py b/triton_models/ssmt_template_model_repo/1/model.py index 7c636bb..084a1c1 100644 --- a/triton_models/ssmt_template_model_repo/1/model.py +++ b/triton_models/ssmt_template_model_repo/1/model.py @@ -19,7 +19,7 @@ class TritonPythonModel: source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT_TOKENIZED") for request in requests] bsize_list = [source.as_numpy().shape[0] for source in source_list] src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()] - tgt_sentences = [' '.join(result.hypotheses[0]).replace('@@ ', '') for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)] + tgt_sentences = [' '.join(result.hypotheses[0]).replace('@@ ', '').removeprefix(' ').removesuffix(' ') for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)] responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_TEXT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list] return responses diff --git a/triton_models/ssmt_tokenizer/1/model.py b/triton_models/ssmt_tokenizer/1/model.py index 76433e8..873b82a 100644 --- a/triton_models/ssmt_tokenizer/1/model.py +++ b/triton_models/ssmt_tokenizer/1/model.py @@ -14,9 +14,14 @@ class TritonPythonModel: self.lang_pair_map = {'en-hi': 1, 'hi-en': 2, 'en-te': 3, 'te-en': 4, 'hi-te': 6, 'te-hi': 7, 'en-gu': 8, 'gu-en': 9} self.bpes = {lang_pair: BPE(open(os.path.join(current_path, f'bpe_src/{model_id}.src'), encoding='utf-8')) for lang_pair, model_id in self.lang_pair_map.items()} + def tokenize_and_segment(self, input_text, source_lang, target_lang): + tokenized_text = tokenizer.tokenize(input_text) + if source_lang == 'en' and target_lang == 'gu': tokenized_text = f' {tokenized_text} ' + return self.bpes[f'{source_lang}-{target_lang}'].segment(tokenized_text).strip() + def execute(self, requests): source_gen = ((pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID")) for request in requests) - tokenized_gen = (self.bpes[f"{input_language_id.as_numpy()[0, 0].decode('utf-8')}-{output_language_id.as_numpy()[0, 0].decode('utf-8')}"].segment(tokenizer.tokenize(input_text.as_numpy()[0, 0].decode('utf-8'))).strip() for input_text, input_language_id, output_language_id in source_gen) + tokenized_gen = (self.tokenize_and_segment(input_text.as_numpy()[0, 0].decode('utf-8'), input_language_id.as_numpy()[0, 0].decode('utf-8'), output_language_id.as_numpy()[0, 0].decode('utf-8')) for input_text, input_language_id, output_language_id in source_gen) responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent]], dtype=self.target_dtype))]) for tokenized_sent in tokenized_gen] return responses -- GitLab