From b0edb83b86078be60fee37eaba0d555a9b3628d4 Mon Sep 17 00:00:00 2001 From: Nikhilesh Bhatnagar Date: Mon, 24 Jul 2023 14:03:31 +0000 Subject: [PATCH] Deployment scripts --- .gitignore | 1 + Dockerfile | 9 + README.md | 50 +++ make_triton_model_repo.sh | 52 +++ triton_client.ipynb | 92 +++++ triton_models/ssmt_model_demuxer/1/model.py | 22 ++ triton_models/ssmt_model_demuxer/config.pbtxt | 46 +++ triton_models/ssmt_pipeline/config.pbtxt | 82 +++++ .../ssmt_template_model_repo/1/model.py | 26 ++ .../ssmt_template_model_repo/config.pbtxt | 31 ++ triton_models/ssmt_tokenizer/1/apply_bpe.py | 325 ++++++++++++++++++ triton_models/ssmt_tokenizer/1/model.py | 23 ++ triton_models/ssmt_tokenizer/config.pbtxt | 46 +++ 13 files changed, 805 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 make_triton_model_repo.sh create mode 100644 triton_client.ipynb create mode 100644 triton_models/ssmt_model_demuxer/1/model.py create mode 100644 triton_models/ssmt_model_demuxer/config.pbtxt create mode 100644 triton_models/ssmt_pipeline/config.pbtxt create mode 100644 triton_models/ssmt_template_model_repo/1/model.py create mode 100644 triton_models/ssmt_template_model_repo/config.pbtxt create mode 100644 triton_models/ssmt_tokenizer/1/apply_bpe.py create mode 100644 triton_models/ssmt_tokenizer/1/model.py create mode 100644 triton_models/ssmt_tokenizer/config.pbtxt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7048a1d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +ssmt_triton_repo diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..04b307b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,9 @@ +FROM nvcr.io/nvidia/tritonserver:23.06-py3 +WORKDIR /opt/tritonserver +RUN apt-get update && apt-get install -y python3.10-venv +ENV VIRTUAL_ENV=/opt/dhruva-mt +RUN python3 -m venv $VIRTUAL_ENV +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +RUN pip install -U ctranslate2 OpenNMT-py==1.2.0 git+https://github.com/vmujadia/tokenizer.git +CMD ["tritonserver", "--model-repository=/models", "--cache-config=local,size=1048576"] +EXPOSE 8000 diff --git a/README.md b/README.md index 1d0b1e2..75e9c10 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,52 @@ # mt-model-deploy-dhruva +## TL;DR + +This repo contains code for python backend CTranslate2 based triton models for the SSMT project. +Prerequisites: `python3.xx-venv`, `nvidia-docker` +```bash +git clone http://ssmt.iiit.ac.in/meitygit/ssmt/mt-model-deploy-dhruva.git +cd mt-model-deploy-dhruva +sh make_triton_model_repo.sh "https://ssmt.iiit.ac.in/uploads/data_mining/models.zip" "float16" +docker build -t dhruva/ssmt-model-server:1 . +nvidia-docker run --gpus=all --rm --shm-size 5g --network=host --name dhruva-ssmt-triton-server -v./ssmt_triton_repo:/models dhruva/ssmt-model-server:1 +``` + +## What this repo does + +* This repo contains the templates and component triton models for the SSMT project. +* Also contained is a Dockerfile to construct the triton server instance. +* Given a URL and quantization method (those supported by CTranslate2 i.e. `int8`, `int8_float16`, `int8_bfloat16`, `int16`, `float16` and `bfloat16`) it will download, quantize and construct the SSMT Tritom Repository in `./ssmt_triton_repo`. +* Dynamic batching and caching is supported and enabled by default. +* The repository folder can me mounted to the dhruva ssmt triton server on `/models` and can be queried via a client. +* Sample client code is also given as an ipython notebook. +* The `model.zip` package needs to contain a folder of `.pt` and `.src` files named `1` through `9` with each file corresponding to the following mapping: +`{'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9}` + +## Architecture of the pipeline + +The pipeline consists of 4 components, executed in order: +* Tokenizer - This model is CPU only and tokenizes and applies BPE on the input string. +* Model Demuxer - This model is CPU only and depending on the language pair requested, queues up an `InferenceRequest` for the appropriate model and returns it as the response. +* Model - This is a GPU based model and it processes the tokenized text and returns the final form of the translated text to the caller. +* Pipeline - This is an ensemble model that wraps the above three components together and is the one meant to be exposed to the client. +The exact specifications of the model inputs and outputs can be looked at in the corresponding `config.pbtxt` files. +One can construct the triton repo like so: +```bash +git clone http://ssmt.iiit.ac.in/meitygit/ssmt/mt-model-deploy-dhruva.git +cd mt-model-deploy-dhruva +sh make_triton_model_repo.sh "https://ssmt.iiit.ac.in/uploads/data_mining/models.zip" "float16" +``` + +## Starting the triton server + +We customize the tritonserver image with the required python packages in a venv and enable the cache in the startup command. After the model repo has beeen built, one can build and run the server like so: +```bash +docker build -t dhruva/ssmt-model-server:1 . +nvidia-docker run --gpus=all --rm --shm-size 5g --network=host --name dhruva-ssmt-triton-server -v./ssmt_triton_repo:/models dhruva/ssmt-model-server:1 +``` + +## Querying the triton server + +We provide a sample ipython notebook that shows how to concurrently request the client for translations. +Prerequisites: `pip install "tritonclient[all]" tqdm numpy` diff --git a/make_triton_model_repo.sh b/make_triton_model_repo.sh new file mode 100644 index 0000000..8b5a8fc --- /dev/null +++ b/make_triton_model_repo.sh @@ -0,0 +1,52 @@ +MODELS_URL=$1 +QUANTIZATION=$2 +wget -O models.zip $MODELS_URL +unzip models.zip +python3 -m venv ./ssmt_ct2 +source ./ssmt_ct2/bin/activate +pip install ctranslate2 "OpenNMT-py==1.2.0" +cd models +ct2-opennmt-py-converter --model_path 1.pt --quantization $QUANTIZATION --output_dir ./1_ct2 +ct2-opennmt-py-converter --model_path 2.pt --quantization $QUANTIZATION --output_dir ./2_ct2 +ct2-opennmt-py-converter --model_path 3.pt --quantization $QUANTIZATION --output_dir ./3_ct2 +ct2-opennmt-py-converter --model_path 4.pt --quantization $QUANTIZATION --output_dir ./4_ct2 +ct2-opennmt-py-converter --model_path 6.pt --quantization $QUANTIZATION --output_dir ./6_ct2 +ct2-opennmt-py-converter --model_path 7.pt --quantization $QUANTIZATION --output_dir ./7_ct2 +ct2-opennmt-py-converter --model_path 8.pt --quantization $QUANTIZATION --output_dir ./8_ct2 +ct2-opennmt-py-converter --model_path 9.pt --quantization $QUANTIZATION --output_dir ./9_ct2 +cd .. +mkdir ssmt_triton_repo +cd ssmt_triton_repo +cp -r ../triton_models/ssmt_pipeline . +cp -r ../triton_models/ssmt_model_demuxer . +cp -r ../triton_models/ssmt_tokenizer . +cp -r ../models/*.src ssmt_tokenizer/1/bpe_src +cp -r ../triton_models/ssmt_template_model_repo ssmt_1_ct2 +cp -r ../models/1_ct2 ssmt_1_ct2/1/translator +sed -i 's/model_name/ssmt_1_ct2/' ssmt_1_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_2_ct2 +cp -r ../models/2_ct2 ssmt_2_ct2/1/translator +sed -i 's/model_name/ssmt_2_ct2/' ssmt_2_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_3_ct2 +cp -r ../models/3_ct2 ssmt_3_ct2/1/translator +sed -i 's/model_name/ssmt_3_ct2/' ssmt_3_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_4_ct2 +cp -r ../models/4_ct2 ssmt_4_ct2/1/translator +sed -i 's/model_name/ssmt_4_ct2/' ssmt_4_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_6_ct2 +cp -r ../models/6_ct2 ssmt_6_ct2/1/translator +sed -i 's/model_name/ssmt_6_ct2/' ssmt_6_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_7_ct2 +cp -r ../models/7_ct2 ssmt_7_ct2/1/translator +sed -i 's/model_name/ssmt_7_ct2/' ssmt_7_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_8_ct2 +cp -r ../models/8_ct2 ssmt_8_ct2/1/translator +sed -i 's/model_name/ssmt_8_ct2/' ssmt_8_ct2/config.pbtxt +cp -r ../triton_models/ssmt_template_model_repo ssmt_9_ct2 +cp -r ../models/9_ct2 ssmt_9_ct2/1/translator +sed -i 's/model_name/ssmt_9_ct2/' ssmt_9_ct2/config.pbtxt +cd .. +source deactivate +rm -rf ssmt_ct2 +rm -f models.zip +rm -rf models diff --git a/triton_client.ipynb b/triton_client.ipynb new file mode 100644 index 0000000..e993bee --- /dev/null +++ b/triton_client.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tqdm import tqdm\n", + "from random import choice\n", + "from tritonclient.utils import *\n", + "import tritonclient.http as httpclient\n", + "from multiprocessing.pool import ThreadPool" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"ssmt_pipeline\"\n", + "shape = [1]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def task(x):\n", + " lang_pair_map = list({'eng-hin': 1, 'hin-eng': 2, 'eng-tel':3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9}.keys())\n", + " with httpclient.InferenceServerClient(\"localhost:8000\") as client:\n", + " async_responses = []\n", + " for i in range(10):\n", + " s = 'this is a sentence.'\n", + " source_data = np.array([[s]], dtype='object')\n", + " inputs = [httpclient.InferInput(\"INPUT_TEXT\", source_data.shape, np_to_triton_dtype(source_data.dtype)), httpclient.InferInput(\"INPUT_LANGUAGE_ID\", source_data.shape, np_to_triton_dtype(source_data.dtype)), httpclient.InferInput(\"OUTPUT_LANGUAGE_ID\", source_data.shape, np_to_triton_dtype(source_data.dtype))]\n", + " inputs[0].set_data_from_numpy(np.array([[s]], dtype='object'))\n", + " langpair = choice(lang_pair_map)\n", + " inputs[1].set_data_from_numpy(np.array([[langpair.split('-')[0].strip()]], dtype='object'))\n", + " inputs[2].set_data_from_numpy(np.array([[langpair.split('-')[1].strip()]], dtype='object'))\n", + " outputs = [httpclient.InferRequestedOutput(\"OUTPUT_TEXT\")]\n", + " async_responses.append(client.async_infer(model_name, inputs, request_id=str(1), outputs=outputs))\n", + " for r in async_responses: r.get_result(timeout=10).get_response()\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1000/1000 [01:49<00:00, 9.15it/s]\n" + ] + } + ], + "source": [ + "with ThreadPool(100) as pool:\n", + " for output in tqdm(pool.imap_unordered(task, range(1000), chunksize=1), total=1000): pass" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "model_metrics", + "language": "python", + "name": "model_metrics" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/triton_models/ssmt_model_demuxer/1/model.py b/triton_models/ssmt_model_demuxer/1/model.py new file mode 100644 index 0000000..5c4e289 --- /dev/null +++ b/triton_models/ssmt_model_demuxer/1/model.py @@ -0,0 +1,22 @@ +import json +import asyncio +import triton_python_backend_utils as pb_utils + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_TEXT") + self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) + self.lang_pair_map = {'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9} + + async def execute(self, requests): + responses = [] + infer_response_awaits = [] + for request in requests: + language_pair = f"{pb_utils.get_input_tensor_by_name(request, 'INPUT_LANGUAGE_ID').as_numpy()[0, 0].decode('utf-8')}-{pb_utils.get_input_tensor_by_name(request, 'OUTPUT_LANGUAGE_ID').as_numpy()[0, 0].decode('utf-8')}" + inference_request = pb_utils.InferenceRequest(model_name=f'ssmt_{self.lang_pair_map[language_pair]}_ct2', requested_output_names=['OUTPUT_TEXT'], inputs=[pb_utils.get_input_tensor_by_name(request, 'INPUT_TEXT_TOKENIZED')]) + infer_response_awaits.append(inference_request.async_exec()) + responses = await asyncio.gather(*infer_response_awaits) + return responses + + def finalize(self): pass diff --git a/triton_models/ssmt_model_demuxer/config.pbtxt b/triton_models/ssmt_model_demuxer/config.pbtxt new file mode 100644 index 0000000..1e0a3db --- /dev/null +++ b/triton_models/ssmt_model_demuxer/config.pbtxt @@ -0,0 +1,46 @@ +name: "ssmt_model_demuxer" +backend: "python" +max_batch_size: 4096 + +input [ + { + name: "INPUT_TEXT_TOKENIZED" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +input [ + { + name: "INPUT_LANGUAGE_ID" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +input [ + { + name: "OUTPUT_LANGUAGE_ID" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +output [ + { + name: "OUTPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +dynamic_batching {} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] + +response_cache { + enable: true +} diff --git a/triton_models/ssmt_pipeline/config.pbtxt b/triton_models/ssmt_pipeline/config.pbtxt new file mode 100644 index 0000000..d519b03 --- /dev/null +++ b/triton_models/ssmt_pipeline/config.pbtxt @@ -0,0 +1,82 @@ +name: "ssmt_pipeline" +platform: "ensemble" +max_batch_size: 4096 + +input [ + { + name: "INPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +input [ + { + name: "INPUT_LANGUAGE_ID" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +input [ + { + name: "OUTPUT_LANGUAGE_ID" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +output [ + { + name: "OUTPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +ensemble_scheduling { + step [ + { + model_name: "ssmt_tokenizer" + model_version: 1 + input_map { + key: "INPUT_TEXT" + value: "INPUT_TEXT" + } + input_map { + key: "INPUT_LANGUAGE_ID" + value: "INPUT_LANGUAGE_ID" + } + input_map { + key: "OUTPUT_LANGUAGE_ID" + value: "OUTPUT_LANGUAGE_ID" + } + output_map { + key: "INPUT_TEXT_TOKENIZED" + value: "INPUT_TEXT_TOKENIZED" + } + }, + { + model_name: "ssmt_model_demuxer" + model_version: 1 + input_map { + key: "INPUT_TEXT_TOKENIZED" + value: "INPUT_TEXT_TOKENIZED" + } + input_map { + key: "INPUT_LANGUAGE_ID" + value: "INPUT_LANGUAGE_ID" + } + input_map { + key: "OUTPUT_LANGUAGE_ID" + value: "OUTPUT_LANGUAGE_ID" + } + output_map { + key: "OUTPUT_TEXT" + value: "OUTPUT_TEXT" + } + } + ] +} + +response_cache { + enable: true +} diff --git a/triton_models/ssmt_template_model_repo/1/model.py b/triton_models/ssmt_template_model_repo/1/model.py new file mode 100644 index 0000000..7c636bb --- /dev/null +++ b/triton_models/ssmt_template_model_repo/1/model.py @@ -0,0 +1,26 @@ +import os +import json +import numpy +from itertools import islice +from ctranslate2 import Translator +import triton_python_backend_utils as pb_utils + +class TritonPythonModel: + def initialize(self, args): + current_path = os.path.dirname(os.path.abspath(__file__)) + self.model_config = json.loads(args["model_config"]) + self.device_id = int(json.loads(args['model_instance_device_id'])) + target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_TEXT") + self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) + try: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cuda", intra_threads=1, inter_threads=1, device_index=[self.device_id]) + except: self.translator = Translator(f"{os.path.join(current_path, 'translator')}", device="cpu", intra_threads=4) + + def execute(self, requests): + source_list = [pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT_TOKENIZED") for request in requests] + bsize_list = [source.as_numpy().shape[0] for source in source_list] + src_sentences = [s[0].decode('utf-8').strip().split(' ') for source in source_list for s in source.as_numpy()] + tgt_sentences = [' '.join(result.hypotheses[0]).replace('@@ ', '') for result in self.translator.translate_iterable(src_sentences, max_batch_size=128, max_input_length=100, max_decoding_length=100)] + responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("OUTPUT_TEXT", numpy.array([[s]for s in islice(tgt_sentences, bsize)], dtype='object').astype(self.target_dtype))]) for bsize in bsize_list] + return responses + + def finalize(self): self.translator.unload_model() diff --git a/triton_models/ssmt_template_model_repo/config.pbtxt b/triton_models/ssmt_template_model_repo/config.pbtxt new file mode 100644 index 0000000..0127f90 --- /dev/null +++ b/triton_models/ssmt_template_model_repo/config.pbtxt @@ -0,0 +1,31 @@ +name: "model_name" +backend: "python" +max_batch_size: 512 + +input [ + { + name: "INPUT_TEXT_TOKENIZED" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +output [ + { + name: "OUTPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +dynamic_batching {} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] + +response_cache { + enable: true +} diff --git a/triton_models/ssmt_tokenizer/1/apply_bpe.py b/triton_models/ssmt_tokenizer/1/apply_bpe.py new file mode 100644 index 0000000..2fe666e --- /dev/null +++ b/triton_models/ssmt_tokenizer/1/apply_bpe.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Author: Rico Sennrich +# flake8: noqa + +"""Use operations learned with learn_bpe.py to encode a new text. +The text will not be smaller, but use only a fixed vocabulary, with rare words +encoded as variable-length sequences of subword units. + +Reference: +Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units. +Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. +""" +# This file is retrieved from https://github.com/rsennrich/subword-nmt + +from __future__ import unicode_literals, division + +import sys +import codecs +import io +import argparse +import json +import re +from collections import defaultdict + +# hack for python2/3 compatibility +from io import open +argparse.open = open + + +class BPE(object): + + def __init__(self, codes, separator='@@', vocab=None, glossaries=None): + + # check version information + firstline = codes.readline() + if firstline.startswith('#version:'): + self.version = tuple([int(x) for x in re.sub( + r'(\.0+)*$', '', firstline.split()[-1]).split(".")]) + else: + self.version = (0, 1) + codes.seek(0) + + self.bpe_codes = [tuple(item.split()) for item in codes] + + # some hacking to deal with duplicates (only consider first instance) + self.bpe_codes = dict( + [(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))]) + + self.bpe_codes_reverse = dict( + [(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()]) + + self.separator = separator + + self.vocab = vocab + + self.glossaries = glossaries if glossaries else [] + + self.cache = {} + + def segment(self, sentence): + """segment single sentence (whitespace-tokenized string) with BPE encoding""" + output = [] + for word in sentence.split(): + new_word = [out for segment in self._isolate_glossaries(word) + for out in encode(segment, + self.bpe_codes, + self.bpe_codes_reverse, + self.vocab, + self.separator, + self.version, + self.cache, + self.glossaries)] + + for item in new_word[:-1]: + output.append(item + self.separator) + output.append(new_word[-1]) + + return ' '.join(output) + + def _isolate_glossaries(self, word): + word_segments = [word] + for gloss in self.glossaries: + word_segments = [out_segments for segment in word_segments + for out_segments in isolate_glossary(segment, gloss)] + return word_segments + + +def create_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="learn BPE-based word segmentation") + + parser.add_argument( + '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, + metavar='PATH', + help="Input file (default: standard input).") + parser.add_argument( + '--codes', '-c', type=argparse.FileType('r'), metavar='PATH', + required=True, + help="File with BPE codes (created by learn_bpe.py).") + parser.add_argument( + '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, + metavar='PATH', + help="Output file (default: standard output)") + parser.add_argument( + '--separator', '-s', type=str, default='@@', metavar='STR', + help="Separator between non-final subword units (default: '%(default)s'))") + parser.add_argument( + '--vocabulary', type=argparse.FileType('r'), default=None, + metavar="PATH", + help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.") + parser.add_argument( + '--vocabulary-threshold', type=int, default=None, + metavar="INT", + help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV") + parser.add_argument( + '--glossaries', type=str, nargs='+', default=None, + metavar="STR", + help="Glossaries. The strings provided in glossaries will not be affected" + + "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords") + + return parser + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + word is represented as tuple of symbols (symbols being variable-length strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries=None): + """Encode word based on list of BPE merge operations, which are applied consecutively + """ + + if orig in cache: + return cache[orig] + + if orig in glossaries: + cache[orig] = (orig,) + return (orig,) + + if version == (0, 1): + word = tuple(orig) + ('',) + elif version == (0, 2): # more consistent handling of word-final segments + word = tuple(orig[:-1]) + (orig[-1] + '',) + else: + raise NotImplementedError + + pairs = get_pairs(word) + + if not pairs: + return orig + + while True: + bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float('inf'))) + if bigram not in bpe_codes: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + + # don't print end-of-word symbols + if word[-1] == '': + word = word[:-1] + elif word[-1].endswith(''): + word = word[:-1] + (word[-1].replace('', ''),) + + if vocab: + word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator) + + cache[orig] = word + return word + + +def recursive_split(segment, bpe_codes, vocab, separator, final=False): + """Recursively split segment into smaller units (by reversing BPE merges) + until all units are either in-vocabulary, or cannot be split futher.""" + + try: + if final: + left, right = bpe_codes[segment + ''] + right = right[:-4] + else: + left, right = bpe_codes[segment] + except: + #sys.stderr.write('cannot split {0} further.\n'.format(segment)) + yield segment + return + + if left + separator in vocab: + yield left + else: + for item in recursive_split(left, bpe_codes, vocab, separator, False): + yield item + + if (final and right in vocab) or (not final and right + separator in vocab): + yield right + else: + for item in recursive_split(right, bpe_codes, vocab, separator, final): + yield item + + +def check_vocab_and_split(orig, bpe_codes, vocab, separator): + """Check for each segment in word if it is in-vocabulary, + and segment OOV segments into smaller units by reversing the BPE merge operations""" + + out = [] + + for segment in orig[:-1]: + if segment + separator in vocab: + out.append(segment) + else: + #sys.stderr.write('OOV: {0}\n'.format(segment)) + for item in recursive_split(segment, bpe_codes, vocab, separator, False): + out.append(item) + + segment = orig[-1] + if segment in vocab: + out.append(segment) + else: + #sys.stderr.write('OOV: {0}\n'.format(segment)) + for item in recursive_split(segment, bpe_codes, vocab, separator, True): + out.append(item) + + return out + + +def read_vocabulary(vocab_file, threshold): + """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold. + """ + + vocabulary = set() + + for line in vocab_file: + word, freq = line.split() + freq = int(freq) + if threshold == None or freq >= threshold: + vocabulary.add(word) + + return vocabulary + + +def isolate_glossary(word, glossary): + """ + Isolate a glossary present inside a word. + + Returns a list of subwords. In which all 'glossary' glossaries are isolated + + For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is: + ['1934', 'USA', 'B', 'USA'] + """ + if word == glossary or glossary not in word: + return [word] + else: + splits = word.split(glossary) + segments = [segment.strip() for split in splits[:-1] + for segment in [split, glossary] if segment != ''] + return segments + [splits[-1].strip()] if splits[-1] != '' else segments + + +if __name__ == '__main__': + + # python 2/3 compatibility + if sys.version_info < (3, 0): + sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) + sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) + sys.stdin = codecs.getreader('UTF-8')(sys.stdin) + else: + sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) + + parser = create_parser() + args = parser.parse_args() + + # read/write files as UTF-8 + args.codes = codecs.open(args.codes.name, encoding='utf-8') + if args.input.name != '': + args.input = codecs.open(args.input.name, encoding='utf-8') + if args.output.name != '': + args.output = codecs.open(args.output.name, 'w', encoding='utf-8') + if args.vocabulary: + args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8') + + if args.vocabulary: + vocabulary = read_vocabulary( + args.vocabulary, args.vocabulary_threshold) + else: + vocabulary = None + + bpe = BPE(args.codes, args.separator, vocabulary, args.glossaries) + + for line in args.input: + args.output.write(bpe.segment(line).strip()) + args.output.write('\n') diff --git a/triton_models/ssmt_tokenizer/1/model.py b/triton_models/ssmt_tokenizer/1/model.py new file mode 100644 index 0000000..7a85827 --- /dev/null +++ b/triton_models/ssmt_tokenizer/1/model.py @@ -0,0 +1,23 @@ +import os +import json +import numpy +from .apply_bpe import BPE +from ilstokenizer import tokenizer +import triton_python_backend_utils as pb_utils + +class TritonPythonModel: + def initialize(self, args): + current_path = os.path.dirname(os.path.abspath(__file__)) + self.model_config = json.loads(args["model_config"]) + target_config = pb_utils.get_output_config_by_name(self.model_config, "INPUT_TEXT_TOKENIZED") + self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) + self.lang_pair_map = {'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9} + self.bpes = {lang_pair: BPE(open(os.path.join(current_path, f'bpe_src/{model_id}.src'), encoding='utf-8')) for lang_pair, model_id in self.lang_pair_map.items()} + + def execute(self, requests): + source_gen = ((pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT"), pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID"), pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID")) for request in requests) + tokenized_gen = (self.bpes[f"{input_language_id.as_numpy()[0, 0].decode('utf-8')}-{output_language_id.as_numpy()[0, 0].decode('utf-8')}"].segment(tokenizer.tokenize(input_text.as_numpy()[0, 0].decode('utf-8'))).strip() for input_text, input_language_id, output_language_id in source_gen) + responses = [pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("INPUT_TEXT_TOKENIZED", numpy.array([[tokenized_sent]], dtype=self.target_dtype))]) for tokenized_sent in tokenized_gen] + return responses + + def finalize(self): pass diff --git a/triton_models/ssmt_tokenizer/config.pbtxt b/triton_models/ssmt_tokenizer/config.pbtxt new file mode 100644 index 0000000..051e2e4 --- /dev/null +++ b/triton_models/ssmt_tokenizer/config.pbtxt @@ -0,0 +1,46 @@ +name: "ssmt_tokenizer" +backend: "python" +max_batch_size: 4096 + +input [ + { + name: "INPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +input [ + { + name: "INPUT_LANGUAGE_ID" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +input [ + { + name: "OUTPUT_LANGUAGE_ID" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +output [ + { + name: "INPUT_TEXT_TOKENIZED" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +dynamic_batching {} + +instance_group [ + { + count: 8 + kind: KIND_CPU + } +] + +response_cache { + enable: true +} -- GitLab