Skip to content

Commit fad0ebc

Browse files
committed
Run multiple inferences
1 parent d50bda1 commit fad0ebc

File tree

2 files changed

+68
-42
lines changed

2 files changed

+68
-42
lines changed

ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def setUp(self):
4949
"The future of AI is",
5050
]
5151
self.sampling_parameters = {"temperature": "0", "top_p": "1"}
52+
self.inference_count = 2
5253

5354
def parse_vllm_metrics(self):
5455
"""
@@ -82,34 +83,36 @@ def vllm_infer(
8283
user_data = UserData()
8384
number_of_vllm_reqs = len(prompts)
8485

85-
self.triton_client.start_stream(callback=partial(callback, user_data))
86-
for i in range(number_of_vllm_reqs):
87-
request_data = create_vllm_request(
88-
prompts[i],
89-
i,
90-
False,
91-
sampling_parameters,
92-
model_name,
93-
True,
94-
)
95-
self.triton_client.async_stream_infer(
96-
model_name=model_name,
97-
inputs=request_data["inputs"],
98-
request_id=request_data["request_id"],
99-
outputs=request_data["outputs"],
100-
parameters=sampling_parameters,
101-
)
102-
103-
for _ in range(number_of_vllm_reqs):
104-
result = user_data._completed_requests.get()
105-
if type(result) is InferenceServerException:
106-
print(result.message())
107-
self.assertIsNot(type(result), InferenceServerException, str(result))
108-
109-
output = result.as_numpy("text_output")
110-
self.assertIsNotNone(output, "`text_output` should not be None")
111-
112-
self.triton_client.stop_stream()
86+
# Run the inference twice in case metrics are updated but engine crashes.
87+
for _ in range(self.inference_count):
88+
self.triton_client.start_stream(callback=partial(callback, user_data))
89+
for i in range(number_of_vllm_reqs):
90+
request_data = create_vllm_request(
91+
prompts[i],
92+
i,
93+
False,
94+
sampling_parameters,
95+
model_name,
96+
True,
97+
)
98+
self.triton_client.async_stream_infer(
99+
model_name=model_name,
100+
inputs=request_data["inputs"],
101+
request_id=request_data["request_id"],
102+
outputs=request_data["outputs"],
103+
parameters=sampling_parameters,
104+
)
105+
106+
for _ in range(number_of_vllm_reqs):
107+
result = user_data._completed_requests.get()
108+
if type(result) is InferenceServerException:
109+
print(result.message())
110+
self.assertIsNot(type(result), InferenceServerException, str(result))
111+
112+
output = result.as_numpy("text_output")
113+
self.assertIsNotNone(output, "`text_output` should not be None")
114+
115+
self.triton_client.stop_stream()
113116

114117
def test_vllm_metrics(self):
115118
# Test vLLM metrics
@@ -125,49 +128,70 @@ def test_vllm_metrics(self):
125128
# (2, 133, 144, 2702, 3477, 16)
126129
# (2, 133, 812, 9, 1470, 16)
127130
# (2, 133, 499, 9, 4687, 16)
128-
self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18)
131+
self.assertEqual(
132+
metrics_dict["vllm:prompt_tokens_total"], 18 * self.inference_count
133+
)
129134
# vllm:generation_tokens_total
130135
# (5, 65, 14, 16, 144, 533, 7, 28, 848, 30, 10, 512, 4, 50118, 100, 437)
131136
# (5, 812, 9, 5, 1515, 3497, 4, 50118, 50118, 133, 812, 9, 1470, 16, 5, 812)
132137
# (11, 5, 1420, 9, 5, 82, 4, 50118, 50118, 133, 499, 9, 4687, 16, 11, 5)
133-
self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48)
138+
self.assertEqual(
139+
metrics_dict["vllm:generation_tokens_total"], 48 * self.inference_count
140+
)
134141
# vllm:time_to_first_token_seconds
135142
self.assertEqual(
136-
metrics_dict["vllm:time_to_first_token_seconds_count"], total_prompts
143+
metrics_dict["vllm:time_to_first_token_seconds_count"],
144+
total_prompts * self.inference_count,
137145
)
138146
self.assertGreater(metrics_dict["vllm:time_to_first_token_seconds_sum"], 0)
139147
self.assertEqual(
140-
metrics_dict["vllm:time_to_first_token_seconds_bucket"], total_prompts
148+
metrics_dict["vllm:time_to_first_token_seconds_bucket"],
149+
total_prompts * self.inference_count,
141150
)
142151
# vllm:time_per_output_token_seconds
143-
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45)
152+
self.assertEqual(
153+
metrics_dict["vllm:time_per_output_token_seconds_count"],
154+
45 * self.inference_count,
155+
)
144156
self.assertGreater(metrics_dict["vllm:time_per_output_token_seconds_sum"], 0)
145-
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_bucket"], 45)
157+
self.assertEqual(
158+
metrics_dict["vllm:time_per_output_token_seconds_bucket"],
159+
45 * self.inference_count,
160+
)
146161
# vllm:e2e_request_latency_seconds
147162
self.assertEqual(
148-
metrics_dict["vllm:e2e_request_latency_seconds_count"], total_prompts
163+
metrics_dict["vllm:e2e_request_latency_seconds_count"],
164+
total_prompts * self.inference_count,
149165
)
150166
self.assertGreater(metrics_dict["vllm:e2e_request_latency_seconds_sum"], 0)
151167
self.assertEqual(
152-
metrics_dict["vllm:e2e_request_latency_seconds_bucket"], total_prompts
168+
metrics_dict["vllm:e2e_request_latency_seconds_bucket"],
169+
total_prompts * self.inference_count,
153170
)
154171
# vllm:request_prompt_tokens
155172
self.assertEqual(
156-
metrics_dict["vllm:request_prompt_tokens_count"], total_prompts
173+
metrics_dict["vllm:request_prompt_tokens_count"],
174+
total_prompts * self.inference_count,
175+
)
176+
self.assertEqual(
177+
metrics_dict["vllm:request_prompt_tokens_sum"], 18 * self.inference_count
157178
)
158-
self.assertEqual(metrics_dict["vllm:request_prompt_tokens_sum"], 18)
159179
self.assertEqual(
160-
metrics_dict["vllm:request_prompt_tokens_bucket"], total_prompts
180+
metrics_dict["vllm:request_prompt_tokens_bucket"],
181+
total_prompts * self.inference_count,
161182
)
162183
# vllm:request_generation_tokens
163184
self.assertEqual(
164185
metrics_dict["vllm:request_generation_tokens_count"],
165-
total_prompts,
186+
total_prompts * self.inference_count,
187+
)
188+
self.assertEqual(
189+
metrics_dict["vllm:request_generation_tokens_sum"],
190+
48 * self.inference_count,
166191
)
167-
self.assertEqual(metrics_dict["vllm:request_generation_tokens_sum"], 48)
168192
self.assertEqual(
169193
metrics_dict["vllm:request_generation_tokens_bucket"],
170-
total_prompts,
194+
total_prompts * self.inference_count,
171195
)
172196

173197
# TODO: Revisit this test due to the removal of best_of

src/utils/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def record(
240240
Returns:
241241
None
242242
"""
243+
if iteration_stats is None:
244+
return
243245

244246
# Parse finished request stats into lists
245247
e2e_latency: List[float] = []

0 commit comments

Comments
 (0)