@@ -70,46 +70,17 @@ def get_metrics(self):
7070
7171 return vllm_dict
7272
73- def test_vllm_metrics (self ):
74- # All vLLM metrics from tritonserver
75- expected_metrics_dict = {
76- "vllm:prompt_tokens_total" : 0 ,
77- "vllm:generation_tokens_total" : 0 ,
78- }
79-
80- # Test vLLM metrics
81- self ._test_vllm_model (
82- prompts = self .prompts ,
83- sampling_parameters = self .sampling_parameters ,
84- stream = False ,
85- send_parameters_as_tensor = True ,
86- model_name = self .vllm_model_name ,
87- )
88- expected_metrics_dict ["vllm:prompt_tokens_total" ] = 18
89- expected_metrics_dict ["vllm:generation_tokens_total" ] = 48
90- self .assertEqual (self .get_metrics (), expected_metrics_dict )
91-
92- self ._test_vllm_model (
93- prompts = self .prompts ,
94- sampling_parameters = self .sampling_parameters ,
95- stream = False ,
96- send_parameters_as_tensor = False ,
97- model_name = self .vllm_model_name ,
98- )
99- expected_metrics_dict ["vllm:prompt_tokens_total" ] = 36
100- expected_metrics_dict ["vllm:generation_tokens_total" ] = 96
101- self .assertEqual (self .get_metrics (), expected_metrics_dict )
102-
103- def _test_vllm_model (
73+ def vllm_async_stream_infer (
10474 self ,
10575 prompts ,
10676 sampling_parameters ,
10777 stream ,
10878 send_parameters_as_tensor ,
109- exclude_input_in_output = None ,
110- expected_output = None ,
111- model_name = "vllm_opt" ,
79+ model_name ,
11280 ):
81+ """
82+ Helper function to send async stream infer requests to vLLM.
83+ """
11384 user_data = UserData ()
11485 number_of_vllm_reqs = len (prompts )
11586
@@ -122,7 +93,6 @@ def _test_vllm_model(
12293 sampling_parameters ,
12394 model_name ,
12495 send_parameters_as_tensor ,
125- exclude_input_in_output = exclude_input_in_output ,
12696 )
12797 self .triton_client .async_stream_infer (
12898 model_name = model_name ,
@@ -132,26 +102,47 @@ def _test_vllm_model(
132102 parameters = sampling_parameters ,
133103 )
134104
135- for i in range (number_of_vllm_reqs ):
105+ for _ in range (number_of_vllm_reqs ):
136106 result = user_data ._completed_requests .get ()
137107 if type (result ) is InferenceServerException :
138108 print (result .message ())
139109 self .assertIsNot (type (result ), InferenceServerException , str (result ))
140110
141111 output = result .as_numpy ("text_output" )
142112 self .assertIsNotNone (output , "`text_output` should not be None" )
143- if expected_output is not None :
144- self .assertEqual (
145- output ,
146- expected_output [i ],
147- 'Actual and expected outputs do not match.\n \
148- Expected "{}" \n Actual:"{}"' .format (
149- output , expected_output [i ]
150- ),
151- )
152113
153114 self .triton_client .stop_stream ()
154115
116+ def test_vllm_metrics (self ):
117+ # All vLLM metrics from tritonserver
118+ expected_metrics_dict = {
119+ "vllm:prompt_tokens_total" : 0 ,
120+ "vllm:generation_tokens_total" : 0 ,
121+ }
122+
123+ # Test vLLM metrics
124+ self .vllm_async_stream_infer (
125+ prompts = self .prompts ,
126+ sampling_parameters = self .sampling_parameters ,
127+ stream = False ,
128+ send_parameters_as_tensor = True ,
129+ model_name = self .vllm_model_name ,
130+ )
131+ expected_metrics_dict ["vllm:prompt_tokens_total" ] = 18
132+ expected_metrics_dict ["vllm:generation_tokens_total" ] = 48
133+ self .assertEqual (self .get_metrics (), expected_metrics_dict )
134+
135+ self .vllm_async_stream_infer (
136+ prompts = self .prompts ,
137+ sampling_parameters = self .sampling_parameters ,
138+ stream = False ,
139+ send_parameters_as_tensor = False ,
140+ model_name = self .vllm_model_name ,
141+ )
142+ expected_metrics_dict ["vllm:prompt_tokens_total" ] = 36
143+ expected_metrics_dict ["vllm:generation_tokens_total" ] = 96
144+ self .assertEqual (self .get_metrics (), expected_metrics_dict )
145+
155146 def tearDown (self ):
156147 self .triton_client .close ()
157148
0 commit comments