Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 66 additions & 42 deletions ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def setUp(self):
"The future of AI is",
]
self.sampling_parameters = {"temperature": "0", "top_p": "1"}
self.inference_count = 2

def parse_vllm_metrics(self):
"""
Expand Down Expand Up @@ -82,34 +83,36 @@ def vllm_infer(
user_data = UserData()
number_of_vllm_reqs = len(prompts)

self.triton_client.start_stream(callback=partial(callback, user_data))
for i in range(number_of_vllm_reqs):
request_data = create_vllm_request(
prompts[i],
i,
False,
sampling_parameters,
model_name,
True,
)
self.triton_client.async_stream_infer(
model_name=model_name,
inputs=request_data["inputs"],
request_id=request_data["request_id"],
outputs=request_data["outputs"],
parameters=sampling_parameters,
)

for _ in range(number_of_vllm_reqs):
result = user_data._completed_requests.get()
if type(result) is InferenceServerException:
print(result.message())
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")

self.triton_client.stop_stream()
# Run the inference twice in case metrics are updated but engine crashes.
for _ in range(self.inference_count):
self.triton_client.start_stream(callback=partial(callback, user_data))
for i in range(number_of_vllm_reqs):
request_data = create_vllm_request(
prompts[i],
i,
False,
sampling_parameters,
model_name,
True,
)
self.triton_client.async_stream_infer(
model_name=model_name,
inputs=request_data["inputs"],
request_id=request_data["request_id"],
outputs=request_data["outputs"],
parameters=sampling_parameters,
)

for _ in range(number_of_vllm_reqs):
result = user_data._completed_requests.get()
if type(result) is InferenceServerException:
print(result.message())
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")

self.triton_client.stop_stream()

def test_vllm_metrics(self):
# Test vLLM metrics
Expand All @@ -125,49 +128,70 @@ def test_vllm_metrics(self):
# (2, 133, 144, 2702, 3477, 16)
# (2, 133, 812, 9, 1470, 16)
# (2, 133, 499, 9, 4687, 16)
self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18)
self.assertEqual(
metrics_dict["vllm:prompt_tokens_total"], 18 * self.inference_count
)
# vllm:generation_tokens_total
# (5, 65, 14, 16, 144, 533, 7, 28, 848, 30, 10, 512, 4, 50118, 100, 437)
# (5, 812, 9, 5, 1515, 3497, 4, 50118, 50118, 133, 812, 9, 1470, 16, 5, 812)
# (11, 5, 1420, 9, 5, 82, 4, 50118, 50118, 133, 499, 9, 4687, 16, 11, 5)
self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48)
self.assertEqual(
metrics_dict["vllm:generation_tokens_total"], 48 * self.inference_count
)
# vllm:time_to_first_token_seconds
self.assertEqual(
metrics_dict["vllm:time_to_first_token_seconds_count"], total_prompts
metrics_dict["vllm:time_to_first_token_seconds_count"],
total_prompts * self.inference_count,
)
self.assertGreater(metrics_dict["vllm:time_to_first_token_seconds_sum"], 0)
self.assertEqual(
metrics_dict["vllm:time_to_first_token_seconds_bucket"], total_prompts
metrics_dict["vllm:time_to_first_token_seconds_bucket"],
total_prompts * self.inference_count,
)
# vllm:time_per_output_token_seconds
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45)
self.assertEqual(
metrics_dict["vllm:time_per_output_token_seconds_count"],
45 * self.inference_count,
)
self.assertGreater(metrics_dict["vllm:time_per_output_token_seconds_sum"], 0)
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_bucket"], 45)
self.assertEqual(
metrics_dict["vllm:time_per_output_token_seconds_bucket"],
45 * self.inference_count,
)
# vllm:e2e_request_latency_seconds
self.assertEqual(
metrics_dict["vllm:e2e_request_latency_seconds_count"], total_prompts
metrics_dict["vllm:e2e_request_latency_seconds_count"],
total_prompts * self.inference_count,
)
self.assertGreater(metrics_dict["vllm:e2e_request_latency_seconds_sum"], 0)
self.assertEqual(
metrics_dict["vllm:e2e_request_latency_seconds_bucket"], total_prompts
metrics_dict["vllm:e2e_request_latency_seconds_bucket"],
total_prompts * self.inference_count,
)
# vllm:request_prompt_tokens
self.assertEqual(
metrics_dict["vllm:request_prompt_tokens_count"], total_prompts
metrics_dict["vllm:request_prompt_tokens_count"],
total_prompts * self.inference_count,
)
self.assertEqual(
metrics_dict["vllm:request_prompt_tokens_sum"], 18 * self.inference_count
)
self.assertEqual(metrics_dict["vllm:request_prompt_tokens_sum"], 18)
self.assertEqual(
metrics_dict["vllm:request_prompt_tokens_bucket"], total_prompts
metrics_dict["vllm:request_prompt_tokens_bucket"],
total_prompts * self.inference_count,
)
# vllm:request_generation_tokens
self.assertEqual(
metrics_dict["vllm:request_generation_tokens_count"],
total_prompts,
total_prompts * self.inference_count,
)
self.assertEqual(
metrics_dict["vllm:request_generation_tokens_sum"],
48 * self.inference_count,
)
self.assertEqual(metrics_dict["vllm:request_generation_tokens_sum"], 48)
self.assertEqual(
metrics_dict["vllm:request_generation_tokens_bucket"],
total_prompts,
total_prompts * self.inference_count,
)

# TODO: Revisit this test due to the removal of best_of
Expand Down
2 changes: 2 additions & 0 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def record(
Returns:
None
"""
if iteration_stats is None:
return

# Parse finished request stats into lists
e2e_latency: List[float] = []
Expand Down
Loading