38
38
39
39
from guidellm .backend import (
40
40
GenerationRequest ,
41
- GenerationRequestTimings ,
42
41
GenerationResponse ,
43
42
)
44
43
from guidellm .benchmark .objects import (
47
46
GenerativeRequestStats ,
48
47
)
49
48
from guidellm .scheduler import (
50
- MeasuredRequestTimingsT ,
51
49
RequestT ,
52
50
ResponseT ,
53
51
ScheduledRequestInfo ,
@@ -153,7 +151,7 @@ def get_metric(
153
151
154
152
155
153
@runtime_checkable
156
- class Aggregator (Protocol [ResponseT , RequestT , MeasuredRequestTimingsT ]):
154
+ class Aggregator (Protocol [ResponseT , RequestT ]):
157
155
"""
158
156
Protocol for processing benchmark data updates during execution.
159
157
@@ -167,7 +165,7 @@ def __call__(
167
165
state : AggregatorState ,
168
166
response : ResponseT | None ,
169
167
request : RequestT ,
170
- request_info : ScheduledRequestInfo [ MeasuredRequestTimingsT ] ,
168
+ request_info : ScheduledRequestInfo ,
171
169
scheduler_state : SchedulerState ,
172
170
) -> dict [str , Any ] | None :
173
171
"""
@@ -183,7 +181,7 @@ def __call__(
183
181
184
182
185
183
@runtime_checkable
186
- class CompilableAggregator (Protocol [ResponseT , RequestT , MeasuredRequestTimingsT ]):
184
+ class CompilableAggregator (Protocol [ResponseT , RequestT ]):
187
185
"""
188
186
Protocol for aggregators that compile final results from aggregated state.
189
187
@@ -196,7 +194,7 @@ def __call__(
196
194
state : AggregatorState ,
197
195
response : ResponseT | None ,
198
196
request : RequestT ,
199
- request_info : ScheduledRequestInfo [ MeasuredRequestTimingsT ] ,
197
+ request_info : ScheduledRequestInfo ,
200
198
scheduler_state : SchedulerState ,
201
199
) -> dict [str , Any ] | None :
202
200
"""
@@ -225,7 +223,7 @@ def compile(
225
223
class SerializableAggregator (
226
224
PydanticClassRegistryMixin [type ["SerializableAggregator" ]],
227
225
ABC ,
228
- Generic [ResponseT , RequestT , MeasuredRequestTimingsT ],
226
+ Generic [ResponseT , RequestT ],
229
227
):
230
228
schema_discriminator : ClassVar [str ] = "type_"
231
229
@@ -286,7 +284,7 @@ def __call__(
286
284
state : AggregatorState ,
287
285
response : ResponseT | None ,
288
286
request : RequestT ,
289
- request_info : ScheduledRequestInfo [ MeasuredRequestTimingsT ] ,
287
+ request_info : ScheduledRequestInfo ,
290
288
scheduler_state : SchedulerState ,
291
289
) -> dict [str , Any ] | None :
292
290
"""
@@ -314,9 +312,7 @@ def compile(
314
312
315
313
316
314
@SerializableAggregator .register ("inject_extras" )
317
- class InjectExtrasAggregator (
318
- SerializableAggregator [ResponseT , RequestT , MeasuredRequestTimingsT ], InfoMixin
319
- ):
315
+ class InjectExtrasAggregator (SerializableAggregator [ResponseT , RequestT ], InfoMixin ):
320
316
"""
321
317
Aggregator for injecting extra metadata into the output.
322
318
"""
@@ -333,7 +329,7 @@ def __call__(
333
329
state : AggregatorState ,
334
330
response : ResponseT | None ,
335
331
request : RequestT ,
336
- request_info : ScheduledRequestInfo [ MeasuredRequestTimingsT ] ,
332
+ request_info : ScheduledRequestInfo ,
337
333
scheduler_state : SchedulerState ,
338
334
) -> dict [str , Any ] | None :
339
335
"""
@@ -355,9 +351,7 @@ def compile(
355
351
356
352
357
353
@SerializableAggregator .register ("scheduler_stats" )
358
- class SchedulerStatsAggregator (
359
- SerializableAggregator [ResponseT , RequestT , MeasuredRequestTimingsT ], InfoMixin
360
- ):
354
+ class SchedulerStatsAggregator (SerializableAggregator [ResponseT , RequestT ], InfoMixin ):
361
355
"""
362
356
Aggregates scheduler timing and performance metrics.
363
357
@@ -376,7 +370,7 @@ def __call__(
376
370
state : AggregatorState ,
377
371
response : ResponseT | None ,
378
372
request : RequestT ,
379
- request_info : ScheduledRequestInfo [ MeasuredRequestTimingsT ] ,
373
+ request_info : ScheduledRequestInfo ,
380
374
scheduler_state : SchedulerState ,
381
375
) -> dict [str , Any ] | None :
382
376
"""
@@ -499,9 +493,7 @@ def compile(
499
493
500
494
@SerializableAggregator .register ("generative_stats_progress" )
501
495
class GenerativeStatsProgressAggregator (
502
- SerializableAggregator [
503
- GenerationResponse , GenerationRequest , GenerationRequestTimings
504
- ]
496
+ SerializableAggregator [GenerationResponse , GenerationRequest ]
505
497
):
506
498
"""
507
499
Tracks generative model metrics during benchmark execution.
@@ -523,7 +515,7 @@ def __call__(
523
515
state : AggregatorState ,
524
516
response : GenerationResponse | None ,
525
517
request : GenerationRequest ,
526
- request_info : ScheduledRequestInfo [ GenerationRequestTimings ] ,
518
+ request_info : ScheduledRequestInfo ,
527
519
scheduler_state : SchedulerState ,
528
520
) -> dict [str , Any ] | None :
529
521
"""
@@ -667,9 +659,7 @@ def compile(
667
659
668
660
@SerializableAggregator .register ("generative_requests" )
669
661
class GenerativeRequestsAggregator (
670
- SerializableAggregator [
671
- GenerationResponse , GenerationRequest , GenerationRequestTimings
672
- ],
662
+ SerializableAggregator [GenerationResponse , GenerationRequest ],
673
663
):
674
664
"""
675
665
Compiles complete generative benchmark results with warmup/cooldown filtering.
@@ -712,7 +702,7 @@ def __call__(
712
702
state : AggregatorState ,
713
703
response : GenerationResponse | None ,
714
704
request : GenerationRequest ,
715
- request_info : ScheduledRequestInfo [ GenerationRequestTimings ] ,
705
+ request_info : ScheduledRequestInfo ,
716
706
scheduler_state : SchedulerState ,
717
707
) -> dict [str , Any ] | None :
718
708
"""
@@ -875,7 +865,7 @@ def compile(
875
865
876
866
def _is_in_warmup (
877
867
self ,
878
- request_info : ScheduledRequestInfo [ GenerationRequestTimings ] ,
868
+ request_info : ScheduledRequestInfo ,
879
869
scheduler_state : SchedulerState ,
880
870
) -> bool :
881
871
"""Check if the current request is within the warmup period."""
@@ -902,7 +892,7 @@ def _is_in_warmup(
902
892
903
893
def _is_in_cooldown (
904
894
self ,
905
- request_info : ScheduledRequestInfo [ GenerationRequestTimings ] ,
895
+ request_info : ScheduledRequestInfo ,
906
896
scheduler_state : SchedulerState ,
907
897
) -> bool :
908
898
"""Check if the current request is within the cooldown period."""
@@ -936,7 +926,7 @@ def _create_generative_request_stats(
936
926
cls ,
937
927
response : GenerationResponse ,
938
928
request : GenerationRequest ,
939
- request_info : ScheduledRequestInfo [ GenerationRequestTimings ] ,
929
+ request_info : ScheduledRequestInfo ,
940
930
) -> GenerativeRequestStats :
941
931
prompt_tokens = response .preferred_prompt_tokens (
942
932
settings .preferred_prompt_tokens_source
0 commit comments