From d0049da2a46e478f94a11c8befa98d24c8481bc4 Mon Sep 17 00:00:00 2001 From: Nikhilesh Bhatnagar Date: Mon, 4 Sep 2023 07:03:07 +0000 Subject: [PATCH] Formatting pass. --- .gitignore | 1 - triton_models/demuxer/1/model.py | 66 +++++++- triton_models/demuxer/config.pbtxt | 2 +- triton_models/model_ct2/1/model.py | 74 +++++++-- triton_models/model_ct2/config.pbtxt | 2 +- triton_models/model_onmt/1/model.py | 127 ++++++++++++++-- triton_models/model_onmt/config.pbtxt | 2 +- triton_models/nmt/config.pbtxt | 2 +- triton_models/tokenizer/1/apply_bpe.py | 199 ++++++++++++++++--------- triton_models/tokenizer/1/model.py | 73 ++++++++- triton_models/tokenizer/config.pbtxt | 2 +- 11 files changed, 437 insertions(+), 113 deletions(-) diff --git a/.gitignore b/.gitignore index c7d65da..163f573 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ -ssmt_triton_repo himangy_triton_repo \ No newline at end of file diff --git a/triton_models/demuxer/1/model.py b/triton_models/demuxer/1/model.py index 8527076..e73b062 100644 --- a/triton_models/demuxer/1/model.py +++ b/triton_models/demuxer/1/model.py @@ -3,7 +3,67 @@ import numpy import asyncio import triton_python_backend_utils as pb_utils + class TritonPythonModel: - def initialize(self, args): self.target_dtype = pb_utils.triton_string_to_numpy(pb_utils.get_output_config_by_name(json.loads(args['model_config']), 'OUTPUT_TEXT')['data_type']) - async def execute(self, requests): return [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor('OUTPUT_TEXT', numpy.array([[pb_utils.get_output_tensor_by_name(result, 'OUTPUT_SENT').as_numpy()[0, 0].decode('utf-8')] for result in (await asyncio.gather(*awaits))], dtype=self.target_dtype))]) for awaits in [[pb_utils.InferenceRequest(model_name=f"himangy-{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}", requested_output_names=['OUTPUT_SENT'], inputs=[pb_utils.Tensor('INPUT_SENT_TOKENIZED', numpy.array([[input_text_tokenized[0].decode('utf-8')]], dtype='object'))]).async_exec() for input_text_tokenized, input_language_id, output_language_id in zip(pb_utils.get_input_tensor_by_name(request, 'INPUT_TEXT_TOKENIZED').as_numpy(), pb_utils.get_input_tensor_by_name(request, 'INPUT_LANGUAGE_ID').as_numpy(), pb_utils.get_input_tensor_by_name(request, 'OUTPUT_LANGUAGE_ID').as_numpy())] for request in requests]] - def finalize(self): pass \ No newline at end of file + def initialize(self, args): + self.target_dtype = pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name( + json.loads(args["model_config"]), "OUTPUT_TEXT" + )["data_type"] + ) + + async def execute(self, requests): + return [ + pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "OUTPUT_TEXT", + numpy.array( + [ + [ + pb_utils.get_output_tensor_by_name( + result, "OUTPUT_SENT" + ) + .as_numpy()[0, 0] + .decode("utf-8") + ] + for result in (await asyncio.gather(*awaits)) + ], + dtype=self.target_dtype, + ), + ) + ] + ) + for awaits in [ + [ + pb_utils.InferenceRequest( + model_name=f"himangy-{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}", + requested_output_names=["OUTPUT_SENT"], + inputs=[ + pb_utils.Tensor( + "INPUT_SENT_TOKENIZED", + numpy.array( + [[input_text_tokenized[0].decode("utf-8")]], + dtype="object", + ), + ) + ], + ).async_exec() + for input_text_tokenized, input_language_id, output_language_id in zip( + pb_utils.get_input_tensor_by_name( + request, "INPUT_TEXT_TOKENIZED" + ).as_numpy(), + pb_utils.get_input_tensor_by_name( + request, "INPUT_LANGUAGE_ID" + ).as_numpy(), + pb_utils.get_input_tensor_by_name( + request, "OUTPUT_LANGUAGE_ID" + ).as_numpy(), + ) + ] + for request in requests + ] + ] + + def finalize(self): + pass diff --git a/triton_models/demuxer/config.pbtxt b/triton_models/demuxer/config.pbtxt index 109a39f..5a5b5f6 100644 --- a/triton_models/demuxer/config.pbtxt +++ b/triton_models/demuxer/config.pbtxt @@ -39,4 +39,4 @@ instance_group [ count: 1 kind: KIND_CPU } -] \ No newline at end of file +] diff --git a/triton_models/model_ct2/1/model.py b/triton_models/model_ct2/1/model.py index 170e90a..2662d22 100644 --- a/triton_models/model_ct2/1/model.py +++ b/triton_models/model_ct2/1/model.py @@ -5,27 +5,75 @@ from itertools import islice from ctranslate2 import Translator import triton_python_backend_utils as pb_utils + class TritonPythonModel: def initialize(self, args): current_path = os.path.dirname(os.path.abspath(__file__)) self.source_lang, self.target_lang = input_lang, output_lang self.model_config = json.loads(args["model_config"]) - self.device_id = int(json.loads(args['model_instance_device_id'])) - target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_SENT") + self.device_id = int(json.loads(args["model_instance_device_id"])) + target_config = pb_utils.get_output_config_by_name( + self.model_config, "OUTPUT_SENT" + ) self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) - try: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cuda", intra_threads=1, inter_threads=1, device_index=[self.device_id]) - except: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cpu", intra_threads=4) + try: + self.translator = Translator( + f"{os.path.join(current_path, 'translator')}", + device="cuda", + intra_threads=1, + inter_threads=1, + device_index=[self.device_id], + ) + except: + self.translator = Translator( + f"{os.path.join(current_path, 'translator')}", + device="cpu", + intra_threads=4, + ) + def clean_output(self, text): - text = text.replace('@@ ', '') - text = text.replace('\u200c', '') - if text.startswith(' '): text = text[8:] - if text.endswith(' '): text = text[:-8] + text = text.replace("@@ ", "") + text = text.replace("\u200c", "") + if text.startswith(" "): + text = text[8:] + if text.endswith(" "): + text = text[:-8] return text + def execute(self, requests): - source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests] + source_list = [ + pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") + for request in requests + ] bsize_list = [source.as_numpy().shape[0] for source in source_list] - src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()] - tgt_sentences = [self.clean_output(' '.join(result.hypotheses[0])) for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)] - responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_SENT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list] + src_sentences = [ + s[0].decode("utf-8").strip().split(" ") + for source in source_list + for s in source.as_numpy() + ] + tgt_sentences = [ + self.clean_output(" ".join(result.hypotheses[0])) + for result in self.translator.translate_iterable( + src_sentences, + max_batch_size=128, + max_input_length=100, + max_decoding_length=100, + ) + ] + responses = [ + pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "OUTPUT_SENT", + numpy.array( + [[s] for s in islice(tgt_sentences, bsize)], dtype="object" + ).astype(self.target_dtype), + ) + ] + ) + for bsize in bsize_list + ] return responses - def finalize(self): self.translator.unload_model() \ No newline at end of file + + def finalize(self): + self.translator.unload_model() diff --git a/triton_models/model_ct2/config.pbtxt b/triton_models/model_ct2/config.pbtxt index 82502f2..93c14a1 100644 --- a/triton_models/model_ct2/config.pbtxt +++ b/triton_models/model_ct2/config.pbtxt @@ -29,4 +29,4 @@ instance_group [ response_cache { enable: true -} \ No newline at end of file +} diff --git a/triton_models/model_onmt/1/model.py b/triton_models/model_onmt/1/model.py index d52b00d..adb06b8 100644 --- a/triton_models/model_onmt/1/model.py +++ b/triton_models/model_onmt/1/model.py @@ -6,27 +6,128 @@ from argparse import Namespace import triton_python_backend_utils as pb_utils from onmt.translate.translator import build_translator + class TritonPythonModel: def initialize(self, args): current_path = os.path.dirname(os.path.abspath(__file__)) self.source_lang, self.target_lang = input_lang, output_lang self.model_config = json.loads(args["model_config"]) - self.device_id = int(json.loads(args['model_instance_device_id'])) - target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_SENT") + self.device_id = int(json.loads(args["model_instance_device_id"])) + target_config = pb_utils.get_output_config_by_name( + self.model_config, "OUTPUT_SENT" + ) self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) - try: self.translator = build_translator(Namespace(tgt_prefix=False, alpha=0.0, batch_type='sents', beam_size=5, beta=-0.0, block_ngram_repeat=0, coverage_penalty='none', data_type='text', dump_beam='', fp32=True, gpu=self.device_id, ignore_when_blocking=[], length_penalty='none', max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output='/dev/null', phrase_table='', random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=False, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False), report_score=False) - except: self.translator = build_translator(Namespace(tgt_prefix=False, alpha=0.0, batch_type='sents', beam_size=5, beta=-0.0, block_ngram_repeat=0, coverage_penalty='none', data_type='text', dump_beam='', fp32=True, gpu=-1, ignore_when_blocking=[], length_penalty='none', max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output='/dev/null', phrase_table='', random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=False, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False), report_score=False) + try: + self.translator = build_translator( + Namespace( + tgt_prefix=False, + alpha=0.0, + batch_type="sents", + beam_size=5, + beta=-0.0, + block_ngram_repeat=0, + coverage_penalty="none", + data_type="text", + dump_beam="", + fp32=True, + gpu=self.device_id, + ignore_when_blocking=[], + length_penalty="none", + max_length=100, + max_sent_length=None, + min_length=0, + models=[f"{os.path.join(current_path, 'translator.pt')}"], + n_best=1, + output="/dev/null", + phrase_table="", + random_sampling_temp=1.0, + random_sampling_topk=1, + ratio=-0.0, + replace_unk=False, + report_align=False, + report_time=False, + seed=829, + stepwise_penalty=False, + tgt=None, + verbose=False, + ), + report_score=False, + ) + except: + self.translator = build_translator( + Namespace( + tgt_prefix=False, + alpha=0.0, + batch_type="sents", + beam_size=5, + beta=-0.0, + block_ngram_repeat=0, + coverage_penalty="none", + data_type="text", + dump_beam="", + fp32=True, + gpu=-1, + ignore_when_blocking=[], + length_penalty="none", + max_length=100, + max_sent_length=None, + min_length=0, + models=[f"{os.path.join(current_path, 'translator.pt')}"], + n_best=1, + output="/dev/null", + phrase_table="", + random_sampling_temp=1.0, + random_sampling_topk=1, + ratio=-0.0, + replace_unk=False, + report_align=False, + report_time=False, + seed=829, + stepwise_penalty=False, + tgt=None, + verbose=False, + ), + report_score=False, + ) + def clean_output(self, text): - text = text.replace('@@ ', '') - text = text.replace('\u200c', '') - if text.startswith(' '): text = text[8:] - if text.endswith(' '): text = text[:-8] + text = text.replace("@@ ", "") + text = text.replace("\u200c", "") + if text.startswith(" "): + text = text[8:] + if text.endswith(" "): + text = text[:-8] return text + def execute(self, requests): - source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests] + source_list = [ + pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") + for request in requests + ] bsize_list = [source.as_numpy().shape[0] for source in source_list] - src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()] - tgt_sentences = [self.clean_output(result[0]) for result in self.translator.translate(src_sentences, batch_size=128)[1]] - responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_SENT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list] + src_sentences = [ + s[0].decode("utf-8").strip().split(" ") + for source in source_list + for s in source.as_numpy() + ] + tgt_sentences = [ + self.clean_output(result[0]) + for result in self.translator.translate(src_sentences, batch_size=128)[1] + ] + responses = [ + pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "OUTPUT_SENT", + numpy.array( + [[s] for s in islice(tgt_sentences, bsize)], dtype="object" + ).astype(self.target_dtype), + ) + ] + ) + for bsize in bsize_list + ] return responses - def finalize(self): del self.translator \ No newline at end of file + + def finalize(self): + del self.translator diff --git a/triton_models/model_onmt/config.pbtxt b/triton_models/model_onmt/config.pbtxt index 82502f2..93c14a1 100644 --- a/triton_models/model_onmt/config.pbtxt +++ b/triton_models/model_onmt/config.pbtxt @@ -29,4 +29,4 @@ instance_group [ response_cache { enable: true -} \ No newline at end of file +} diff --git a/triton_models/nmt/config.pbtxt b/triton_models/nmt/config.pbtxt index 126bc1c..e7566cf 100644 --- a/triton_models/nmt/config.pbtxt +++ b/triton_models/nmt/config.pbtxt @@ -75,4 +75,4 @@ ensemble_scheduling { } } ] -} \ No newline at end of file +} diff --git a/triton_models/tokenizer/1/apply_bpe.py b/triton_models/tokenizer/1/apply_bpe.py index 2fe666e..8a87730 100644 --- a/triton_models/tokenizer/1/apply_bpe.py +++ b/triton_models/tokenizer/1/apply_bpe.py @@ -25,18 +25,21 @@ from collections import defaultdict # hack for python2/3 compatibility from io import open + argparse.open = open class BPE(object): - - def __init__(self, codes, separator='@@', vocab=None, glossaries=None): - + def __init__(self, codes, separator="@@", vocab=None, glossaries=None): # check version information firstline = codes.readline() - if firstline.startswith('#version:'): - self.version = tuple([int(x) for x in re.sub( - r'(\.0+)*$', '', firstline.split()[-1]).split(".")]) + if firstline.startswith("#version:"): + self.version = tuple( + [ + int(x) + for x in re.sub(r"(\.0+)*$", "", firstline.split()[-1]).split(".") + ] + ) else: self.version = (0, 1) codes.seek(0) @@ -45,10 +48,12 @@ class BPE(object): # some hacking to deal with duplicates (only consider first instance) self.bpe_codes = dict( - [(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))]) + [(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))] + ) self.bpe_codes_reverse = dict( - [(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()]) + [(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()] + ) self.separator = separator @@ -62,63 +67,99 @@ class BPE(object): """segment single sentence (whitespace-tokenized string) with BPE encoding""" output = [] for word in sentence.split(): - new_word = [out for segment in self._isolate_glossaries(word) - for out in encode(segment, - self.bpe_codes, - self.bpe_codes_reverse, - self.vocab, - self.separator, - self.version, - self.cache, - self.glossaries)] + new_word = [ + out + for segment in self._isolate_glossaries(word) + for out in encode( + segment, + self.bpe_codes, + self.bpe_codes_reverse, + self.vocab, + self.separator, + self.version, + self.cache, + self.glossaries, + ) + ] for item in new_word[:-1]: output.append(item + self.separator) output.append(new_word[-1]) - return ' '.join(output) + return " ".join(output) def _isolate_glossaries(self, word): word_segments = [word] for gloss in self.glossaries: - word_segments = [out_segments for segment in word_segments - for out_segments in isolate_glossary(segment, gloss)] + word_segments = [ + out_segments + for segment in word_segments + for out_segments in isolate_glossary(segment, gloss) + ] return word_segments def create_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, - description="learn BPE-based word segmentation") + description="learn BPE-based word segmentation", + ) parser.add_argument( - '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, - metavar='PATH', - help="Input file (default: standard input).") + "--input", + "-i", + type=argparse.FileType("r"), + default=sys.stdin, + metavar="PATH", + help="Input file (default: standard input).", + ) parser.add_argument( - '--codes', '-c', type=argparse.FileType('r'), metavar='PATH', + "--codes", + "-c", + type=argparse.FileType("r"), + metavar="PATH", required=True, - help="File with BPE codes (created by learn_bpe.py).") + help="File with BPE codes (created by learn_bpe.py).", + ) parser.add_argument( - '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, - metavar='PATH', - help="Output file (default: standard output)") + "--output", + "-o", + type=argparse.FileType("w"), + default=sys.stdout, + metavar="PATH", + help="Output file (default: standard output)", + ) parser.add_argument( - '--separator', '-s', type=str, default='@@', metavar='STR', - help="Separator between non-final subword units (default: '%(default)s'))") + "--separator", + "-s", + type=str, + default="@@", + metavar="STR", + help="Separator between non-final subword units (default: '%(default)s'))", + ) parser.add_argument( - '--vocabulary', type=argparse.FileType('r'), default=None, + "--vocabulary", + type=argparse.FileType("r"), + default=None, metavar="PATH", - help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.") + help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.", + ) parser.add_argument( - '--vocabulary-threshold', type=int, default=None, + "--vocabulary-threshold", + type=int, + default=None, metavar="INT", - help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV") + help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV", + ) parser.add_argument( - '--glossaries', type=str, nargs='+', default=None, + "--glossaries", + type=str, + nargs="+", + default=None, metavar="STR", - help="Glossaries. The strings provided in glossaries will not be affected" + - "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords") + help="Glossaries. The strings provided in glossaries will not be affected" + + "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords", + ) return parser @@ -136,9 +177,17 @@ def get_pairs(word): return pairs -def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries=None): - """Encode word based on list of BPE merge operations, which are applied consecutively - """ +def encode( + orig, + bpe_codes, + bpe_codes_reverse, + vocab, + separator, + version, + cache, + glossaries=None, +): + """Encode word based on list of BPE merge operations, which are applied consecutively""" if orig in cache: return cache[orig] @@ -148,9 +197,9 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, return (orig,) if version == (0, 1): - word = tuple(orig) + ('',) + word = tuple(orig) + ("",) elif version == (0, 2): # more consistent handling of word-final segments - word = tuple(orig[:-1]) + (orig[-1] + '',) + word = tuple(orig[:-1]) + (orig[-1] + "",) else: raise NotImplementedError @@ -160,7 +209,7 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, return orig while True: - bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float("inf"))) if bigram not in bpe_codes: break first, second = bigram @@ -189,10 +238,10 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, pairs = get_pairs(word) # don't print end-of-word symbols - if word[-1] == '': + if word[-1] == "": word = word[:-1] - elif word[-1].endswith(''): - word = word[:-1] + (word[-1].replace('', ''),) + elif word[-1].endswith(""): + word = word[:-1] + (word[-1].replace("", ""),) if vocab: word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator) @@ -207,12 +256,12 @@ def recursive_split(segment, bpe_codes, vocab, separator, final=False): try: if final: - left, right = bpe_codes[segment + ''] + left, right = bpe_codes[segment + ""] right = right[:-4] else: left, right = bpe_codes[segment] except: - #sys.stderr.write('cannot split {0} further.\n'.format(segment)) + # sys.stderr.write('cannot split {0} further.\n'.format(segment)) yield segment return @@ -239,7 +288,7 @@ def check_vocab_and_split(orig, bpe_codes, vocab, separator): if segment + separator in vocab: out.append(segment) else: - #sys.stderr.write('OOV: {0}\n'.format(segment)) + # sys.stderr.write('OOV: {0}\n'.format(segment)) for item in recursive_split(segment, bpe_codes, vocab, separator, False): out.append(item) @@ -247,7 +296,7 @@ def check_vocab_and_split(orig, bpe_codes, vocab, separator): if segment in vocab: out.append(segment) else: - #sys.stderr.write('OOV: {0}\n'.format(segment)) + # sys.stderr.write('OOV: {0}\n'.format(segment)) for item in recursive_split(segment, bpe_codes, vocab, separator, True): out.append(item) @@ -255,8 +304,7 @@ def check_vocab_and_split(orig, bpe_codes, vocab, separator): def read_vocabulary(vocab_file, threshold): - """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold. - """ + """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.""" vocabulary = set() @@ -273,7 +321,7 @@ def isolate_glossary(word, glossary): """ Isolate a glossary present inside a word. - Returns a list of subwords. In which all 'glossary' glossaries are isolated + Returns a list of subwords. In which all 'glossary' glossaries are isolated For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is: ['1934', 'USA', 'B', 'USA'] @@ -282,39 +330,42 @@ def isolate_glossary(word, glossary): return [word] else: splits = word.split(glossary) - segments = [segment.strip() for split in splits[:-1] - for segment in [split, glossary] if segment != ''] - return segments + [splits[-1].strip()] if splits[-1] != '' else segments - + segments = [ + segment.strip() + for split in splits[:-1] + for segment in [split, glossary] + if segment != "" + ] + return segments + [splits[-1].strip()] if splits[-1] != "" else segments -if __name__ == '__main__': +if __name__ == "__main__": # python 2/3 compatibility if sys.version_info < (3, 0): - sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) - sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) - sys.stdin = codecs.getreader('UTF-8')(sys.stdin) + sys.stderr = codecs.getwriter("UTF-8")(sys.stderr) + sys.stdout = codecs.getwriter("UTF-8")(sys.stdout) + sys.stdin = codecs.getreader("UTF-8")(sys.stdin) else: - sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') - sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8") + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8") sys.stdout = io.TextIOWrapper( - sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) + sys.stdout.buffer, encoding="utf-8", write_through=True, line_buffering=True + ) parser = create_parser() args = parser.parse_args() # read/write files as UTF-8 - args.codes = codecs.open(args.codes.name, encoding='utf-8') - if args.input.name != '': - args.input = codecs.open(args.input.name, encoding='utf-8') - if args.output.name != '': - args.output = codecs.open(args.output.name, 'w', encoding='utf-8') + args.codes = codecs.open(args.codes.name, encoding="utf-8") + if args.input.name != "": + args.input = codecs.open(args.input.name, encoding="utf-8") + if args.output.name != "": + args.output = codecs.open(args.output.name, "w", encoding="utf-8") if args.vocabulary: - args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8') + args.vocabulary = codecs.open(args.vocabulary.name, encoding="utf-8") if args.vocabulary: - vocabulary = read_vocabulary( - args.vocabulary, args.vocabulary_threshold) + vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold) else: vocabulary = None @@ -322,4 +373,4 @@ if __name__ == '__main__': for line in args.input: args.output.write(bpe.segment(line).strip()) - args.output.write('\n') + args.output.write("\n") diff --git a/triton_models/tokenizer/1/model.py b/triton_models/tokenizer/1/model.py index 4dad844..c64a9ff 100644 --- a/triton_models/tokenizer/1/model.py +++ b/triton_models/tokenizer/1/model.py @@ -6,8 +6,73 @@ from .apply_bpe import BPE from ilstokenizer import tokenizer import triton_python_backend_utils as pb_utils + class TritonPythonModel: - def initialize(self, args): self.target_dtype, self.bpes = pb_utils.triton_string_to_numpy(pb_utils.get_output_config_by_name(json.loads(args["model_config"]), "INPUT_TEXT_TOKENIZED")["data_type"]), {fname.rsplit('/', maxsplit=1)[-1][:-len('.src')]: BPE(open(fname, 'r', encoding='utf-8')) for fname in iglob(f'{os.path.dirname(os.path.abspath(__file__))}/bpe_src/*.src')} - def preprocess_text(self, text, source_lang, target_lang): return f" {text} " if source_lang == 'en' and target_lang == 'gu' else text - def execute(self, requests): return [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent] for tokenized_sent in tokenized_sents], dtype=self.target_dtype))]) for tokenized_sents in ((self.bpes[f"{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}"].segment(self.preprocess_text(tokenizer.tokenize(input_text[0].decode('utf-8').lower()), input_language_id[0].decode('utf-8'), output_language_id[0].decode('utf-8'))).strip() for input_text, input_language_id, output_language_id in zip(input_texts.as_numpy(), input_language_ids.as_numpy(), output_language_ids.as_numpy())) for input_texts, input_language_ids, output_language_ids in ((pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID")) for request in requests))] - def finalize(self): pass \ No newline at end of file + def initialize(self, args): + self.target_dtype, self.bpes = pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name( + json.loads(args["model_config"]), "INPUT_TEXT_TOKENIZED" + )["data_type"] + ), { + fname.rsplit("/", maxsplit=1)[-1][: -len(".src")]: BPE( + open(fname, "r", encoding="utf-8") + ) + for fname in iglob( + f"{os.path.dirname(os.path.abspath(__file__))}/bpe_src/*.src" + ) + } + + def preprocess_text(self, text, source_lang, target_lang): + return ( + f" {text} " + if source_lang == "en" and target_lang == "gu" + else text + ) + + def execute(self, requests): + return [ + pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "INPUT_TEXT_TOKENIZED", + numpy.array( + [[tokenized_sent] for tokenized_sent in tokenized_sents], + dtype=self.target_dtype, + ), + ) + ] + ) + for tokenized_sents in ( + ( + self.bpes[ + f"{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}" + ] + .segment( + self.preprocess_text( + tokenizer.tokenize(input_text[0].decode("utf-8").lower()), + input_language_id[0].decode("utf-8"), + output_language_id[0].decode("utf-8"), + ) + ) + .strip() + for input_text, input_language_id, output_language_id in zip( + input_texts.as_numpy(), + input_language_ids.as_numpy(), + output_language_ids.as_numpy(), + ) + ) + for input_texts, input_language_ids, output_language_ids in ( + ( + pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), + pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), + pb_utils.get_input_tensor_by_name( + request, "OUTPUT_LANGUAGE_ID" + ), + ) + for request in requests + ) + ) + ] + + def finalize(self): + pass diff --git a/triton_models/tokenizer/config.pbtxt b/triton_models/tokenizer/config.pbtxt index 5cfe82a..9dd74d8 100644 --- a/triton_models/tokenizer/config.pbtxt +++ b/triton_models/tokenizer/config.pbtxt @@ -39,4 +39,4 @@ instance_group [ count: 8 kind: KIND_CPU } -] \ No newline at end of file +] -- GitLab