@@ -234,9 +234,14 @@ def abort_requests(self, request_ids: list[str]):
234
234
self .scheduler .finish_requests (request_ids ,
235
235
RequestStatus .FINISHED_ABORTED )
236
236
237
- def execute_model (self , scheduler_output : SchedulerOutput ):
237
+ def execute_model_with_error_logging (
238
+ self ,
239
+ model_fn : Callable [[SchedulerOutput ], ModelRunnerOutput ],
240
+ scheduler_output : SchedulerOutput ,
241
+ ) -> ModelRunnerOutput :
242
+ """Execute the model and log detailed info on failure."""
238
243
try :
239
- return self . model_executor . execute_model (scheduler_output )
244
+ return model_fn (scheduler_output )
240
245
except Exception as err :
241
246
# We do not want to catch BaseException here since we're only
242
247
# interested in dumping info when the exception is due to an
@@ -259,7 +264,9 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
259
264
if not self .scheduler .has_requests ():
260
265
return {}, False
261
266
scheduler_output = self .scheduler .schedule ()
262
- model_output = self .execute_model (scheduler_output )
267
+ model_output = self .execute_model_with_error_logging (
268
+ self .model_executor .execute_model , # type: ignore
269
+ scheduler_output )
263
270
engine_core_outputs = self .scheduler .update_from_output (
264
271
scheduler_output , model_output ) # type: ignore
265
272
@@ -306,8 +313,11 @@ def step_with_batch_queue(
306
313
# so we need more work.
307
314
if not scheduled_batch and not self .batch_queue .empty ():
308
315
future , scheduler_output = self .batch_queue .get_nowait ()
316
+
309
317
# Blocking until the first result is available.
310
- model_output = future .result ()
318
+ model_output = self .execute_model_with_error_logging (
319
+ lambda _ : future .result (), scheduler_output )
320
+
311
321
self .batch_queue .task_done ()
312
322
engine_core_outputs = (self .scheduler .update_from_output (
313
323
scheduler_output , model_output ))
0 commit comments