Commit d0049da2 authored by Nikhilesh Bhatnagar's avatar Nikhilesh Bhatnagar

Formatting pass.

parent f61cdc30
ssmt_triton_repo
himangy_triton_repo
\ No newline at end of file
......@@ -3,7 +3,67 @@ import numpy
import asyncio
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args): self.target_dtype = pb_utils.triton_string_to_numpy(pb_utils.get_output_config_by_name(json.loads(args['model_config']), 'OUTPUT_TEXT')['data_type'])
async def execute(self, requests): return [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor('OUTPUT_TEXT', numpy.array([[pb_utils.get_output_tensor_by_name(result, 'OUTPUT_SENT').as_numpy()[0, 0].decode('utf-8')] for result in (await asyncio.gather(*awaits))], dtype=self.target_dtype))]) for awaits in [[pb_utils.InferenceRequest(model_name=f"himangy-{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}", requested_output_names=['OUTPUT_SENT'], inputs=[pb_utils.Tensor('INPUT_SENT_TOKENIZED', numpy.array([[input_text_tokenized[0].decode('utf-8')]], dtype='object'))]).async_exec() for input_text_tokenized, input_language_id, output_language_id in zip(pb_utils.get_input_tensor_by_name(request, 'INPUT_TEXT_TOKENIZED').as_numpy(), pb_utils.get_input_tensor_by_name(request, 'INPUT_LANGUAGE_ID').as_numpy(), pb_utils.get_input_tensor_by_name(request, 'OUTPUT_LANGUAGE_ID').as_numpy())] for request in requests]]
def finalize(self): pass
\ No newline at end of file
def initialize(self, args):
self.target_dtype = pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(
json.loads(args["model_config"]), "OUTPUT_TEXT"
)["data_type"]
)
async def execute(self, requests):
return [
pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"OUTPUT_TEXT",
numpy.array(
[
[
pb_utils.get_output_tensor_by_name(
result, "OUTPUT_SENT"
)
.as_numpy()[0, 0]
.decode("utf-8")
]
for result in (await asyncio.gather(*awaits))
],
dtype=self.target_dtype,
),
)
]
)
for awaits in [
[
pb_utils.InferenceRequest(
model_name=f"himangy-{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}",
requested_output_names=["OUTPUT_SENT"],
inputs=[
pb_utils.Tensor(
"INPUT_SENT_TOKENIZED",
numpy.array(
[[input_text_tokenized[0].decode("utf-8")]],
dtype="object",
),
)
],
).async_exec()
for input_text_tokenized, input_language_id, output_language_id in zip(
pb_utils.get_input_tensor_by_name(
request, "INPUT_TEXT_TOKENIZED"
).as_numpy(),
pb_utils.get_input_tensor_by_name(
request, "INPUT_LANGUAGE_ID"
).as_numpy(),
pb_utils.get_input_tensor_by_name(
request, "OUTPUT_LANGUAGE_ID"
).as_numpy(),
)
]
for request in requests
]
]
def finalize(self):
pass
......@@ -5,27 +5,75 @@ from itertools import islice
from ctranslate2 import Translator
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args):
current_path = os.path.dirname(os.path.abspath(__file__))
self.source_lang, self.target_lang = input_lang, output_lang
self.model_config = json.loads(args["model_config"])
self.device_id = int(json.loads(args['model_instance_device_id']))
target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_SENT")
self.device_id = int(json.loads(args["model_instance_device_id"]))
target_config = pb_utils.get_output_config_by_name(
self.model_config, "OUTPUT_SENT"
)
self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
try: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cuda", intra_threads=1, inter_threads=1, device_index=[self.device_id])
except: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cpu", intra_threads=4)
try:
self.translator = Translator(
f"{os.path.join(current_path, 'translator')}",
device="cuda",
intra_threads=1,
inter_threads=1,
device_index=[self.device_id],
)
except:
self.translator = Translator(
f"{os.path.join(current_path, 'translator')}",
device="cpu",
intra_threads=4,
)
def clean_output(self, text):
text = text.replace('@@ ', '')
text = text.replace('\u200c', '')
if text.startswith('<to-gu> '): text = text[8:]
if text.endswith(' <to-gu>'): text = text[:-8]
text = text.replace("@@ ", "")
text = text.replace("\u200c", "")
if text.startswith("<to-gu> "):
text = text[8:]
if text.endswith(" <to-gu>"):
text = text[:-8]
return text
def execute(self, requests):
source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests]
source_list = [
pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
for request in requests
]
bsize_list = [source.as_numpy().shape[0] for source in source_list]
src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()]
tgt_sentences = [self.clean_output(' '.join(result.hypotheses[0])) for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)]
responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_SENT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list]
src_sentences = [
s[0].decode("utf-8").strip().split(" ")
for source in source_list
for s in source.as_numpy()
]
tgt_sentences = [
self.clean_output(" ".join(result.hypotheses[0]))
for result in self.translator.translate_iterable(
src_sentences,
max_batch_size=128,
max_input_length=100,
max_decoding_length=100,
)
]
responses = [
pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"OUTPUT_SENT",
numpy.array(
[[s] for s in islice(tgt_sentences, bsize)], dtype="object"
).astype(self.target_dtype),
)
]
)
for bsize in bsize_list
]
return responses
def finalize(self): self.translator.unload_model()
\ No newline at end of file
def finalize(self):
self.translator.unload_model()
......@@ -6,27 +6,128 @@ from argparse import Namespace
import triton_python_backend_utils as pb_utils
from onmt.translate.translator import build_translator
class TritonPythonModel:
def initialize(self, args):
current_path = os.path.dirname(os.path.abspath(__file__))
self.source_lang, self.target_lang = input_lang, output_lang
self.model_config = json.loads(args["model_config"])
self.device_id = int(json.loads(args['model_instance_device_id']))
target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_SENT")
self.device_id = int(json.loads(args["model_instance_device_id"]))
target_config = pb_utils.get_output_config_by_name(
self.model_config, "OUTPUT_SENT"
)
self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
try: self.translator = build_translator(Namespace(tgt_prefix=False, alpha=0.0, batch_type='sents', beam_size=5, beta=-0.0, block_ngram_repeat=0, coverage_penalty='none', data_type='text', dump_beam='', fp32=True, gpu=self.device_id, ignore_when_blocking=[], length_penalty='none', max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output='/dev/null', phrase_table='', random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=False, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False), report_score=False)
except: self.translator = build_translator(Namespace(tgt_prefix=False, alpha=0.0, batch_type='sents', beam_size=5, beta=-0.0, block_ngram_repeat=0, coverage_penalty='none', data_type='text', dump_beam='', fp32=True, gpu=-1, ignore_when_blocking=[], length_penalty='none', max_length=100, max_sent_length=None, min_length=0, models=[f"{os.path.join(current_path, 'translator.pt')}"], n_best=1, output='/dev/null', phrase_table='', random_sampling_temp=1.0, random_sampling_topk=1, ratio=-0.0, replace_unk=False, report_align=False, report_time=False, seed=829, stepwise_penalty=False, tgt=None, verbose=False), report_score=False)
try:
self.translator = build_translator(
Namespace(
tgt_prefix=False,
alpha=0.0,
batch_type="sents",
beam_size=5,
beta=-0.0,
block_ngram_repeat=0,
coverage_penalty="none",
data_type="text",
dump_beam="",
fp32=True,
gpu=self.device_id,
ignore_when_blocking=[],
length_penalty="none",
max_length=100,
max_sent_length=None,
min_length=0,
models=[f"{os.path.join(current_path, 'translator.pt')}"],
n_best=1,
output="/dev/null",
phrase_table="",
random_sampling_temp=1.0,
random_sampling_topk=1,
ratio=-0.0,
replace_unk=False,
report_align=False,
report_time=False,
seed=829,
stepwise_penalty=False,
tgt=None,
verbose=False,
),
report_score=False,
)
except:
self.translator = build_translator(
Namespace(
tgt_prefix=False,
alpha=0.0,
batch_type="sents",
beam_size=5,
beta=-0.0,
block_ngram_repeat=0,
coverage_penalty="none",
data_type="text",
dump_beam="",
fp32=True,
gpu=-1,
ignore_when_blocking=[],
length_penalty="none",
max_length=100,
max_sent_length=None,
min_length=0,
models=[f"{os.path.join(current_path, 'translator.pt')}"],
n_best=1,
output="/dev/null",
phrase_table="",
random_sampling_temp=1.0,
random_sampling_topk=1,
ratio=-0.0,
replace_unk=False,
report_align=False,
report_time=False,
seed=829,
stepwise_penalty=False,
tgt=None,
verbose=False,
),
report_score=False,
)
def clean_output(self, text):
text = text.replace('@@ ', '')
text = text.replace('\u200c', '')
if text.startswith('<to-gu> '): text = text[8:]
if text.endswith(' <to-gu>'): text = text[:-8]
text = text.replace("@@ ", "")
text = text.replace("\u200c", "")
if text.startswith("<to-gu> "):
text = text[8:]
if text.endswith(" <to-gu>"):
text = text[:-8]
return text
def execute(self, requests):
source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED") for request in requests]
source_list = [
pb_utils.get_input_tensor_by_name(request, "INPUT_SENT_TOKENIZED")
for request in requests
]
bsize_list = [source.as_numpy().shape[0] for source in source_list]
src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()]
tgt_sentences = [self.clean_output(result[0]) for result in self.translator.translate(src_sentences, batch_size=128)[1]]
responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_SENT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list]
src_sentences = [
s[0].decode("utf-8").strip().split(" ")
for source in source_list
for s in source.as_numpy()
]
tgt_sentences = [
self.clean_output(result[0])
for result in self.translator.translate(src_sentences, batch_size=128)[1]
]
responses = [
pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"OUTPUT_SENT",
numpy.array(
[[s] for s in islice(tgt_sentences, bsize)], dtype="object"
).astype(self.target_dtype),
)
]
)
for bsize in bsize_list
]
return responses
def finalize(self): del self.translator
\ No newline at end of file
def finalize(self):
del self.translator
......@@ -25,18 +25,21 @@ from collections import defaultdict
# hack for python2/3 compatibility
from io import open
argparse.open = open
class BPE(object):
def __init__(self, codes, separator='@@', vocab=None, glossaries=None):
def __init__(self, codes, separator="@@", vocab=None, glossaries=None):
# check version information
firstline = codes.readline()
if firstline.startswith('#version:'):
self.version = tuple([int(x) for x in re.sub(
r'(\.0+)*$', '', firstline.split()[-1]).split(".")])
if firstline.startswith("#version:"):
self.version = tuple(
[
int(x)
for x in re.sub(r"(\.0+)*$", "", firstline.split()[-1]).split(".")
]
)
else:
self.version = (0, 1)
codes.seek(0)
......@@ -45,10 +48,12 @@ class BPE(object):
# some hacking to deal with duplicates (only consider first instance)
self.bpe_codes = dict(
[(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))])
[(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))]
)
self.bpe_codes_reverse = dict(
[(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()])
[(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()]
)
self.separator = separator
......@@ -62,63 +67,99 @@ class BPE(object):
"""segment single sentence (whitespace-tokenized string) with BPE encoding"""
output = []
for word in sentence.split():
new_word = [out for segment in self._isolate_glossaries(word)
for out in encode(segment,
new_word = [
out
for segment in self._isolate_glossaries(word)
for out in encode(
segment,
self.bpe_codes,
self.bpe_codes_reverse,
self.vocab,
self.separator,
self.version,
self.cache,
self.glossaries)]
self.glossaries,
)
]
for item in new_word[:-1]:
output.append(item + self.separator)
output.append(new_word[-1])
return ' '.join(output)
return " ".join(output)
def _isolate_glossaries(self, word):
word_segments = [word]
for gloss in self.glossaries:
word_segments = [out_segments for segment in word_segments
for out_segments in isolate_glossary(segment, gloss)]
word_segments = [
out_segments
for segment in word_segments
for out_segments in isolate_glossary(segment, gloss)
]
return word_segments
def create_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="learn BPE-based word segmentation")
description="learn BPE-based word segmentation",
)
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
metavar='PATH',
help="Input file (default: standard input).")
"--input",
"-i",
type=argparse.FileType("r"),
default=sys.stdin,
metavar="PATH",
help="Input file (default: standard input).",
)
parser.add_argument(
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
"--codes",
"-c",
type=argparse.FileType("r"),
metavar="PATH",
required=True,
help="File with BPE codes (created by learn_bpe.py).")
help="File with BPE codes (created by learn_bpe.py).",
)
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
metavar='PATH',
help="Output file (default: standard output)")
"--output",
"-o",
type=argparse.FileType("w"),
default=sys.stdout,
metavar="PATH",
help="Output file (default: standard output)",
)
parser.add_argument(
'--separator', '-s', type=str, default='@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s'))")
"--separator",
"-s",
type=str,
default="@@",
metavar="STR",
help="Separator between non-final subword units (default: '%(default)s'))",
)
parser.add_argument(
'--vocabulary', type=argparse.FileType('r'), default=None,
"--vocabulary",
type=argparse.FileType("r"),
default=None,
metavar="PATH",
help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.")
help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.",
)
parser.add_argument(
'--vocabulary-threshold', type=int, default=None,
"--vocabulary-threshold",
type=int,
default=None,
metavar="INT",
help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV")
help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV",
)
parser.add_argument(
'--glossaries', type=str, nargs='+', default=None,
"--glossaries",
type=str,
nargs="+",
default=None,
metavar="STR",
help="Glossaries. The strings provided in glossaries will not be affected" +
"by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords")
help="Glossaries. The strings provided in glossaries will not be affected"
+ "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords",
)
return parser
......@@ -136,9 +177,17 @@ def get_pairs(word):
return pairs
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries=None):
"""Encode word based on list of BPE merge operations, which are applied consecutively
"""
def encode(
orig,
bpe_codes,
bpe_codes_reverse,
vocab,
separator,
version,
cache,
glossaries=None,
):
"""Encode word based on list of BPE merge operations, which are applied consecutively"""
if orig in cache:
return cache[orig]
......@@ -148,9 +197,9 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
return (orig,)
if version == (0, 1):
word = tuple(orig) + ('</w>',)
word = tuple(orig) + ("</w>",)
elif version == (0, 2): # more consistent handling of word-final segments
word = tuple(orig[:-1]) + (orig[-1] + '</w>',)
word = tuple(orig[:-1]) + (orig[-1] + "</w>",)
else:
raise NotImplementedError
......@@ -160,7 +209,7 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
return orig
while True:
bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float("inf")))
if bigram not in bpe_codes:
break
first, second = bigram
......@@ -189,10 +238,10 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
pairs = get_pairs(word)
# don't print end-of-word symbols
if word[-1] == '</w>':
if word[-1] == "</w>":
word = word[:-1]
elif word[-1].endswith('</w>'):
word = word[:-1] + (word[-1].replace('</w>', ''),)
elif word[-1].endswith("</w>"):
word = word[:-1] + (word[-1].replace("</w>", ""),)
if vocab:
word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)
......@@ -207,12 +256,12 @@ def recursive_split(segment, bpe_codes, vocab, separator, final=False):
try:
if final:
left, right = bpe_codes[segment + '</w>']
left, right = bpe_codes[segment + "</w>"]
right = right[:-4]
else:
left, right = bpe_codes[segment]
except:
#sys.stderr.write('cannot split {0} further.\n'.format(segment))
# sys.stderr.write('cannot split {0} further.\n'.format(segment))
yield segment
return
......@@ -239,7 +288,7 @@ def check_vocab_and_split(orig, bpe_codes, vocab, separator):
if segment + separator in vocab:
out.append(segment)
else:
#sys.stderr.write('OOV: {0}\n'.format(segment))
# sys.stderr.write('OOV: {0}\n'.format(segment))
for item in recursive_split(segment, bpe_codes, vocab, separator, False):
out.append(item)
......@@ -247,7 +296,7 @@ def check_vocab_and_split(orig, bpe_codes, vocab, separator):
if segment in vocab:
out.append(segment)
else:
#sys.stderr.write('OOV: {0}\n'.format(segment))
# sys.stderr.write('OOV: {0}\n'.format(segment))
for item in recursive_split(segment, bpe_codes, vocab, separator, True):
out.append(item)
......@@ -255,8 +304,7 @@ def check_vocab_and_split(orig, bpe_codes, vocab, separator):
def read_vocabulary(vocab_file, threshold):
"""read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.
"""
"""read vocabulary file produced by get_vocab.py, and filter according to frequency threshold."""
vocabulary = set()
......@@ -282,39 +330,42 @@ def isolate_glossary(word, glossary):
return [word]
else:
splits = word.split(glossary)
segments = [segment.strip() for split in splits[:-1]
for segment in [split, glossary] if segment != '']
return segments + [splits[-1].strip()] if splits[-1] != '' else segments
segments = [
segment.strip()
for split in splits[:-1]
for segment in [split, glossary]
if segment != ""
]
return segments + [splits[-1].strip()] if splits[-1] != "" else segments
if __name__ == '__main__':
if __name__ == "__main__":
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
sys.stderr = codecs.getwriter("UTF-8")(sys.stderr)
sys.stdout = codecs.getwriter("UTF-8")(sys.stdout)
sys.stdin = codecs.getreader("UTF-8")(sys.stdin)
else:
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
sys.stdout = io.TextIOWrapper(
sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True)
sys.stdout.buffer, encoding="utf-8", write_through=True, line_buffering=True
)
parser = create_parser()
args = parser.parse_args()
# read/write files as UTF-8
args.codes = codecs.open(args.codes.name, encoding='utf-8')
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
args.codes = codecs.open(args.codes.name, encoding="utf-8")
if args.input.name != "<stdin>":
args.input = codecs.open(args.input.name, encoding="utf-8")
if args.output.name != "<stdout>":
args.output = codecs.open(args.output.name, "w", encoding="utf-8")
if args.vocabulary:
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
args.vocabulary = codecs.open(args.vocabulary.name, encoding="utf-8")
if args.vocabulary:
vocabulary = read_vocabulary(
args.vocabulary, args.vocabulary_threshold)
vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
else:
vocabulary = None
......@@ -322,4 +373,4 @@ if __name__ == '__main__':
for line in args.input:
args.output.write(bpe.segment(line).strip())
args.output.write('\n')
args.output.write("\n")
......@@ -6,8 +6,73 @@ from .apply_bpe import BPE
from ilstokenizer import tokenizer
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args): self.target_dtype, self.bpes = pb_utils.triton_string_to_numpy(pb_utils.get_output_config_by_name(json.loads(args["model_config"]), "INPUT_TEXT_TOKENIZED")["data_type"]), {fname.rsplit('/', maxsplit=1)[-1][:-len('.src')]: BPE(open(fname, 'r', encoding='utf-8')) for fname in iglob(f'{os.path.dirname(os.path.abspath(__file__))}/bpe_src/*.src')}
def preprocess_text(self, text, source_lang, target_lang): return f"<to-gu> {text} <to-gu>" if source_lang == 'en' and target_lang == 'gu' else text
def execute(self, requests): return [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent] for tokenized_sent in tokenized_sents], dtype=self.target_dtype))]) for tokenized_sents in ((self.bpes[f"{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}"].segment(self.preprocess_text(tokenizer.tokenize(input_text[0].decode('utf-8').lower()), input_language_id[0].decode('utf-8'), output_language_id[0].decode('utf-8'))).strip() for input_text, input_language_id, output_language_id in zip(input_texts.as_numpy(), input_language_ids.as_numpy(), output_language_ids.as_numpy())) for input_texts, input_language_ids, output_language_ids in ((pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID")) for request in requests))]
def finalize(self): pass
\ No newline at end of file
def initialize(self, args):
self.target_dtype, self.bpes = pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(
json.loads(args["model_config"]), "INPUT_TEXT_TOKENIZED"
)["data_type"]
), {
fname.rsplit("/", maxsplit=1)[-1][: -len(".src")]: BPE(
open(fname, "r", encoding="utf-8")
)
for fname in iglob(
f"{os.path.dirname(os.path.abspath(__file__))}/bpe_src/*.src"
)
}
def preprocess_text(self, text, source_lang, target_lang):
return (
f"<to-gu> {text} <to-gu>"
if source_lang == "en" and target_lang == "gu"
else text
)
def execute(self, requests):
return [
pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"INPUT_TEXT_TOKENIZED",
numpy.array(
[[tokenized_sent] for tokenized_sent in tokenized_sents],
dtype=self.target_dtype,
),
)
]
)
for tokenized_sents in (
(
self.bpes[
f"{input_language_id[0].decode('utf-8')}-{output_language_id[0].decode('utf-8')}"
]
.segment(
self.preprocess_text(
tokenizer.tokenize(input_text[0].decode("utf-8").lower()),
input_language_id[0].decode("utf-8"),
output_language_id[0].decode("utf-8"),
)
)
.strip()
for input_text, input_language_id, output_language_id in zip(
input_texts.as_numpy(),
input_language_ids.as_numpy(),
output_language_ids.as_numpy(),
)
)
for input_texts, input_language_ids, output_language_ids in (
(
pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"),
pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"),
pb_utils.get_input_tensor_by_name(
request, "OUTPUT_LANGUAGE_ID"
),
)
for request in requests
)
)
]
def finalize(self):
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment