triton_client.ipynb 2.92 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from random import choice\n",
    "from tritonclient.utils import *\n",
    "import tritonclient.http as httpclient\n",
    "from multiprocessing.pool import ThreadPool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"ssmt_pipeline\"\n",
    "shape = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def task(x):\n",
    "    lang_pair_map = list({'eng-hin': 1, 'hin-eng': 2, 'eng-tel':3, 'tel-eng': 4, 'hin-tel': 6, 'tel-hin': 7, 'eng-guj': 8, 'guj-eng': 9}.keys())\n",
    "    with httpclient.InferenceServerClient(\"localhost:8000\") as client:\n",
    "        async_responses = []\n",
    "        for i in range(10):\n",
    "            s = 'this is a sentence.'\n",
    "            source_data = np.array([[s]], dtype='object')\n",
    "            inputs = [httpclient.InferInput(\"INPUT_TEXT\", source_data.shape, np_to_triton_dtype(source_data.dtype)), httpclient.InferInput(\"INPUT_LANGUAGE_ID\", source_data.shape, np_to_triton_dtype(source_data.dtype)), httpclient.InferInput(\"OUTPUT_LANGUAGE_ID\", source_data.shape, np_to_triton_dtype(source_data.dtype))]\n",
    "            inputs[0].set_data_from_numpy(np.array([[s]], dtype='object'))\n",
    "            langpair = choice(lang_pair_map)\n",
    "            inputs[1].set_data_from_numpy(np.array([[langpair.split('-')[0].strip()]], dtype='object'))\n",
    "            inputs[2].set_data_from_numpy(np.array([[langpair.split('-')[1].strip()]], dtype='object'))\n",
    "            outputs = [httpclient.InferRequestedOutput(\"OUTPUT_TEXT\")]\n",
    "            async_responses.append(client.async_infer(model_name, inputs, request_id=str(1), outputs=outputs))\n",
    "        for r in async_responses: r.get_result(timeout=10).get_response()\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [01:49<00:00,  9.15it/s]\n"
     ]
    }
   ],
   "source": [
    "with ThreadPool(100) as pool:\n",
    "    for output in tqdm(pool.imap_unordered(task, range(1000), chunksize=1), total=1000): pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "model_metrics",
   "language": "python",
   "name": "model_metrics"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}