|
33 | 33 | runtime_checkable,
|
34 | 34 | )
|
35 | 35 |
|
36 |
| -import numpy |
| 36 | +import numpy as np |
37 | 37 | from pydantic import Field, PrivateAttr
|
38 | 38 |
|
39 | 39 | from guidellm.backend import (
|
|
70 | 70 | "CompilableAggregator",
|
71 | 71 | "GenerativeRequestsAggregator",
|
72 | 72 | "GenerativeStatsProgressAggregator",
|
| 73 | + "InjectExtrasAggregator", |
73 | 74 | "SchedulerStatsAggregator",
|
74 | 75 | "SerializableAggregator",
|
75 | 76 | ]
|
@@ -284,6 +285,47 @@ def compile(
|
284 | 285 | """
|
285 | 286 |
|
286 | 287 |
|
| 288 | +@SerializableAggregator.register("inject_extras") |
| 289 | +class InjectExtrasAggregator( |
| 290 | + SerializableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], InfoMixin |
| 291 | +): |
| 292 | + """ |
| 293 | + Aggregator for injecting extra metadata into the output. |
| 294 | + """ |
| 295 | + |
| 296 | + @classmethod |
| 297 | + def validated_kwargs(cls, extras: dict[str, Any], **kwargs) -> dict[str, Any]: |
| 298 | + return {"extras": extras} |
| 299 | + |
| 300 | + type_: Literal["inject_extras"] = Field(default="inject_extras") |
| 301 | + extras: dict[str, Any] | None = Field(default_factory=None) |
| 302 | + |
| 303 | + def __call__( |
| 304 | + self, |
| 305 | + agg_state: dict[str, Any], |
| 306 | + response: ResponseT | None, |
| 307 | + request: RequestT, |
| 308 | + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], |
| 309 | + scheduler_state: SchedulerState, |
| 310 | + ) -> dict[str, Any] | None: |
| 311 | + """ |
| 312 | + Inject extra metadata into the aggregation state. |
| 313 | +
|
| 314 | + :param agg_state: Current aggregation state to update. |
| 315 | + :param response: Response generated for the request, if successful. |
| 316 | + :param request: The processed request object. |
| 317 | + :param request_info: Scheduling metadata and timing information. |
| 318 | + :param scheduler_state: Current scheduler execution state. |
| 319 | + :return: Updated aggregation state with injected extras. |
| 320 | + """ |
| 321 | + return None |
| 322 | + |
| 323 | + def compile( |
| 324 | + self, agg_state: dict[str, Any], scheduler_state: SchedulerState |
| 325 | + ) -> dict[str, Any]: |
| 326 | + return {"extras": self.extras} if self.extras else {} |
| 327 | + |
| 328 | + |
287 | 329 | @SerializableAggregator.register("scheduler_stats")
|
288 | 330 | class SchedulerStatsAggregator(
|
289 | 331 | SerializableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], InfoMixin
|
@@ -600,7 +642,7 @@ def __call__(
|
600 | 642 | self.add_aggregate_metric_rate(f"{prefix}prompt_tokens", agg_state)
|
601 | 643 | )
|
602 | 644 | self.add_aggregate_metric(
|
603 |
| - f"{prefix}prompt_tokens", agg_state, response.prompt_tokens |
| 645 | + "prompt_tokens", agg_state, response.prompt_tokens |
604 | 646 | )
|
605 | 647 | agg_state["prompt_tokens_per_request"] = self.add_aggregate_metric_rate(
|
606 | 648 | "prompt_tokens", agg_state
|
@@ -842,26 +884,32 @@ def compile(
|
842 | 884 | "requests": StatusBreakdown(
|
843 | 885 | successful=(
|
844 | 886 | list(
|
845 |
| - numpy.random.choice( |
846 |
| - successful, size=self.request_samples, replace=False |
| 887 | + np.random.choice( |
| 888 | + successful, |
| 889 | + size=min(self.request_samples, len(successful)), |
| 890 | + replace=False, |
847 | 891 | )
|
848 | 892 | )
|
849 | 893 | if self.request_samples
|
850 | 894 | else successful
|
851 | 895 | ),
|
852 | 896 | incomplete=(
|
853 | 897 | list(
|
854 |
| - numpy.random.choice( |
855 |
| - incomplete, size=self.request_samples, replace=False |
| 898 | + np.random.choice( |
| 899 | + incomplete, |
| 900 | + size=min(self.request_samples, len(incomplete)), |
| 901 | + replace=False, |
856 | 902 | )
|
857 | 903 | )
|
858 | 904 | if self.request_samples
|
859 | 905 | else incomplete
|
860 | 906 | ),
|
861 | 907 | errored=(
|
862 | 908 | list(
|
863 |
| - numpy.random.choice( |
864 |
| - errored, size=self.request_samples, replace=False |
| 909 | + np.random.choice( |
| 910 | + errored, |
| 911 | + size=min(self.request_samples, len(errored)), |
| 912 | + replace=False, |
865 | 913 | )
|
866 | 914 | )
|
867 | 915 | if self.request_samples
|
|
0 commit comments