|
1 | 1 | use clap::Parser;
|
2 | 2 | use std::env;
|
3 |
| -use std::io::{BufRead, BufReader, ErrorKind}; |
| 3 | +use std::io::{BufRead, BufReader, ErrorKind, Write}; |
4 | 4 | use std::path::Path;
|
5 | 5 | use std::process::ExitCode;
|
6 | 6 | use std::sync::atomic::{AtomicBool, Ordering};
|
@@ -311,7 +311,7 @@ fn main() -> ExitCode {
|
311 | 311 |
|
312 | 312 | while running.load(Ordering::SeqCst) {
|
313 | 313 | if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
314 |
| - tracing::error!("Shard {rank} failed:\n{err}"); |
| 314 | + tracing::error!("Shard {rank} failed: {err}"); |
315 | 315 | exit_code = ExitCode::FAILURE;
|
316 | 316 | break;
|
317 | 317 | };
|
@@ -475,7 +475,7 @@ fn shard_manager(
|
475 | 475 | if let Some(alloc_conf) = cuda_alloc_conf {
|
476 | 476 | if alloc_conf.is_empty() {
|
477 | 477 | // Remove it from env
|
478 |
| - env.retain(|(k, v)| k != "PYTORCH_CUDA_ALLOC_CONF"); |
| 478 | + env.retain(|(k, _)| k != "PYTORCH_CUDA_ALLOC_CONF"); |
479 | 479 | } else {
|
480 | 480 | env.push(("PYTORCH_CUDA_ALLOC_CONF".into(), alloc_conf.into()));
|
481 | 481 | }
|
@@ -532,45 +532,51 @@ fn shard_manager(
|
532 | 532 |
|
533 | 533 | // Redirect STDOUT and STDERR to the console
|
534 | 534 | let shard_stdout = p.stdout.take().unwrap();
|
535 |
| - thread::spawn(move || BufReader::new(shard_stdout).lines().for_each(|line| |
536 |
| - println!("Shard {}: {}", rank, line.unwrap()) |
537 |
| - )); |
| 535 | + let stdout_thread = thread::spawn( |
| 536 | + move || BufReader::new(shard_stdout).lines().for_each( |
| 537 | + |line| println!("Shard {rank}: {}", line.unwrap()) |
| 538 | + ) |
| 539 | + ); |
538 | 540 | let shard_stderr = p.stderr.take().unwrap();
|
539 |
| - thread::spawn(move || BufReader::new(shard_stderr).lines().for_each(|line| |
540 |
| - eprintln!("Shard {}: {}", rank, line.unwrap()) |
541 |
| - )); |
| 541 | + let stderr_thread = thread::spawn( |
| 542 | + move || BufReader::new(shard_stderr).lines().for_each( |
| 543 | + |line| eprintln!("Shard {rank}: {}", line.unwrap()) |
| 544 | + ) |
| 545 | + ); |
542 | 546 |
|
543 | 547 | let mut ready = false;
|
544 | 548 | let start_time = Instant::now();
|
545 | 549 | let mut wait_time = Instant::now();
|
546 | 550 | loop {
|
547 | 551 | // Process exited
|
548 |
| - if p.poll().is_some() { |
549 |
| - let mut err = String::new(); |
550 |
| - //We don't need to do this now that we're logging |
551 |
| - //p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); |
552 |
| - status_sender |
553 |
| - .send(ShardStatus::Failed((rank, err))) |
554 |
| - .unwrap(); |
555 |
| - return; |
| 552 | + if let Some(status) = p.poll() { |
| 553 | + // Ensure we finish propagating any final stdout/stderr from the shard |
| 554 | + stdout_thread.join().unwrap_or_default(); |
| 555 | + io::stdout().flush().unwrap_or_default(); |
| 556 | + stderr_thread.join().unwrap_or_default(); |
| 557 | + io::stderr().flush().unwrap_or_default(); |
| 558 | + status_sender.send(ShardStatus::Failed((rank, format!("{status:?}")))).unwrap(); |
| 559 | + return |
556 | 560 | }
|
557 | 561 |
|
558 | 562 | // We received a shutdown signal
|
559 | 563 | if *shutdown.lock().unwrap() {
|
560 | 564 | p.terminate().unwrap();
|
561 | 565 | let _ = p.wait_timeout(Duration::from_secs(90));
|
562 | 566 | info!("Shard {rank} terminated");
|
563 |
| - return; |
| 567 | + return |
564 | 568 | }
|
565 | 569 |
|
566 | 570 | // Shard is ready
|
567 |
| - if uds.exists() && !ready { |
568 |
| - info!("Shard {rank} ready in {:?}", start_time.elapsed()); |
569 |
| - status_sender.send(ShardStatus::Ready).unwrap(); |
570 |
| - ready = true; |
571 |
| - } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { |
572 |
| - tracing::info!("Waiting for shard {rank} to be ready..."); |
573 |
| - wait_time = Instant::now(); |
| 571 | + if !ready { |
| 572 | + if uds.exists() { |
| 573 | + info!("Shard {rank} ready in {:?}", start_time.elapsed()); |
| 574 | + status_sender.send(ShardStatus::Ready).unwrap(); |
| 575 | + ready = true; |
| 576 | + } else if wait_time.elapsed() > Duration::from_secs(10) { |
| 577 | + info!("Waiting for shard {rank} to be ready..."); |
| 578 | + wait_time = Instant::now(); |
| 579 | + } |
574 | 580 | }
|
575 | 581 | sleep(Duration::from_millis(100));
|
576 | 582 | }
|
|
0 commit comments