@@ -84,7 +84,7 @@ def __init__(
8484 backend : BackendInterface [RequestT , ResponseT ],
8585 strategy : SchedulingStrategy ,
8686 startup_duration : float ,
87- ** constraints : dict [ str , Constraint ] ,
87+ ** constraints : Constraint ,
8888 ):
8989 """
9090 Initialize a worker process group for distributed request processing.
@@ -232,7 +232,7 @@ async def create_processes(self):
232232 worker_index = rank ,
233233 max_buffer_send_size = None ,
234234 max_buffer_receive_size = per_proc_max_buffer_size ,
235- ),
235+ ), # The non-group worker lacks the SchedulerState type. Type err.
236236 backend = self .backend ,
237237 strategy = self .strategy ,
238238 async_limit = async_limit ,
@@ -478,9 +478,9 @@ def __init__(
478478 num_processes = len (processes ),
479479 start_time = start_time ,
480480 )
481- self ._queued_requests : set [RequestT | MultiTurnRequestT [ RequestT ] ] = set ()
482- self ._pending_requests : set [RequestT | MultiTurnRequestT [ RequestT ] ] = set ()
483- self ._processing_requests : set [RequestT | MultiTurnRequestT [ RequestT ] ] = set ()
481+ self ._queued_request_ids : set [str ] = set ()
482+ self ._pending_request_ids : set [str ] = set ()
483+ self ._processing_request_ids : set [str ] = set ()
484484
485485 def requests_generator (
486486 self , requests : Iterable [RequestT | MultiTurnRequestT [RequestT ]]
@@ -517,11 +517,13 @@ def requests_generator(
517517 )
518518 state_update = self ._locked_update (request_info )
519519 request_info .timings .queued = time .time ()
520+ if self .messaging .buffer_receive_queue is None :
521+ raise RuntimeError ("buffer receive queue is None" )
520522 self .messaging .buffer_receive_queue .sync_put (
521523 (None , request , request_info , state_update .state )
522524 )
523525
524- yield ( request , request_info )
526+ yield request , request_info
525527
526528 if state_update .stop_queueing :
527529 self .stop_send_requests_event .set ()
@@ -530,8 +532,8 @@ def requests_generator(
530532 # Reached the end, inject a RequestsExhaustedConstraint to record
531533 self ._locked_update (
532534 info = None ,
533- requests_exhausted = {
534- "requests_exhausted" : RequestsExhaustedConstraint (
535+ add_constraints = {
536+ "requests_exhausted" : RequestsExhaustedConstraint ( # type: ignore[dict-item]
535537 num_requests = count
536538 )
537539 },
@@ -610,10 +612,10 @@ def received_callback(
610612 def _locked_update (
611613 self ,
612614 info : RequestInfo | None = None ,
613- ** add_constraints : dict [str , Constraint ],
615+ add_constraints : dict [str , Constraint ] | None = None ,
614616 ) -> _StateUpdate :
615617 with self ._update_lock :
616- if add_constraints :
618+ if add_constraints is not None :
617619 self .constraints .update (add_constraints )
618620
619621 if info is not None :
@@ -631,34 +633,34 @@ def _locked_update(
631633
632634 def _update_state_request_counts (self , info : RequestInfo ):
633635 if info .status == "queued" :
634- self ._queued_requests .add (info .request_id )
635- self ._state .queued_requests = len (self ._queued_requests )
636+ self ._queued_request_ids .add (info .request_id )
637+ self ._state .queued_requests = len (self ._queued_request_ids )
636638 self ._state .created_requests += 1
637639 elif info .status == "pending" :
638- self ._queued_requests .remove (info .request_id )
639- self ._state .queued_requests = len (self ._queued_requests )
640- self ._pending_requests .add (info .request_id )
641- self ._state .pending_requests = len (self ._pending_requests )
640+ self ._queued_request_ids .remove (info .request_id )
641+ self ._state .queued_requests = len (self ._queued_request_ids )
642+ self ._pending_request_ids .add (info .request_id )
643+ self ._state .pending_requests = len (self ._pending_request_ids )
642644 elif info .status == "in_progress" :
643- self ._pending_requests .remove (info .request_id )
644- self ._state .pending_requests = len (self ._pending_requests )
645- self ._processing_requests .add (info .request_id )
646- self ._state .processing_requests = len (self ._processing_requests )
645+ self ._pending_request_ids .remove (info .request_id )
646+ self ._state .pending_requests = len (self ._pending_request_ids )
647+ self ._processing_request_ids .add (info .request_id )
648+ self ._state .processing_requests = len (self ._processing_request_ids )
647649 elif info .status == "completed" :
648- self ._processing_requests .remove (info .request_id )
649- self ._state .processing_requests = len (self ._processing_requests )
650+ self ._processing_request_ids .remove (info .request_id )
651+ self ._state .processing_requests = len (self ._processing_request_ids )
650652 self ._state .processed_requests += 1
651653 self ._state .successful_requests += 1
652654 elif info .status in ("errored" , "cancelled" ):
653- if info .request_id in self ._queued_requests :
654- self ._queued_requests .remove (info .request_id )
655- self ._state .queued_requests = len (self ._queued_requests )
656- elif info .request_id in self ._pending_requests :
657- self ._pending_requests .remove (info .request_id )
658- self ._state .pending_requests = len (self ._pending_requests )
659- elif info .request_id in self ._processing_requests :
660- self ._processing_requests .remove (info .request_id )
661- self ._state .processing_requests = len (self ._processing_requests )
655+ if info .request_id in self ._queued_request_ids :
656+ self ._queued_request_ids .remove (info .request_id )
657+ self ._state .queued_requests = len (self ._queued_request_ids )
658+ elif info .request_id in self ._pending_request_ids :
659+ self ._pending_request_ids .remove (info .request_id )
660+ self ._state .pending_requests = len (self ._pending_request_ids )
661+ elif info .request_id in self ._processing_request_ids :
662+ self ._processing_request_ids .remove (info .request_id )
663+ self ._state .processing_requests = len (self ._processing_request_ids )
662664
663665 self ._state .processed_requests += 1
664666 self ._state .errored_requests += 1 if info .status == "errored" else 0
0 commit comments