Skip to content

Commit 06292d4

Browse files
feat: termination reason present and E2E tested
1 parent 24c6ea9 commit 06292d4

File tree

12 files changed

+195
-75
lines changed

12 files changed

+195
-75
lines changed

src/guidellm/benchmark/aggregator.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
import time
22
from abc import ABC, abstractmethod
33
from pathlib import Path
4-
from typing import (
5-
Any,
6-
Generic,
7-
Literal,
8-
Optional,
9-
TypeVar,
10-
Union,
11-
)
4+
from typing import Any, Generic, Literal, Optional, TypeVar, Union, get_args
125

136
from pydantic import Field
147

158
from guidellm.backend import ResponseSummary
169
from guidellm.benchmark.benchmark import (
10+
REASON_STATUS_MAPPING,
1711
BenchmarkArgs,
1812
BenchmarkRunStats,
1913
BenchmarkT,
@@ -40,6 +34,7 @@
4034
SchedulerRequestResult,
4135
WorkerDescription,
4236
)
37+
from guidellm.scheduler.result import TerminationReason
4338
from guidellm.utils import check_load_processor
4439

4540
__all__ = [
@@ -305,6 +300,12 @@ class BenchmarkAggregator(
305300
total=None,
306301
),
307302
)
303+
termination_reason: TerminationReason = Field(
304+
description=(
305+
f"The benchmark termination reason, one of: {get_args(TerminationReason)}"
306+
),
307+
default="interrupted",
308+
)
308309

309310
def add_result(
310311
self,
@@ -444,6 +445,9 @@ def add_result(
444445

445446
return True
446447

448+
def set_termination_reason(self, termination_reason: TerminationReason) -> None:
449+
self.termination_reason = termination_reason
450+
447451
@abstractmethod
448452
def compile(self) -> BenchmarkT:
449453
"""
@@ -602,6 +606,8 @@ def compile(self) -> GenerativeBenchmark:
602606

603607
error_rate = self._calculate_error_rate()
604608

609+
termination_reason = self.termination_reason
610+
605611
return GenerativeBenchmark.from_stats(
606612
run_id=self.run_id,
607613
successful=successful,
@@ -628,6 +634,8 @@ def compile(self) -> GenerativeBenchmark:
628634
request_time_delay_avg=self.requests_stats.request_time_delay.mean,
629635
request_time_avg=self.requests_stats.request_time.mean,
630636
error_rate=error_rate,
637+
status=REASON_STATUS_MAPPING[termination_reason],
638+
termination_reason=termination_reason,
631639
),
632640
worker=self.worker_description,
633641
requests_loader=self.request_loader_description,

src/guidellm/benchmark/benchmark.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
import uuid
3-
from typing import Any, Literal, Optional, TypeVar, Union
3+
from typing import Any, Literal, Optional, TypeVar, Union, get_args
44

55
from pydantic import Field, computed_field
66

@@ -32,6 +32,7 @@
3232
ThroughputStrategy,
3333
WorkerDescription,
3434
)
35+
from guidellm.scheduler.result import TerminationReason
3536

3637
__all__ = [
3738
"Benchmark",
@@ -46,6 +47,14 @@
4647
"StatusBreakdown",
4748
]
4849

50+
BenchmarkStatus = Literal["success", "error", "interrupted"]
51+
REASON_STATUS_MAPPING: dict[TerminationReason, BenchmarkStatus] = {
52+
"interrupted": "interrupted",
53+
"max_error_reached": "error",
54+
"max_seconds_reached": "success",
55+
"max_requests_reached": "success",
56+
}
57+
4958

5059
class BenchmarkArgs(StandardBaseModel):
5160
"""
@@ -225,6 +234,18 @@ class BenchmarkRunStats(StandardBaseModel):
225234
"account incomplete requests."
226235
)
227236
)
237+
status: BenchmarkStatus = Field(
238+
description=(
239+
f"The status of the benchmark output, "
240+
f"one of the following options: {get_args(BenchmarkStatus)}."
241+
)
242+
)
243+
termination_reason: TerminationReason = Field(
244+
description=(
245+
"The reason for the benchmark termination, "
246+
f"one of the following options: {get_args(TerminationReason)}."
247+
)
248+
)
228249

229250

230251
class BenchmarkMetrics(StandardBaseModel):

