@@ -4,7 +4,7 @@ use nix::unistd::Pid;
4
4
use std:: env;
5
5
use std:: io:: { BufRead , BufReader , ErrorKind , Write } ;
6
6
use std:: path:: Path ;
7
- use std:: process:: { Command , ExitCode , Stdio } ;
7
+ use std:: process:: { Command , ExitCode , ExitStatus , Stdio } ;
8
8
use std:: sync:: atomic:: { AtomicBool , Ordering } ;
9
9
use std:: sync:: mpsc:: TryRecvError ;
10
10
use std:: sync:: Arc ;
@@ -16,7 +16,7 @@ use std::{fs, io};
16
16
use std:: env:: VarError ;
17
17
use std:: ffi:: OsString ;
18
18
use std:: fs:: File ;
19
- use std:: os:: unix:: process:: CommandExt ;
19
+ use std:: os:: unix:: process:: { CommandExt , ExitStatusExt } ;
20
20
use tracing:: { info, warn} ;
21
21
22
22
// In most cases this gives the best performance for inferencing
@@ -238,7 +238,7 @@ fn main() -> ExitCode {
238
238
Err ( TryRecvError :: Empty ) => {
239
239
sleep ( Duration :: from_millis ( 100 ) ) ;
240
240
}
241
- Ok ( ShardStatus :: Failed ) => {
241
+ Ok ( ShardStatus :: Failed ( _status ) ) => {
242
242
shutdown_shards ( shutdown, shutdown_receiver) ;
243
243
return ExitCode :: FAILURE ;
244
244
}
@@ -347,9 +347,17 @@ fn main() -> ExitCode {
347
347
let mut exit_code = ExitCode :: SUCCESS ;
348
348
349
349
while running. load ( Ordering :: SeqCst ) {
350
- if let Ok ( ShardStatus :: Failed ) = status_receiver. try_recv ( ) {
350
+ if let Ok ( ShardStatus :: Failed ( status ) ) = status_receiver. try_recv ( ) {
351
351
exit_code = ExitCode :: FAILURE ;
352
- break ;
352
+ terminate_gracefully ( & mut webserver, shutdown. clone ( ) , shutdown_receiver) ;
353
+ if status. signal ( ) == Some ( 7 ) && num_shard > 1 {
354
+ panic ! (
355
+ "Encountered SIGBUS error. This is usually caused by NCCL having insufficient shared memory. \
356
+ Ensure at least 1GB of shared memory is available. In case of OpenShift/K8s, \
357
+ mount a memory medium emptyDir volume to /dev/shm"
358
+ )
359
+ }
360
+ return exit_code
353
361
} ;
354
362
355
363
match webserver. try_wait ( ) . expect ( "Error polling status of router process" ) {
@@ -362,17 +370,21 @@ fn main() -> ExitCode {
362
370
} ;
363
371
}
364
372
365
- // Graceful termination
373
+ terminate_gracefully ( & mut webserver, shutdown. clone ( ) , shutdown_receiver) ;
374
+
375
+ exit_code
376
+ }
377
+
378
+ /// Graceful termination
379
+ fn terminate_gracefully ( webserver : & mut std:: process:: Child , shutdown : Arc < Mutex < bool > > , shutdown_receiver : & mpsc:: Receiver < ( ) > ) {
366
380
signal:: kill ( Pid :: from_raw ( webserver. id ( ) as i32 ) , Signal :: SIGTERM ) . unwrap ( ) ;
367
381
info ! ( "Waiting for router to gracefully shutdown" ) ;
368
382
webserver. wait ( ) . unwrap ( ) ;
369
383
info ! ( "Router terminated" ) ;
370
384
shutdown_shards ( shutdown, & shutdown_receiver) ;
371
385
372
- exit_code
373
386
}
374
387
375
-
376
388
fn num_cuda_devices ( ) -> Option < usize > {
377
389
let devices = match env:: var ( "CUDA_VISIBLE_DEVICES" ) {
378
390
Ok ( devices) => devices,
@@ -481,7 +493,7 @@ fn find_num_shards(num_shard: Option<usize>) -> usize {
481
493
#[ derive( Debug ) ]
482
494
enum ShardStatus {
483
495
Ready ,
484
- Failed ,
496
+ Failed ( ExitStatus ) ,
485
497
}
486
498
487
499
#[ allow( clippy:: too_many_arguments) ]
@@ -619,7 +631,7 @@ fn shard_manager(
619
631
} else {
620
632
tracing:: error!( "Shard {rank} failed to start:\n {err}" ) ;
621
633
}
622
- status_sender. send ( ShardStatus :: Failed ) . unwrap ( ) ;
634
+ status_sender. send ( ShardStatus :: Failed ( ExitStatus :: from_raw ( 0 ) ) ) . unwrap ( ) ;
623
635
return
624
636
}
625
637
} ;
@@ -654,7 +666,9 @@ fn shard_manager(
654
666
io:: stdout ( ) . flush ( ) . unwrap_or_default ( ) ;
655
667
stderr_thread. join ( ) . unwrap_or_default ( ) ;
656
668
io:: stderr ( ) . flush ( ) . unwrap_or_default ( ) ;
657
- status_sender. send ( ShardStatus :: Failed ) . unwrap ( ) ;
669
+ status_sender
670
+ . send ( ShardStatus :: Failed ( status) )
671
+ . unwrap ( ) ;
658
672
}
659
673
return
660
674
}
0 commit comments