Skip to content

Commit 8dcbad0

Browse files
committed
Don't log PYTORCH_CUDA_ALLOC_CONF setting for every shard
1 parent 011ba5d commit 8dcbad0

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

launcher/src/main.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,26 @@ fn main() -> ExitCode {
103103
&args.model_name, args.revision.as_deref()
104104
).expect("Could not find tokenizer for model");
105105

106+
// Set max_split_size to default value if PYTORCH_CUDA_ALLOC_CONF is not set,
107+
// or unset it if PYTORCH_CUDA_ALLOC_CONF is set but empty
108+
let cuda_alloc_conf = match env::var("PYTORCH_CUDA_ALLOC_CONF") {
109+
Err(VarError::NotPresent) if DEFAULT_SPLIT_SIZE == "none" => None,
110+
Err(VarError::NotPresent) => {
111+
let alloc_conf = format!("max_split_size_mb:{}", DEFAULT_SPLIT_SIZE);
112+
info!("Setting PYTORCH_CUDA_ALLOC_CONF to default value: {alloc_conf}");
113+
Some(alloc_conf)
114+
},
115+
Ok(alloc_conf) if alloc_conf.trim().is_empty() => {
116+
info!("PYTORCH_CUDA_ALLOC_CONF is unset");
117+
Some(String::new()) // This means remove it from the env
118+
},
119+
Ok(alloc_conf) => {
120+
info!("PYTORCH_CUDA_ALLOC_CONF is set to: {alloc_conf}");
121+
None
122+
},
123+
Err(VarError::NotUnicode(_)) => panic!("PYTORCH_CUDA_ALLOC_CONF set to non-unicode value"),
124+
};
125+
106126
// Signal handler
107127
let running = Arc::new(AtomicBool::new(true));
108128
let r = running.clone();
@@ -126,6 +146,7 @@ fn main() -> ExitCode {
126146
let status_sender = status_sender.clone();
127147
let shutdown = shutdown.clone();
128148
let shutdown_sender = shutdown_sender.clone();
149+
let cuda_alloc_conf = cuda_alloc_conf.clone();
129150
thread::spawn(move || {
130151
shard_manager(
131152
args.model_name,
@@ -138,6 +159,7 @@ fn main() -> ExitCode {
138159
args.max_batch_weight,
139160
args.shard_uds_path,
140161
args.cuda_process_memory_fraction,
162+
cuda_alloc_conf,
141163
rank,
142164
num_shard,
143165
args.master_addr,
@@ -380,6 +402,7 @@ fn shard_manager(
380402
max_batch_weight: Option<usize>,
381403
uds_path: String,
382404
cuda_process_memory_fraction: f32,
405+
cuda_alloc_conf: Option<String>,
383406
rank: usize,
384407
world_size: usize,
385408
master_addr: String,
@@ -449,15 +472,13 @@ fn shard_manager(
449472
_ => (),
450473
}
451474

452-
// Set max_split_size to default value if PYTORCH_CUDA_ALLOC_CONF is not set
453-
match env::var("PYTORCH_CUDA_ALLOC_CONF") {
454-
Err(VarError::NotPresent) => {
455-
let alloc_conf = format!("max_split_size_mb:{}", DEFAULT_SPLIT_SIZE);
456-
info!("Setting PYTORCH_CUDA_ALLOC_CONF to default value {alloc_conf}");
475+
if let Some(alloc_conf) = cuda_alloc_conf {
476+
if alloc_conf.is_empty() {
477+
// Remove it from env
478+
env.retain(|(k, v)| k != "PYTORCH_CUDA_ALLOC_CONF");
479+
} else {
457480
env.push(("PYTORCH_CUDA_ALLOC_CONF".into(), alloc_conf.into()));
458-
},
459-
Err(VarError::NotUnicode(_)) => panic!("PYTORCH_CUDA_ALLOC_CONF set to non-unicode value"),
460-
Ok(alloc_conf) => info!("PYTORCH_CUDA_ALLOC_CONF is set to {alloc_conf}"),
481+
}
461482
}
462483

463484
// Torch Distributed / DeepSpeed Env vars

0 commit comments

Comments
 (0)