1
- import os
2
-
3
1
import numpy as np
4
- import pytest
5
2
6
3
from tritony import InferenceClient
7
4
8
- TRITON_HOST = os .environ .get ("TRITON_HOST" , "localhost" )
9
- TRITON_HTTP = os .environ .get ("TRITON_HTTP" , "8000" )
10
- TRITON_GRPC = os .environ .get ("TRITON_GRPC" , "8001" )
11
-
5
+ from .common_fixtures import TRITON_HOST , config
12
6
13
7
EPSILON = 1e-8
8
+ __all__ = ["config" ]
14
9
15
10
16
- @pytest .fixture (params = [("http" , TRITON_HTTP ), ("grpc" , TRITON_GRPC )])
17
- def protocol_and_port (request ):
18
- return request .param
19
-
11
+ def get_client (protocol , port , run_async , model_name ):
12
+ print (f"Testing { protocol } with run_async={ run_async } " , flush = True )
13
+ return InferenceClient .create_with (model_name , f"{ TRITON_HOST } :{ port } " , protocol = protocol , run_async = run_async )
20
14
21
- def get_client (protocol , port , model_name ):
22
- print (f"Testing { protocol } " , flush = True )
23
- return InferenceClient .create_with (model_name , f"{ TRITON_HOST } :{ port } " , protocol = protocol )
24
15
25
-
26
- def test_swithcing (protocol_and_port ):
27
- client = get_client (* protocol_and_port , model_name = "sample" )
16
+ def test_swithcing (config ):
17
+ client = get_client (* config , model_name = "sample" )
28
18
29
19
sample = np .random .rand (1 , 100 ).astype (np .float32 )
30
20
result = client (sample )
@@ -35,16 +25,16 @@ def test_swithcing(protocol_and_port):
35
25
assert np .isclose (result , sample ).all ()
36
26
37
27
38
- def test_with_input_name (protocol_and_port ):
39
- client = get_client (* protocol_and_port , model_name = "sample" )
28
+ def test_with_input_name (config ):
29
+ client = get_client (* config , model_name = "sample" )
40
30
41
31
sample = np .random .rand (100 , 100 ).astype (np .float32 )
42
32
result = client ({client .default_model_spec .model_input [0 ].name : sample })
43
33
assert np .isclose (result , sample ).all ()
44
34
45
35
46
- def test_with_parameters (protocol_and_port ):
47
- client = get_client (* protocol_and_port , model_name = "sample" )
36
+ def test_with_parameters (config ):
37
+ client = get_client (* config , model_name = "sample" )
48
38
49
39
sample = np .random .rand (1 , 100 ).astype (np .float32 )
50
40
ADD_VALUE = 1
@@ -53,8 +43,8 @@ def test_with_parameters(protocol_and_port):
53
43
assert np .isclose (result [0 ], sample [0 ] + ADD_VALUE ).all ()
54
44
55
45
56
- def test_with_optional (protocol_and_port ):
57
- client = get_client (* protocol_and_port , model_name = "sample_optional" )
46
+ def test_with_optional (config ):
47
+ client = get_client (* config , model_name = "sample_optional" )
58
48
59
49
sample = np .random .rand (1 , 100 ).astype (np .float32 )
60
50
@@ -71,16 +61,11 @@ def test_with_optional(protocol_and_port):
71
61
assert np .isclose (result [0 ], sample [0 ] - OPTIONAL_SUB_VALUE , rtol = EPSILON ).all ()
72
62
73
63
74
- def test_reload_model_spec (protocol_and_port ):
75
- client = get_client (* protocol_and_port , model_name = "sample_autobatching" )
64
+ def test_reload_model_spec (config ):
65
+ client = get_client (* config , model_name = "sample_autobatching" )
76
66
# force to change max_batch_size
77
67
client .default_model_spec .max_batch_size = 4
78
68
79
69
sample = np .random .rand (8 , 100 ).astype (np .float32 )
80
70
result = client (sample )
81
71
assert np .isclose (result , sample ).all ()
82
-
83
-
84
- if __name__ == "__main__" :
85
- test_with_parameters (("grpc" , "8101" ))
86
- test_with_optional (("grpc" , "8101" ))
0 commit comments