@@ -120,7 +120,7 @@ def __init__(
120
120
self .mp_context = None
121
121
self .mp_manager = None
122
122
self .processes : list [BaseProcess ] = None
123
- self .processes_completed_events : list [ Event ] = None
123
+ self .requests_completed_event : Event = None
124
124
self .startup_barrier : Barrier = None
125
125
self .shutdown_event : Event = None
126
126
self .error_event : Event = None
@@ -176,8 +176,11 @@ async def create_processes(self):
176
176
raise RuntimeError ("num_processes resolved to 0; increase limits/config" )
177
177
178
178
per_proc_max_conc = max_conc // num_processes
179
- per_proc_max_receive_buffer = max (
180
- 1 , math .floor (per_proc_max_conc * settings .mp_proc_receive_buffer_per )
179
+ max_pending_size = max (
180
+ 1 , math .floor (max_conc * settings .mp_max_pending_buffer_percent )
181
+ )
182
+ per_proc_max_buffer_size = max (
183
+ 1 , math .floor (per_proc_max_conc * settings .mp_max_worker_buffer_percent )
181
184
)
182
185
183
186
# Initialize multiprocessing components
@@ -186,12 +189,13 @@ async def create_processes(self):
186
189
self .startup_barrier = self .mp_context .Barrier (num_processes + 1 )
187
190
self .shutdown_event = self .mp_context .Event ()
188
191
self .error_event = self .mp_context .Event ()
192
+ self .requests_completed_event = self .mp_context .Event ()
189
193
190
194
if settings .mp_messaging_object == "queue" :
191
195
self .messaging = InterProcessMessagingQueue (
192
196
serialization = settings .mp_serialization ,
193
197
encoding = settings .mp_encoding ,
194
- max_send_size = max_conc ,
198
+ max_pending_size = max_pending_size ,
195
199
max_buffer_send_size = settings .mp_requests_send_buffer_size ,
196
200
poll_interval = settings .mp_poll_interval ,
197
201
)
@@ -200,7 +204,7 @@ async def create_processes(self):
200
204
manager = self .mp_manager ,
201
205
serialization = settings .mp_serialization ,
202
206
encoding = settings .mp_encoding ,
203
- max_send_size = max_conc ,
207
+ max_pending_size = max_pending_size ,
204
208
max_buffer_send_size = settings .mp_requests_send_buffer_size ,
205
209
poll_interval = settings .mp_poll_interval ,
206
210
)
@@ -209,32 +213,30 @@ async def create_processes(self):
209
213
num_workers = num_processes ,
210
214
serialization = settings .mp_serialization ,
211
215
encoding = settings .mp_encoding ,
212
- max_send_size = max_conc ,
216
+ max_pending_size = max_pending_size ,
213
217
max_buffer_send_size = settings .mp_requests_send_buffer_size ,
214
218
poll_interval = settings .mp_poll_interval ,
215
219
)
216
220
217
221
# Initialize worker processes
218
222
self .processes = []
219
- self .processes_completed_events = []
220
223
for rank in range (num_processes ):
221
224
# Distribute any remainder across the first N ranks
222
225
async_limit = per_proc_max_conc + (
223
226
1 if rank < (max_conc % num_processes ) else 0
224
227
)
225
228
226
- worker_completed_event = self .mp_context .Event ()
227
229
worker = WorkerProcess [RequestT , MeasuredRequestTimingsT , ResponseT ](
228
230
messaging = self .messaging .create_worker_copy (
229
231
worker_index = rank ,
230
232
max_buffer_send_size = None ,
231
- max_buffer_receive_size = per_proc_max_receive_buffer ,
233
+ max_buffer_receive_size = per_proc_max_buffer_size ,
232
234
),
233
235
async_limit = async_limit ,
234
236
startup_barrier = self .startup_barrier ,
235
237
shutdown_event = self .shutdown_event ,
236
238
error_event = self .error_event ,
237
- completed_event = worker_completed_event ,
239
+ requests_completed_event = self . requests_completed_event ,
238
240
backend = self .backend ,
239
241
request_timings = self .strategy .create_request_timings (
240
242
local_rank = rank ,
@@ -245,7 +247,6 @@ async def create_processes(self):
245
247
proc = self .mp_context .Process (target = worker .run , daemon = False )
246
248
proc .start ()
247
249
self .processes .append (proc )
248
- self .processes_completed_events .append (worker_completed_event )
249
250
250
251
reason , _ = await synchronous_to_exitable_async (
251
252
synchronous = None ,
@@ -279,7 +280,7 @@ async def start(self, start_time: float):
279
280
self ._state = _WorkerGroupState [RequestT , MeasuredRequestTimingsT , ResponseT ](
280
281
start_time = start_time ,
281
282
num_processes = len (self .processes ),
282
- processes_completed_events = self .processes_completed_events ,
283
+ processes = self .processes ,
283
284
constraints = self .constraints ,
284
285
shutdown_event = self .shutdown_event ,
285
286
)
@@ -289,6 +290,7 @@ async def start(self, start_time: float):
289
290
),
290
291
receive_callback = self ._state .update_callback_receive ,
291
292
send_stop_criteria = [self .shutdown_event , self .error_event ],
293
+ send_stopped_event = self .requests_completed_event ,
292
294
receive_stop_criteria = [self .error_event , self ._state .stop_callback_receive ],
293
295
pydantic_models = list (SchedulerMessagingPydanticRegistry .registry .values ()),
294
296
)
@@ -408,7 +410,7 @@ def __init__(
408
410
self ,
409
411
start_time : float ,
410
412
num_processes : int ,
411
- processes_completed_events : list [Event ],
413
+ processes : list [BaseProcess ],
412
414
constraints : dict [str , Constraint ],
413
415
shutdown_event : Event ,
414
416
):
@@ -419,7 +421,7 @@ def __init__(
419
421
num_processes = num_processes ,
420
422
start_time = start_time ,
421
423
)
422
- self .processes_completed_events = processes_completed_events
424
+ self .processes = processes
423
425
self ._constraints = constraints
424
426
self ._internal_constraints : dict [str , Constraint ] = {}
425
427
self ._shutdown_event = shutdown_event
@@ -544,7 +546,7 @@ def stop_callback_receive(
544
546
and messaging .send_stopped_event .is_set () # No more requests will be added
545
547
and self ._shutdown_event .is_set () # processing should stop
546
548
and all (
547
- event . is_set () for event in self .processes_completed_events
549
+ not proc . is_alive () for proc in self .processes
548
550
) # no more updates will be added by workers
549
551
)
550
552
@@ -601,21 +603,19 @@ def _update_new_request(self):
601
603
self ._state .queued_requests += 1
602
604
603
605
def _update_new_response (self , info : ScheduledRequestInfo [MeasuredRequestTimingsT ]):
604
- if info .status == "in_progress" :
606
+ if info .status == "in_progress" or (
607
+ info .status == "cancelled" and info .scheduler_timings .resolve_start is None
608
+ # Cancelled request that never sent a progress update
609
+ ):
605
610
self ._state .queued_requests -= 1
606
611
self ._state .processing_requests += 1
607
- elif info .status in ("completed" , "errored" , "cancelled" ):
612
+
613
+ if info .status in ("completed" , "errored" , "cancelled" ):
608
614
self ._state .processing_requests -= 1
609
615
self ._state .processed_requests += 1
610
616
self ._state .successful_requests += 1 if info .status == "completed" else 0
611
617
self ._state .errored_requests += 1 if info .status == "errored" else 0
612
618
self ._state .cancelled_requests += 1 if info .status == "cancelled" else 0
613
- else :
614
- raise ValueError (
615
- f"Unknown request status: { info .status } . "
616
- "Supported statuses are: queued, pending, in_progress, "
617
- "completed, errored, cancelled."
618
- )
619
619
620
620
def _update_with_constraints (
621
621
self , info : ScheduledRequestInfo [MeasuredRequestTimingsT ]
0 commit comments