model.py 4 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3
import os
import json
import numpy
4
from time import time
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
5
from itertools import islice
6
from threading import Lock, Timer
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
7 8
from ctranslate2 import Translator
import triton_python_backend_utils as pb_utils
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
from tenacity import retry, wait_random_exponential


class DynamicModel(object):
    def __init__(
        self, path, device, device_index=None, timeout=5, timer_min_delta=0.01
    ):
        self.model, self.model_path, self.model_device, self.model_device_index = (
            None,
            path,
            device,
            device_index,
        )
        self.model_lock, self.timer_lock = Lock(), Lock()
        self.timeout, self.timer_min_delta = timeout, timer_min_delta
        self.initialize()

    @retry(wait=wait_random_exponential(multiplier=0.5, max=10, exp_base=1.2))
    def initialize(self):
        self.model = Translator(
            self.model_path,
            device=self.model_device,
            intra_threads=1,
            inter_threads=1,
            device_index=self.model_device_index,
        )
        self.timer = Timer(1, self.unload)
        self.timer.start_time = time()
        self.timer.start()

    def restart_timer(self):
        with self.timer_lock:
            if time() - self.timer.start_time >= self.timer_min_delta:
                self.timer.cancel()
                self.timer = Timer(self.timeout, self.unload)
                self.timer.start_time = time()
                self.timer.start()

    @retry(wait=wait_random_exponential(multiplier=0.5, max=20, exp_base=1.2))
    def load(self, reset_timer=True):
        with self.timer_lock:
            self.timer.cancel()
        with self.model_lock:
            self.model.load_model()
        if reset_timer:
            self.restart_timer()

    def unload(self):
        with self.model_lock:
            self.model.unload_model()

    @retry(wait=wait_random_exponential(multiplier=0.5, max=20, exp_base=1.2))
    def translate(self, *args, **kwargs):
        if not self.model.model_is_loaded:
            self.load(reset_timer=False)
        results = list(self.model.translate_iterable(*args, **kwargs))
        self.restart_timer()
        return results

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
68 69 70 71 72

class TritonPythonModel:
    def initialize(self, args):
        current_path = os.path.dirname(os.path.abspath(__file__))
        self.model_config = json.loads(args["model_config"])
73 74 75 76
        self.device_id = int(json.loads(args["model_instance_device_id"]))
        target_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT_TEXT"
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
77
        self.target_dtype = pb_utils.triton_string_to_numpy(target_config["data_type"])
78 79 80 81 82
        self.translator = DynamicModel(
            f"{os.path.join(current_path, 'translator')}",
            device="cuda",
            device_index=[self.device_id],
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
83 84

    def execute(self, requests):
85 86 87 88
        source_list = [
            pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT_TOKENIZED")
            for request in requests
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
89
        bsize_list = [source.as_numpy().shape[0] for source in source_list]
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        src_sentences = [
            s[0].decode("utf-8").strip().split(" ")
            for source in source_list
            for s in source.as_numpy()
        ]
        tgt_sentences = [
            " ".join(result.hypotheses[0]).replace("@@ ", "")
            for result in self.translator.translate(
                src_sentences,
                max_batch_size=128,
                max_input_length=100,
                max_decoding_length=100,
            )
        ]
        responses = [
            pb_utils.InferenceResponse(
                output_tensors=[
                    pb_utils.Tensor(
                        "OUTPUT_TEXT",
                        numpy.array(
                            [[s] for s in islice(tgt_sentences, bsize)], dtype="object"
                        ).astype(self.target_dtype),
                    )
                ]
            )
            for bsize in bsize_list
        ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
117 118
        return responses

119 120
    def finalize(self):
        self.translator.unload()