diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index 0111056c..dca808c1 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -49,6 +49,7 @@ def setUp(self): "The future of AI is", ] self.sampling_parameters = {"temperature": "0", "top_p": "1"} + self.inference_count = 2 def parse_vllm_metrics(self): """ @@ -82,34 +83,36 @@ def vllm_infer( user_data = UserData() number_of_vllm_reqs = len(prompts) - self.triton_client.start_stream(callback=partial(callback, user_data)) - for i in range(number_of_vllm_reqs): - request_data = create_vllm_request( - prompts[i], - i, - False, - sampling_parameters, - model_name, - True, - ) - self.triton_client.async_stream_infer( - model_name=model_name, - inputs=request_data["inputs"], - request_id=request_data["request_id"], - outputs=request_data["outputs"], - parameters=sampling_parameters, - ) - - for _ in range(number_of_vllm_reqs): - result = user_data._completed_requests.get() - if type(result) is InferenceServerException: - print(result.message()) - self.assertIsNot(type(result), InferenceServerException, str(result)) - - output = result.as_numpy("text_output") - self.assertIsNotNone(output, "`text_output` should not be None") - - self.triton_client.stop_stream() + # Run the inference twice in case metrics are updated but engine crashes. + for _ in range(self.inference_count): + self.triton_client.start_stream(callback=partial(callback, user_data)) + for i in range(number_of_vllm_reqs): + request_data = create_vllm_request( + prompts[i], + i, + False, + sampling_parameters, + model_name, + True, + ) + self.triton_client.async_stream_infer( + model_name=model_name, + inputs=request_data["inputs"], + request_id=request_data["request_id"], + outputs=request_data["outputs"], + parameters=sampling_parameters, + ) + + for _ in range(number_of_vllm_reqs): + result = user_data._completed_requests.get() + if type(result) is InferenceServerException: + print(result.message()) + self.assertIsNot(type(result), InferenceServerException, str(result)) + + output = result.as_numpy("text_output") + self.assertIsNotNone(output, "`text_output` should not be None") + + self.triton_client.stop_stream() def test_vllm_metrics(self): # Test vLLM metrics @@ -125,49 +128,70 @@ def test_vllm_metrics(self): # (2, 133, 144, 2702, 3477, 16) # (2, 133, 812, 9, 1470, 16) # (2, 133, 499, 9, 4687, 16) - self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18) + self.assertEqual( + metrics_dict["vllm:prompt_tokens_total"], 18 * self.inference_count + ) # vllm:generation_tokens_total # (5, 65, 14, 16, 144, 533, 7, 28, 848, 30, 10, 512, 4, 50118, 100, 437) # (5, 812, 9, 5, 1515, 3497, 4, 50118, 50118, 133, 812, 9, 1470, 16, 5, 812) # (11, 5, 1420, 9, 5, 82, 4, 50118, 50118, 133, 499, 9, 4687, 16, 11, 5) - self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48) + self.assertEqual( + metrics_dict["vllm:generation_tokens_total"], 48 * self.inference_count + ) # vllm:time_to_first_token_seconds self.assertEqual( - metrics_dict["vllm:time_to_first_token_seconds_count"], total_prompts + metrics_dict["vllm:time_to_first_token_seconds_count"], + total_prompts * self.inference_count, ) self.assertGreater(metrics_dict["vllm:time_to_first_token_seconds_sum"], 0) self.assertEqual( - metrics_dict["vllm:time_to_first_token_seconds_bucket"], total_prompts + metrics_dict["vllm:time_to_first_token_seconds_bucket"], + total_prompts * self.inference_count, ) # vllm:time_per_output_token_seconds - self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45) + self.assertEqual( + metrics_dict["vllm:time_per_output_token_seconds_count"], + 45 * self.inference_count, + ) self.assertGreater(metrics_dict["vllm:time_per_output_token_seconds_sum"], 0) - self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_bucket"], 45) + self.assertEqual( + metrics_dict["vllm:time_per_output_token_seconds_bucket"], + 45 * self.inference_count, + ) # vllm:e2e_request_latency_seconds self.assertEqual( - metrics_dict["vllm:e2e_request_latency_seconds_count"], total_prompts + metrics_dict["vllm:e2e_request_latency_seconds_count"], + total_prompts * self.inference_count, ) self.assertGreater(metrics_dict["vllm:e2e_request_latency_seconds_sum"], 0) self.assertEqual( - metrics_dict["vllm:e2e_request_latency_seconds_bucket"], total_prompts + metrics_dict["vllm:e2e_request_latency_seconds_bucket"], + total_prompts * self.inference_count, ) # vllm:request_prompt_tokens self.assertEqual( - metrics_dict["vllm:request_prompt_tokens_count"], total_prompts + metrics_dict["vllm:request_prompt_tokens_count"], + total_prompts * self.inference_count, + ) + self.assertEqual( + metrics_dict["vllm:request_prompt_tokens_sum"], 18 * self.inference_count ) - self.assertEqual(metrics_dict["vllm:request_prompt_tokens_sum"], 18) self.assertEqual( - metrics_dict["vllm:request_prompt_tokens_bucket"], total_prompts + metrics_dict["vllm:request_prompt_tokens_bucket"], + total_prompts * self.inference_count, ) # vllm:request_generation_tokens self.assertEqual( metrics_dict["vllm:request_generation_tokens_count"], - total_prompts, + total_prompts * self.inference_count, + ) + self.assertEqual( + metrics_dict["vllm:request_generation_tokens_sum"], + 48 * self.inference_count, ) - self.assertEqual(metrics_dict["vllm:request_generation_tokens_sum"], 48) self.assertEqual( metrics_dict["vllm:request_generation_tokens_bucket"], - total_prompts, + total_prompts * self.inference_count, ) # TODO: Revisit this test due to the removal of best_of diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 644eb6d9..4f8eceee 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -240,6 +240,8 @@ def record( Returns: None """ + if iteration_stats is None: + return # Parse finished request stats into lists e2e_latency: List[float] = []