Skip to content

Commit bc4fbac

Browse files
committed
Ensure exit code and all stdout/stderr is logged upon shard failure
1 parent 1f6fcfe commit bc4fbac

File tree

1 file changed

+31
-25
lines changed

1 file changed

+31
-25
lines changed

launcher/src/main.rs

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use clap::Parser;
22
use std::env;
3-
use std::io::{BufRead, BufReader, ErrorKind};
3+
use std::io::{BufRead, BufReader, ErrorKind, Write};
44
use std::path::Path;
55
use std::process::ExitCode;
66
use std::sync::atomic::{AtomicBool, Ordering};
@@ -311,7 +311,7 @@ fn main() -> ExitCode {
311311

312312
while running.load(Ordering::SeqCst) {
313313
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
314-
tracing::error!("Shard {rank} failed:\n{err}");
314+
tracing::error!("Shard {rank} failed: {err}");
315315
exit_code = ExitCode::FAILURE;
316316
break;
317317
};
@@ -475,7 +475,7 @@ fn shard_manager(
475475
if let Some(alloc_conf) = cuda_alloc_conf {
476476
if alloc_conf.is_empty() {
477477
// Remove it from env
478-
env.retain(|(k, v)| k != "PYTORCH_CUDA_ALLOC_CONF");
478+
env.retain(|(k, _)| k != "PYTORCH_CUDA_ALLOC_CONF");
479479
} else {
480480
env.push(("PYTORCH_CUDA_ALLOC_CONF".into(), alloc_conf.into()));
481481
}
@@ -532,45 +532,51 @@ fn shard_manager(
532532

533533
// Redirect STDOUT and STDERR to the console
534534
let shard_stdout = p.stdout.take().unwrap();
535-
thread::spawn(move || BufReader::new(shard_stdout).lines().for_each(|line|
536-
println!("Shard {}: {}", rank, line.unwrap())
537-
));
535+
let stdout_thread = thread::spawn(
536+
move || BufReader::new(shard_stdout).lines().for_each(
537+
|line| println!("Shard {rank}: {}", line.unwrap())
538+
)
539+
);
538540
let shard_stderr = p.stderr.take().unwrap();
539-
thread::spawn(move || BufReader::new(shard_stderr).lines().for_each(|line|
540-
eprintln!("Shard {}: {}", rank, line.unwrap())
541-
));
541+
let stderr_thread = thread::spawn(
542+
move || BufReader::new(shard_stderr).lines().for_each(
543+
|line| eprintln!("Shard {rank}: {}", line.unwrap())
544+
)
545+
);
542546

543547
let mut ready = false;
544548
let start_time = Instant::now();
545549
let mut wait_time = Instant::now();
546550
loop {
547551
// Process exited
548-
if p.poll().is_some() {
549-
let mut err = String::new();
550-
//We don't need to do this now that we're logging
551-
//p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
552-
status_sender
553-
.send(ShardStatus::Failed((rank, err)))
554-
.unwrap();
555-
return;
552+
if let Some(status) = p.poll() {
553+
// Ensure we finish propagating any final stdout/stderr from the shard
554+
stdout_thread.join().unwrap_or_default();
555+
io::stdout().flush().unwrap_or_default();
556+
stderr_thread.join().unwrap_or_default();
557+
io::stderr().flush().unwrap_or_default();
558+
status_sender.send(ShardStatus::Failed((rank, format!("{status:?}")))).unwrap();
559+
return
556560
}
557561

558562
// We received a shutdown signal
559563
if *shutdown.lock().unwrap() {
560564
p.terminate().unwrap();
561565
let _ = p.wait_timeout(Duration::from_secs(90));
562566
info!("Shard {rank} terminated");
563-
return;
567+
return
564568
}
565569

566570
// Shard is ready
567-
if uds.exists() && !ready {
568-
info!("Shard {rank} ready in {:?}", start_time.elapsed());
569-
status_sender.send(ShardStatus::Ready).unwrap();
570-
ready = true;
571-
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
572-
tracing::info!("Waiting for shard {rank} to be ready...");
573-
wait_time = Instant::now();
571+
if !ready {
572+
if uds.exists() {
573+
info!("Shard {rank} ready in {:?}", start_time.elapsed());
574+
status_sender.send(ShardStatus::Ready).unwrap();
575+
ready = true;
576+
} else if wait_time.elapsed() > Duration::from_secs(10) {
577+
info!("Waiting for shard {rank} to be ready...");
578+
wait_time = Instant::now();
579+
}
574580
}
575581
sleep(Duration::from_millis(100));
576582
}

0 commit comments

Comments
 (0)