src/guidellm/benchmark/benchmarker.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class BenchmarkerStrategyLimits(StandardBaseModel):
7474
description="Maximum duration (in seconds) to process requests per strategy.",
7575
ge=0,
7676
)
77-
max_error: Optional[float] = Field(
77+
max_error_per_strategy: Optional[float] = Field(
7878
description="Maximum error after which a "
7979
"benchmark will stop,"
8080
" either rate or fixed number",
@@ -105,6 +105,10 @@ def max_number(self) -> Optional[int]:
105105
def max_duration(self) -> Optional[float]:
106106
return self.max_duration_per_strategy
107107

108+
@property
109+
def max_error(self) -> Optional[float]:
110+
return self.max_error_per_strategy
111+
108112
@property
109113
def warmup_number(self) -> Optional[int]:
110114
if self.warmup_percent_per_strategy is None or self.max_number is None:
@@ -154,7 +158,7 @@ async def run(
154158
profile: Profile,
155159
max_number_per_strategy: Optional[int],
156160
max_duration_per_strategy: Optional[float],
157-
max_error: Optional[float],
161+
max_error_per_strategy: Optional[float],
158162
warmup_percent_per_strategy: Optional[float],
159163
cooldown_percent_per_strategy: Optional[float],
160164
) -> AsyncGenerator[
@@ -169,7 +173,7 @@ async def run(
169173
requests_loader_size=requests_loader_size,
170174
max_number_per_strategy=max_number_per_strategy,
171175
max_duration_per_strategy=max_duration_per_strategy,
172-
max_error=max_error,
176+
max_error_per_strategy=max_error_per_strategy,
173177
warmup_percent_per_strategy=warmup_percent_per_strategy,
174178
cooldown_percent_per_strategy=cooldown_percent_per_strategy,
175179
)
@@ -204,7 +208,7 @@ async def run(
204208
scheduling_strategy=scheduling_strategy,
205209
max_number=max_number_per_strategy,
206210
max_duration=max_duration_per_strategy,
207-
max_error=max_error,
211+
max_error=max_error_per_strategy,
208212
):
209213
if result.type_ == "run_start":
210214
yield BenchmarkerResult(
@@ -219,6 +223,9 @@ async def run(
219223
current_result=None,
220224
)
221225
elif result.type_ == "run_complete":
226+
aggregator.set_termination_reason(
227+
result.run_info.termination_reason
228+
)
222229
yield BenchmarkerResult(
223230
type_="scheduler_complete",
224231
start_time=start_time,

src/guidellm/benchmark/entrypoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def benchmark_generative_text(
120120
profile=profile,
121121
max_number_per_strategy=max_requests,
122122
max_duration_per_strategy=max_seconds,
123-
max_error=max_error,
123+
max_error_per_strategy=max_error,
124124
warmup_percent_per_strategy=warmup_percent,
125125
cooldown_percent_per_strategy=cooldown_percent,
126126
):

src/guidellm/scheduler/result.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919

2020
RequestStatus = Literal["success", "error"]
21+
TerminationReason = Literal[
22+
"interrupted", "max_error_reached", "max_seconds_reached", "max_requests_reached"
23+
]
2124

2225

2326
class SchedulerRunInfo(StandardBaseModel):
@@ -60,6 +63,8 @@ class SchedulerRunInfo(StandardBaseModel):
6063
completed_requests: int = 0
6164
errored_requests: int = 0
6265

66+
termination_reason: TerminationReason = "interrupted"
67+
6368

6469
class SchedulerRequestInfo(StandardBaseModel):
6570
"""

src/guidellm/scheduler/scheduler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ async def run(
177177
):
178178
shutdown_event.set()
179179
max_error_reached = True
180+
run_info.termination_reason = "max_error_reached"
180181
logger.info(
181182
f"Max error rate of "
182183
f"({iter_result.run_info.max_error}) "
@@ -394,11 +395,16 @@ def _add_requests(
394395
and added_count < settings.max_add_requests_per_loop
395396
):
396397
if run_info.created_requests >= run_info.end_number:
398+
if time.time() >= run_info.end_time - 1:
399+
run_info.termination_reason = "max_seconds_reached"
400+
else:
401+
run_info.termination_reason = "max_requests_reached"
397402
raise StopIteration
398403

399404
if (
400405
request_time := next(times_iter)
401406
) >= run_info.end_time or time.time() >= run_info.end_time:
407+
run_info.termination_reason = "max_seconds_reached"
402408
raise StopIteration
403409

404410
request = next(requests_iter)

tests/e2e/test_basic.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

tests/e2e/test_interrupted.py renamed to tests/e2e/test_failed_benchmark.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,15 @@ def test_interrupted_report(server: VllmSimServer):
7474
assert "errored" in requests
7575
errored = requests["errored"]
7676
assert len(errored) / (len(successful) + len(errored)) > max_error_rate
77+
78+
assert "run_stats" in benchmark
79+
run_stats = benchmark["run_stats"]
80+
assert "status" in run_stats
81+
status = run_stats["status"]
82+
assert status == "error"
83+
assert "termination_reason" in run_stats
84+
termination_reason = run_stats["termination_reason"]
85+
assert termination_reason == "max_error_reached"
7786
finally:
7887
if report_path.exists():
7988
report_path.unlink()

0 commit comments

Comments
 (0)