@@ -49,6 +49,7 @@ def setUp(self):
4949 "The future of AI is" ,
5050 ]
5151 self .sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
52+ self .inference_count = 2
5253
5354 def parse_vllm_metrics (self ):
5455 """
@@ -82,34 +83,36 @@ def vllm_infer(
8283 user_data = UserData ()
8384 number_of_vllm_reqs = len (prompts )
8485
85- self .triton_client .start_stream (callback = partial (callback , user_data ))
86- for i in range (number_of_vllm_reqs ):
87- request_data = create_vllm_request (
88- prompts [i ],
89- i ,
90- False ,
91- sampling_parameters ,
92- model_name ,
93- True ,
94- )
95- self .triton_client .async_stream_infer (
96- model_name = model_name ,
97- inputs = request_data ["inputs" ],
98- request_id = request_data ["request_id" ],
99- outputs = request_data ["outputs" ],
100- parameters = sampling_parameters ,
101- )
102-
103- for _ in range (number_of_vllm_reqs ):
104- result = user_data ._completed_requests .get ()
105- if type (result ) is InferenceServerException :
106- print (result .message ())
107- self .assertIsNot (type (result ), InferenceServerException , str (result ))
108-
109- output = result .as_numpy ("text_output" )
110- self .assertIsNotNone (output , "`text_output` should not be None" )
111-
112- self .triton_client .stop_stream ()
86+ # Run the inference twice in case metrics are updated but engine crashes.
87+ for _ in range (self .inference_count ):
88+ self .triton_client .start_stream (callback = partial (callback , user_data ))
89+ for i in range (number_of_vllm_reqs ):
90+ request_data = create_vllm_request (
91+ prompts [i ],
92+ i ,
93+ False ,
94+ sampling_parameters ,
95+ model_name ,
96+ True ,
97+ )
98+ self .triton_client .async_stream_infer (
99+ model_name = model_name ,
100+ inputs = request_data ["inputs" ],
101+ request_id = request_data ["request_id" ],
102+ outputs = request_data ["outputs" ],
103+ parameters = sampling_parameters ,
104+ )
105+
106+ for _ in range (number_of_vllm_reqs ):
107+ result = user_data ._completed_requests .get ()
108+ if type (result ) is InferenceServerException :
109+ print (result .message ())
110+ self .assertIsNot (type (result ), InferenceServerException , str (result ))
111+
112+ output = result .as_numpy ("text_output" )
113+ self .assertIsNotNone (output , "`text_output` should not be None" )
114+
115+ self .triton_client .stop_stream ()
113116
114117 def test_vllm_metrics (self ):
115118 # Test vLLM metrics
@@ -125,49 +128,70 @@ def test_vllm_metrics(self):
125128 # (2, 133, 144, 2702, 3477, 16)
126129 # (2, 133, 812, 9, 1470, 16)
127130 # (2, 133, 499, 9, 4687, 16)
128- self .assertEqual (metrics_dict ["vllm:prompt_tokens_total" ], 18 )
131+ self .assertEqual (
132+ metrics_dict ["vllm:prompt_tokens_total" ], 18 * self .inference_count
133+ )
129134 # vllm:generation_tokens_total
130135 # (5, 65, 14, 16, 144, 533, 7, 28, 848, 30, 10, 512, 4, 50118, 100, 437)
131136 # (5, 812, 9, 5, 1515, 3497, 4, 50118, 50118, 133, 812, 9, 1470, 16, 5, 812)
132137 # (11, 5, 1420, 9, 5, 82, 4, 50118, 50118, 133, 499, 9, 4687, 16, 11, 5)
133- self .assertEqual (metrics_dict ["vllm:generation_tokens_total" ], 48 )
138+ self .assertEqual (
139+ metrics_dict ["vllm:generation_tokens_total" ], 48 * self .inference_count
140+ )
134141 # vllm:time_to_first_token_seconds
135142 self .assertEqual (
136- metrics_dict ["vllm:time_to_first_token_seconds_count" ], total_prompts
143+ metrics_dict ["vllm:time_to_first_token_seconds_count" ],
144+ total_prompts * self .inference_count ,
137145 )
138146 self .assertGreater (metrics_dict ["vllm:time_to_first_token_seconds_sum" ], 0 )
139147 self .assertEqual (
140- metrics_dict ["vllm:time_to_first_token_seconds_bucket" ], total_prompts
148+ metrics_dict ["vllm:time_to_first_token_seconds_bucket" ],
149+ total_prompts * self .inference_count ,
141150 )
142151 # vllm:time_per_output_token_seconds
143- self .assertEqual (metrics_dict ["vllm:time_per_output_token_seconds_count" ], 45 )
152+ self .assertEqual (
153+ metrics_dict ["vllm:time_per_output_token_seconds_count" ],
154+ 45 * self .inference_count ,
155+ )
144156 self .assertGreater (metrics_dict ["vllm:time_per_output_token_seconds_sum" ], 0 )
145- self .assertEqual (metrics_dict ["vllm:time_per_output_token_seconds_bucket" ], 45 )
157+ self .assertEqual (
158+ metrics_dict ["vllm:time_per_output_token_seconds_bucket" ],
159+ 45 * self .inference_count ,
160+ )
146161 # vllm:e2e_request_latency_seconds
147162 self .assertEqual (
148- metrics_dict ["vllm:e2e_request_latency_seconds_count" ], total_prompts
163+ metrics_dict ["vllm:e2e_request_latency_seconds_count" ],
164+ total_prompts * self .inference_count ,
149165 )
150166 self .assertGreater (metrics_dict ["vllm:e2e_request_latency_seconds_sum" ], 0 )
151167 self .assertEqual (
152- metrics_dict ["vllm:e2e_request_latency_seconds_bucket" ], total_prompts
168+ metrics_dict ["vllm:e2e_request_latency_seconds_bucket" ],
169+ total_prompts * self .inference_count ,
153170 )
154171 # vllm:request_prompt_tokens
155172 self .assertEqual (
156- metrics_dict ["vllm:request_prompt_tokens_count" ], total_prompts
173+ metrics_dict ["vllm:request_prompt_tokens_count" ],
174+ total_prompts * self .inference_count ,
175+ )
176+ self .assertEqual (
177+ metrics_dict ["vllm:request_prompt_tokens_sum" ], 18 * self .inference_count
157178 )
158- self .assertEqual (metrics_dict ["vllm:request_prompt_tokens_sum" ], 18 )
159179 self .assertEqual (
160- metrics_dict ["vllm:request_prompt_tokens_bucket" ], total_prompts
180+ metrics_dict ["vllm:request_prompt_tokens_bucket" ],
181+ total_prompts * self .inference_count ,
161182 )
162183 # vllm:request_generation_tokens
163184 self .assertEqual (
164185 metrics_dict ["vllm:request_generation_tokens_count" ],
165- total_prompts ,
186+ total_prompts * self .inference_count ,
187+ )
188+ self .assertEqual (
189+ metrics_dict ["vllm:request_generation_tokens_sum" ],
190+ 48 * self .inference_count ,
166191 )
167- self .assertEqual (metrics_dict ["vllm:request_generation_tokens_sum" ], 48 )
168192 self .assertEqual (
169193 metrics_dict ["vllm:request_generation_tokens_bucket" ],
170- total_prompts ,
194+ total_prompts * self . inference_count ,
171195 )
172196
173197 # TODO: Revisit this test due to the removal of best_of
0 commit comments