Skip to content

Commit 58933e2

Browse files
authored
Add asyncio support for tritonclient (beta) (#23)
* Add aio_grpcclient, aio_httpclient * add test_multiple_tasks test-case * tritony==0.0.19
1 parent 9629d07 commit 58933e2

File tree

13 files changed

+399
-62
lines changed

13 files changed

+399
-62
lines changed

.github/workflows/pre-commit_pytest.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ jobs:
5252
TRITON_CLIENT_TIMEOUT: 1
5353
run: |
5454
tritonserver --model-repo=$GITHUB_WORKSPACE/model_repository &
55-
pip install .[tests]
55+
pip install uv
56+
uv pip install .[tests]
5657
sleep 3
5758
5859
curl -v ${TRITON_HOST}:${TRITON_HTTP}/v2/health/ready

bin/run_triton_tritony_sample.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ docker run -it --rm --name triton_tritony \
1111
-e OMP_NUM_THREADS=2 \
1212
-e OPENBLAS_NUM_THREADS=2 \
1313
--shm-size=1g \
14-
nvcr.io/nvidia/tritonserver:23.08-pyt-python-py3 \
14+
nvcr.io/nvidia/tritonserver:24.05-pyt-python-py3 \
1515
tritonserver --model-repository=/models \
1616
--exit-timeout-secs 15 \
1717
--min-supported-compute-capability 7.0 \
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import json
2+
import time
3+
4+
import triton_python_backend_utils as pb_utils
5+
6+
7+
class TritonPythonModel:
8+
def initialize(self, args):
9+
self.model_config = model_config = json.loads(args["model_config"])
10+
output_configs = model_config["output"]
11+
12+
self.output_name_list = [output_config["name"] for output_config in output_configs]
13+
self.output_dtype_list = [
14+
pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs
15+
]
16+
17+
def execute(self, requests):
18+
responses = [None for _ in requests]
19+
for idx, request in enumerate(requests):
20+
current_add_value = int(json.loads(request.parameters()).get("add", 0))
21+
in_tensor = [item.as_numpy() + current_add_value for item in request.inputs() if item.name() == "model_in"]
22+
out_tensor = [
23+
pb_utils.Tensor(output_name, x.astype(output_dtype))
24+
for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list)
25+
]
26+
inference_response = pb_utils.InferenceResponse(output_tensors=out_tensor)
27+
responses[idx] = inference_response
28+
time.sleep(1)
29+
return responses
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
backend: "python"
2+
max_batch_size: 0
3+
4+
input [
5+
{
6+
name: "model_in"
7+
data_type: TYPE_FP32
8+
dims: [ -1 ]
9+
}
10+
]
11+
12+
output [
13+
{
14+
name: "model_out"
15+
data_type: TYPE_FP32
16+
dims: [ -1 ]
17+
}
18+
]
19+
20+
instance_group [{ kind: KIND_CPU, count: 10 }]
21+
22+
model_warmup {
23+
name: "RandomSampleInput"
24+
batch_size: 1
25+
inputs [{
26+
key: "model_in"
27+
value: {
28+
data_type: TYPE_FP32
29+
dims: [ 10 ]
30+
random_data: true
31+
}
32+
}, {
33+
key: "model_in"
34+
value: {
35+
data_type: TYPE_FP32
36+
dims: [ 10 ]
37+
zero_data: true
38+
}
39+
}]
40+
}

packaging.md

Lines changed: 0 additions & 16 deletions
This file was deleted.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ line-length = 120
1717
ignore = ["F811","F841","E203","E402","E501","E712","B019"]
1818

1919
[tool.lint.isort]
20-
forced-separate = ["tests"]
20+
forced-separate = ["tests"]
21+
22+
[tool.pytest.ini_options]
23+
asyncio_default_fixture_loop_scope = "function"

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
log_cli=true
3+
log_level=NOTSET

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ tests =
6969
pytest-xdist
7070
pytest-mpl
7171
pytest-cov
72+
pytest-asyncio
7273
pytest
7374
pre-commit
7475
coveralls

tests/common_fixtures.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import logging
12
import os
23

34
import pytest
45

6+
logging.basicConfig(
7+
level=logging.INFO,
8+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
9+
datefmt="%m/%d/%Y %H:%M:%S",
10+
)
11+
12+
513
MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
614
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
715
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8100")
@@ -14,3 +22,11 @@ def config(request):
1422
Returns a tuple of (protocol, port, run_async)
1523
"""
1624
return request.param
25+
26+
27+
@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
28+
def async_config(request):
29+
"""
30+
Returns a tuple of (protocol, port)
31+
"""
32+
return request.param

tests/test_async_connect.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import asyncio
2+
import logging
3+
4+
import numpy as np
5+
import pytest
6+
7+
from tritony import InferenceClient
8+
9+
from .common_fixtures import MODEL_NAME, TRITON_HOST, async_config
10+
11+
logging.basicConfig(
12+
level=logging.INFO,
13+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
14+
datefmt="%m/%d/%Y %H:%M:%S",
15+
)
16+
17+
__all__ = ["async_config"]
18+
EPSILON = 1e-8
19+
20+
21+
def get_client(protocol, port, model_name):
22+
print(f"Testing {protocol}", flush=True)
23+
return InferenceClient.create_with_asyncio(model_name, f"{TRITON_HOST}:{port}", protocol=protocol)
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_basics(async_config):
28+
protocol, port = async_config
29+
30+
client = get_client(*async_config, model_name=MODEL_NAME)
31+
sample = np.random.rand(1, 100).astype(np.float32)
32+
33+
result = await client.aio_infer(sample)
34+
assert np.isclose(result, sample).all()
35+
36+
result = await client.aio_infer({"model_in": sample})
37+
assert np.isclose(result, sample).all()
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_multiple_tasks(async_config):
42+
n_multiple_tasks = 10
43+
protocol, port = async_config
44+
print(f"Testing {protocol}:{port}")
45+
46+
client_list = [get_client(*async_config, model_name="sample_sleep_1sec") for _ in range(n_multiple_tasks)]
47+
48+
sample = np.random.rand(1, 100).astype(np.float32)
49+
tasks = [client.aio_infer(sample) for client in client_list]
50+
51+
start_time = asyncio.get_event_loop().time()
52+
results = await asyncio.gather(*tasks)
53+
end_time = asyncio.get_event_loop().time()
54+
55+
for result in results:
56+
assert np.isclose(result, sample).all()
57+
58+
assert (end_time - start_time) < 2, f"Time taken: {end_time - start_time}"

0 commit comments

Comments
 (0)