26
26
from guidellm .scheduler .objects import (
27
27
BackendInterface ,
28
28
MultiTurnRequestT ,
29
+ MultiTurnT ,
29
30
RequestT ,
30
31
ResponseT ,
32
+ ScheduledRequestAugmentation ,
31
33
ScheduledRequestInfo ,
32
34
SchedulerMessagingPydanticRegistry ,
33
35
SchedulerState ,
@@ -471,9 +473,9 @@ def __init__(
471
473
472
474
def requests_generator (
473
475
self ,
474
- requests : Iterable [RequestT | MultiTurnRequestT [ RequestT ]] | None ,
475
- cycle_requests : Iterable [RequestT | MultiTurnRequestT [ RequestT ]] | None ,
476
- ) -> Generator [tuple [RequestT | MultiTurnRequestT [ RequestT ], ], None , None ]:
476
+ requests : Iterable [Iterable [ tuple [ RequestT , float ] ]] | None ,
477
+ cycle_requests : Iterable [Iterable [ tuple [ RequestT , float ] ]] | None ,
478
+ ) -> Generator [MultiTurnT [RequestT ], None , None ]:
477
479
"""
478
480
Generate request-info pairs for worker processing with constraint evaluation.
479
481
@@ -494,31 +496,40 @@ def _iter():
494
496
while True :
495
497
yield from cycle_requests
496
498
497
- count = 0
498
- request_info : ScheduledRequestInfo = None
499
+ count : int = 0
500
+ stop_queueing : bool = False
501
+
502
+ def _turn_iter (requests_chain : Iterable [tuple [RequestT , float ]]):
503
+ nonlocal count , stop_queueing
504
+ for request , delay in requests_chain :
505
+ count += 1
506
+
507
+ if hasattr (request , "request_id" ):
508
+ request_id = request .request_id
509
+ elif hasattr (request , "id" ):
510
+ request_id = request .id
511
+ else :
512
+ request_id = str (uuid .uuid4 ())
513
+ request_augmentation = ScheduledRequestAugmentation (
514
+ post_requeue_delay = delay
515
+ )
516
+ request_info : ScheduledRequestInfo = ScheduledRequestInfo (
517
+ request_id = request_id ,
518
+ status = "queued" ,
519
+ scheduler_process_id = 0 ,
520
+ scheduler_start_time = self .start_time ,
521
+ )
522
+ state_update = self ._locked_update (request_info )
523
+ yield (request , request_augmentation , request_info )
524
+
525
+ if state_update .stop_queueing :
526
+ stop_queueing = True
527
+ return
528
+
499
529
for request_chain in _iter ():
500
- if isinstance (request_chain , (list , tuple )):
501
- request = request_chain [0 ]
502
- else :
503
- request = request_chain
504
- count += 1
505
-
506
- if hasattr (request , "request_id" ):
507
- request_id = request .request_id
508
- elif hasattr (request , "id" ):
509
- request_id = request .id
510
- else :
511
- request_id = str (uuid .uuid4 ())
512
- request_info : ScheduledRequestInfo = ScheduledRequestInfo (
513
- request_id = request_id ,
514
- status = "queued" ,
515
- scheduler_process_id = 0 ,
516
- scheduler_start_time = self .start_time ,
517
- )
518
- state_update = self ._locked_update (request_info )
519
- yield (request , request_info )
530
+ yield list (_turn_iter (request_chain ))
520
531
521
- if state_update . stop_queueing :
532
+ if stop_queueing :
522
533
self .stop_send_requests_event .set ()
523
534
return
524
535
0 commit comments