Skip to content

Commit da0998d

Browse files
Enable the run_async=False option for the InferenceClient. (#17)
* Add `tools.request` for `run_async=False` * Fix type annotations, Refactor tests --------- Co-authored-by: Dongwoo Arthur Kim <[email protected]>
1 parent 0ef6067 commit da0998d

File tree

5 files changed

+155
-70
lines changed

5 files changed

+155
-70
lines changed

tests/common_fixtures.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
3+
import pytest
4+
5+
MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
6+
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
7+
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8100")
8+
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8101")
9+
10+
11+
@pytest.fixture(params=[("http", TRITON_HTTP, True), ("grpc", TRITON_GRPC, True), ("grpc", TRITON_GRPC, False)])
12+
def config(request):
13+
"""
14+
Returns a tuple of (protocol, port, run_async)
15+
"""
16+
return request.param

tests/test_connect.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,18 @@
1-
import os
2-
31
import grpc
42
import numpy as np
5-
import pytest
63

74
from tritony import InferenceClient
85

9-
MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
10-
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
11-
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000")
12-
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001")
13-
6+
from .common_fixtures import MODEL_NAME, TRITON_HOST, config
147

15-
@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
16-
def protocol_and_port(request):
17-
return request.param
8+
__all__ = ["config"]
189

1910

20-
def test_basics(protocol_and_port):
21-
protocol, port = protocol_and_port
22-
print(f"Testing {protocol}")
11+
def test_basics(config):
12+
protocol, port, run_async = config
13+
print(f"Testing {protocol} with run_async={run_async}")
2314

24-
client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
15+
client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async)
2516

2617
sample = np.random.rand(1, 100).astype(np.float32)
2718
result = client(sample)
@@ -31,23 +22,27 @@ def test_basics(protocol_and_port):
3122
assert np.isclose(result, sample).all()
3223

3324

34-
def test_batching(protocol_and_port):
35-
protocol, port = protocol_and_port
36-
print(f"{__name__}, Testing {protocol}")
25+
def test_batching(config):
26+
protocol, port, run_async = config
27+
print(f"{__name__}, Testing {protocol} with run_async={run_async}")
3728

38-
client = InferenceClient.create_with("sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol)
29+
client = InferenceClient.create_with(
30+
"sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async
31+
)
3932

4033
sample = np.random.rand(100, 100).astype(np.float32)
4134
# client automatically makes sub batches with (50, 2, 100)
4235
result = client(sample)
4336
assert np.isclose(result, sample).all()
4437

4538

46-
def test_exception(protocol_and_port):
47-
protocol, port = protocol_and_port
48-
print(f"{__name__}, Testing {protocol}")
39+
def test_exception(config):
40+
protocol, port, run_async = config
41+
print(f"{__name__}, Testing {protocol} with run_async={run_async}")
4942

50-
client = InferenceClient.create_with("sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol)
43+
client = InferenceClient.create_with(
44+
"sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async
45+
)
5146

5247
sample = np.random.rand(100, 100, 100).astype(np.float32)
5348
# client automatically makes sub batches with (50, 2, 100)

tests/test_model_call.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
1-
import os
2-
31
import numpy as np
4-
import pytest
52

63
from tritony import InferenceClient
74

8-
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
9-
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000")
10-
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001")
11-
5+
from .common_fixtures import TRITON_HOST, config
126

137
EPSILON = 1e-8
8+
__all__ = ["config"]
149

1510

16-
@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
17-
def protocol_and_port(request):
18-
return request.param
19-
11+
def get_client(protocol, port, run_async, model_name):
12+
print(f"Testing {protocol} with run_async={run_async}", flush=True)
13+
return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async)
2014

21-
def get_client(protocol, port, model_name):
22-
print(f"Testing {protocol}", flush=True)
23-
return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol)
2415

25-
26-
def test_swithcing(protocol_and_port):
27-
client = get_client(*protocol_and_port, model_name="sample")
16+
def test_swithcing(config):
17+
client = get_client(*config, model_name="sample")
2818

