Skip to content

Commit c00ac91

Browse files
committed
Fixes for e2e enablement
1 parent 6f3d753 commit c00ac91

File tree

17 files changed

+1539
-435
lines changed

17 files changed

+1539
-435
lines changed

src/guidellm/__main__.py

Lines changed: 194 additions & 164 deletions
Large diffs are not rendered by default.

src/guidellm/benchmark/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
CompilableAggregator,
44
GenerativeRequestsAggregator,
55
GenerativeStatsProgressAggregator,
6+
InjectExtrasAggregator,
67
SchedulerStatsAggregator,
78
SerializableAggregator,
89
)
@@ -62,6 +63,7 @@
6263
"GenerativeRequestStats",
6364
"GenerativeRequestsAggregator",
6465
"GenerativeStatsProgressAggregator",
66+
"InjectExtrasAggregator",
6567
"Profile",
6668
"ProfileType",
6769
"SchedulerStatsAggregator",

src/guidellm/benchmark/aggregator.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
runtime_checkable,
3434
)
3535

36-
import numpy
36+
import numpy as np
3737
from pydantic import Field, PrivateAttr
3838

3939
from guidellm.backend import (
@@ -70,6 +70,7 @@
7070
"CompilableAggregator",
7171
"GenerativeRequestsAggregator",
7272
"GenerativeStatsProgressAggregator",
73+
"InjectExtrasAggregator",
7374
"SchedulerStatsAggregator",
7475
"SerializableAggregator",
7576
]
@@ -284,6 +285,47 @@ def compile(
284285
"""
285286

286287

288+
@SerializableAggregator.register("inject_extras")
289+
class InjectExtrasAggregator(
290+
SerializableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], InfoMixin
291+
):
292+
"""
293+
Aggregator for injecting extra metadata into the output.
294+
"""
295+
296+
@classmethod
297+
def validated_kwargs(cls, extras: dict[str, Any], **kwargs) -> dict[str, Any]:
298+
return {"extras": extras}
299+
300+
type_: Literal["inject_extras"] = Field(default="inject_extras")
301+
extras: dict[str, Any] | None = Field(default_factory=None)
302+
303+
def __call__(
304+
self,
305+
agg_state: dict[str, Any],
306+
response: ResponseT | None,
307+
request: RequestT,
308+
request_info: ScheduledRequestInfo[MeasuredRequestTimingsT],
309+
scheduler_state: SchedulerState,
310+
) -> dict[str, Any] | None:
311+
"""
312+
Inject extra metadata into the aggregation state.
313+
314+
:param agg_state: Current aggregation state to update.
315+
:param response: Response generated for the request, if successful.
316+
:param request: The processed request object.
317+
:param request_info: Scheduling metadata and timing information.
318+
:param scheduler_state: Current scheduler execution state.
319+
:return: Updated aggregation state with injected extras.
320+
"""
321+
return None
322+
323+
def compile(
324+
self, agg_state: dict[str, Any], scheduler_state: SchedulerState
325+
) -> dict[str, Any]:
326+
return {"extras": self.extras} if self.extras else {}
327+
328+
287329
@SerializableAggregator.register("scheduler_stats")
288330
class SchedulerStatsAggregator(
289331
SerializableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], InfoMixin
@@ -600,7 +642,7 @@ def __call__(
600642
self.add_aggregate_metric_rate(f"{prefix}prompt_tokens", agg_state)
601643
)
602644
self.add_aggregate_metric(
603-
f"{prefix}prompt_tokens", agg_state, response.prompt_tokens
645+
"prompt_tokens", agg_state, response.prompt_tokens
604646
)
605647
agg_state["prompt_tokens_per_request"] = self.add_aggregate_metric_rate(
606648
"prompt_tokens", agg_state
@@ -842,26 +884,32 @@ def compile(
842884
"requests": StatusBreakdown(
843885
successful=(
844886
list(
845-
numpy.random.choice(
846-
successful, size=self.request_samples, replace=False
887+
np.random.choice(
888+
successful,
889+
size=min(self.request_samples, len(successful)),
890+
replace=False,
847891
)
848892
)
849893
if self.request_samples
850894
else successful
851895
),
852896
incomplete=(
853897
list(
854-
numpy.random.choice(
855-
incomplete, size=self.request_samples, replace=False
898+
np.random.choice(
899+
incomplete,
900+
size=min(self.request_samples, len(incomplete)),
901+
replace=False,
856902
)
857903
)
858904
if self.request_samples
859905
else incomplete
860906
),
861907
errored=(
862908
list(
863-
numpy.random.choice(
864-
errored, size=self.request_samples, replace=False
909+
np.random.choice(
910+
errored,
911+
size=min(self.request_samples, len(errored)),
912+
replace=False,
865913
)
866914
)
867915
if self.request_samples

0 commit comments

Comments
 (0)