|
5 | 5 |
|
6 | 6 | from tritony import InferenceClient
|
7 | 7 |
|
8 |
| -MODEL_NAME = os.environ.get("MODEL_NAME", "sample") |
9 | 8 | TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
|
10 | 9 | TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000")
|
11 | 10 | TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001")
|
12 | 11 |
|
13 | 12 |
|
| 13 | +EPSILON = 1e-8 |
| 14 | + |
| 15 | + |
14 | 16 | @pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
|
15 | 17 | def protocol_and_port(request):
|
16 | 18 | return request.param
|
17 | 19 |
|
18 | 20 |
|
19 |
| -def get_client(protocol, port): |
| 21 | +def get_client(protocol, port, model_name): |
20 | 22 | print(f"Testing {protocol}", flush=True)
|
21 |
| - return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) |
| 23 | + return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol) |
22 | 24 |
|
23 | 25 |
|
24 | 26 | def test_swithcing(protocol_and_port):
|
25 |
| - client = get_client(*protocol_and_port) |
| 27 | + client = get_client(*protocol_and_port, model_name="sample") |
26 | 28 |
|
27 | 29 | sample = np.random.rand(1, 100).astype(np.float32)
|
28 | 30 | result = client(sample)
|
29 | 31 | assert {np.isclose(result, sample).all()}
|
30 | 32 |
|
31 | 33 | sample_batched = np.random.rand(100, 100).astype(np.float32)
|
32 | 34 | client(sample_batched, model_name="sample_autobatching")
|
33 |
| - assert {np.isclose(result, sample).all()} |
| 35 | + assert np.isclose(result, sample).all() |
34 | 36 |
|
35 | 37 |
|
36 | 38 | def test_with_input_name(protocol_and_port):
|
37 |
| - client = get_client(*protocol_and_port) |
| 39 | + client = get_client(*protocol_and_port, model_name="sample") |
38 | 40 |
|
39 | 41 | sample = np.random.rand(100, 100).astype(np.float32)
|
40 | 42 | result = client({client.default_model_spec.model_input[0].name: sample})
|
41 |
| - assert {np.isclose(result, sample).all()} |
| 43 | + assert np.isclose(result, sample).all() |
42 | 44 |
|
43 | 45 |
|
44 | 46 | def test_with_parameters(protocol_and_port):
|
45 |
| - client = get_client(*protocol_and_port) |
| 47 | + client = get_client(*protocol_and_port, model_name="sample") |
46 | 48 |
|
47 | 49 | sample = np.random.rand(1, 100).astype(np.float32)
|
48 | 50 | ADD_VALUE = 1
|
49 | 51 | result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"})
|
50 | 52 |
|
51 |
| - assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()} |
| 53 | + assert np.isclose(result[0], sample[0] + ADD_VALUE).all() |
| 54 | + |
| 55 | + |
| 56 | +def test_with_optional(protocol_and_port): |
| 57 | + client = get_client(*protocol_and_port, model_name="sample_optional") |
| 58 | + |
| 59 | + sample = np.random.rand(1, 100).astype(np.float32) |
| 60 | + |
| 61 | + result = client({client.default_model_spec.model_input[0].name: sample}) |
| 62 | + assert np.isclose(result[0], sample[0], rtol=EPSILON).all() |
| 63 | + |
| 64 | + OPTIONAL_SUB_VALUE = np.zeros_like(sample) + 3 |
| 65 | + result = client( |
| 66 | + { |
| 67 | + client.default_model_spec.model_input[0].name: sample, |
| 68 | + "optional_model_sub": OPTIONAL_SUB_VALUE, |
| 69 | + } |
| 70 | + ) |
| 71 | + assert np.isclose(result[0], sample[0] - OPTIONAL_SUB_VALUE, rtol=EPSILON).all() |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + test_with_parameters(("grpc", "8101")) |
| 76 | + test_with_optional(("grpc", "8101")) |
0 commit comments