2919
sample = np.random.rand(1, 100).astype(np.float32)
3020
result = client(sample)
@@ -35,16 +25,16 @@ def test_swithcing(protocol_and_port):
3525
assert np.isclose(result, sample).all()
3626

3727

38-
def test_with_input_name(protocol_and_port):
39-
client = get_client(*protocol_and_port, model_name="sample")
28+
def test_with_input_name(config):
29+
client = get_client(*config, model_name="sample")
4030

4131
sample = np.random.rand(100, 100).astype(np.float32)
4232
result = client({client.default_model_spec.model_input[0].name: sample})
4333
assert np.isclose(result, sample).all()
4434

4535

46-
def test_with_parameters(protocol_and_port):
47-
client = get_client(*protocol_and_port, model_name="sample")
36+
def test_with_parameters(config):
37+
client = get_client(*config, model_name="sample")
4838

4939
sample = np.random.rand(1, 100).astype(np.float32)
5040
ADD_VALUE = 1
@@ -53,8 +43,8 @@ def test_with_parameters(protocol_and_port):
5343
assert np.isclose(result[0], sample[0] + ADD_VALUE).all()
5444

5545

56-
def test_with_optional(protocol_and_port):
57-
client = get_client(*protocol_and_port, model_name="sample_optional")
46+
def test_with_optional(config):
47+
client = get_client(*config, model_name="sample_optional")
5848

5949
sample = np.random.rand(1, 100).astype(np.float32)
6050

@@ -71,16 +61,11 @@ def test_with_optional(protocol_and_port):
7161
assert np.isclose(result[0], sample[0] - OPTIONAL_SUB_VALUE, rtol=EPSILON).all()
7262

7363

74-
def test_reload_model_spec(protocol_and_port):
75-
client = get_client(*protocol_and_port, model_name="sample_autobatching")
64+
def test_reload_model_spec(config):
65+
client = get_client(*config, model_name="sample_autobatching")
7666
# force to change max_batch_size
7767
client.default_model_spec.max_batch_size = 4
7868

7969
sample = np.random.rand(8, 100).astype(np.float32)
8070
result = client(sample)
8171
assert np.isclose(result, sample).all()
82-
83-
84-
if __name__ == "__main__":
85-
test_with_parameters(("grpc", "8101"))
86-
test_with_optional(("grpc", "8101"))

tritony/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from enum import Enum
66
from types import SimpleNamespace
7-
from typing import Any, Optional, Union
7+
from typing import Any
88

99
import attrs
1010
from attrs import define
@@ -18,7 +18,7 @@ class TritonProtocol(Enum):
1818
http = "http"
1919

2020

21-
COMPRESSION_ALGORITHM_MAP = defaultdict(int)
21+
COMPRESSION_ALGORITHM_MAP: dict[str, int] = defaultdict(int)
2222
COMPRESSION_ALGORITHM_MAP.update({"deflate": 1, "gzip": 2})
2323

2424

@@ -83,13 +83,13 @@ class TritonClientFlag:
8383
concurrency: int = 6 # only for TritonProtocol.http client
8484
verbose: bool = False
8585
input_dims: int = 1
86-
compression_algorithm: Optional[str] = None
86+
compression_algorithm: str | None = None
8787
ssl: bool = False
8888

8989

9090
def init_triton_client(
9191
flag: TritonClientFlag,
92-
) -> Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient]:
92+
) -> grpcclient.InferenceServerClient | httpclient.InferenceServerClient:
9393
assert not (
9494
flag.streaming and not (flag.protocol is TritonProtocol.grpc)
9595
), "Streaming is only allowed with gRPC protocol"
@@ -107,7 +107,7 @@ def init_triton_client(
107107

108108

109109
def get_triton_client(
110-
triton_client: Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient],
110+
triton_client: grpcclient.InferenceServerClient | httpclient.InferenceServerClient,
111111
model_name: str,
112112
model_version: str,
113113
protocol: TritonProtocol,

0 commit comments

Comments
 (0)