diff --git a/README.md b/README.md index 3ad2698b3b9e8a4749bd62ab6a92cde646d71e70..91f898899b51ede9ef7947c9ac1ceaf870a66a50 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,11 @@ nvidia-docker run --gpus=all --rm --shm-size 5g --network=host --name dhruva-ssm * This repo contains the templates and component triton models for the SSMT project. * Also contained is a Dockerfile to construct the triton server instance. -* Given a URL and quantization method (those supported by CTranslate2 i.e. `int8`, `int8_float16`, `int8_bfloat16`, `int16`, `float16` and `bfloat16`) it will download, quantize and construct the SSMT Tritom Repository in `./ssmt_triton_repo`. +* Given a URL and quantization method (those supported by CTranslate2 i.e. `int8`, `int8_float16`, `int8_bfloat16`, `int16`, `float16` and `bfloat16`) it will download, quantize and construct the SSMT Triton Repository in `./ssmt_triton_repo`. * Dynamic batching and caching is supported and enabled by default. * The repository folder can me mounted to the dhruva ssmt triton server on `/models` and can be queried via a client. * Sample client code is also given as an ipython notebook. -* The `model.zip` package needs to contain a folder of `.pt` and `.src` files named `1` through `9` with each file corresponding to the following mapping: -`{'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9}` +* The `model.zip` package needs to contain a folder of `.pt` and `.src` files named `1` through `9` with each file corresponding to the following mapping: `{'en-hi': 1, 'hi-en': 2, 'en-te': 3, 'te-en': 4, 'hi-te': 6, 'te-hi': 7, 'en-gu': 8, 'gu-en': 9}` ## Architecture of the pipeline @@ -32,6 +31,7 @@ The pipeline consists of 4 components, executed in order: * Model Demuxer - This model is CPU only and depending on the language pair requested, queues up an `InferenceRequest` for the appropriate model and returns it as the response. * Model - This is a GPU based model and it processes the tokenized text and returns the final form of the translated text to the caller. * Pipeline - This is an ensemble model that wraps the above three components together and is the one meant to be exposed to the client. + The exact specifications of the model inputs and outputs can be looked at in the corresponding `config.pbtxt` files. One can construct the triton repo like so: ```bash diff --git a/triton_models/ssmt_model_demuxer/1/model.py b/triton_models/ssmt_model_demuxer/1/model.py index 5c4e289e9ad64c3e11062290784f93f2ba66a3a3..e8c67449c8b134788466c5bc1da6035c4c923c80 100644 --- a/triton_models/ssmt_model_demuxer/1/model.py +++ b/triton_models/ssmt_model_demuxer/1/model.py @@ -7,7 +7,7 @@ class TritonPythonModel: self.model_config = json.loads(args["model_config"]) target_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT_TEXT") self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) - self.lang_pair_map = {'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9} + self.lang_pair_map = {'en-hi': 1, 'hi-en': 2, 'en-te': 3, 'te-en': 4, 'hi-te': 6, 'te-hi': 7, 'en-gu': 8, 'gu-en': 9} async def execute(self, requests): responses = [] diff --git a/triton_models/ssmt_tokenizer/1/model.py b/triton_models/ssmt_tokenizer/1/model.py index 7a85827357c6621309e44b7896cbfffc2cf72c5f..76433e883eb02c89e3967805dbb5f5f35fd930ec 100644 --- a/triton_models/ssmt_tokenizer/1/model.py +++ b/triton_models/ssmt_tokenizer/1/model.py @@ -11,7 +11,7 @@ class TritonPythonModel: self.model_config = json.loads(args["model_config"]) target_config = pb_utils.get_output_config_by_name(self.model_config, "INPUT_TEXT_TOKENIZED") self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"]) - self.lang_pair_map = {'eng-hin': 1, 'hin-eng': 2, 'eng-tel': 3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9} + self.lang_pair_map = {'en-hi': 1, 'hi-en': 2, 'en-te': 3, 'te-en': 4, 'hi-te': 6, 'te-hi': 7, 'en-gu': 8, 'gu-en': 9} self.bpes = {lang_pair: BPE(open(os.path.join(current_path, f'bpe_src/{model_id}.src'), encoding='utf-8')) for lang_pair, model_id in self.lang_pair_map.items()} def execute(self, requests):