@@ -103,6 +103,26 @@ fn main() -> ExitCode {
103
103
& args. model_name , args. revision . as_deref ( )
104
104
) . expect ( "Could not find tokenizer for model" ) ;
105
105
106
+ // Set max_split_size to default value if PYTORCH_CUDA_ALLOC_CONF is not set,
107
+ // or unset it if PYTORCH_CUDA_ALLOC_CONF is set but empty
108
+ let cuda_alloc_conf = match env:: var ( "PYTORCH_CUDA_ALLOC_CONF" ) {
109
+ Err ( VarError :: NotPresent ) if DEFAULT_SPLIT_SIZE == "none" => None ,
110
+ Err ( VarError :: NotPresent ) => {
111
+ let alloc_conf = format ! ( "max_split_size_mb:{}" , DEFAULT_SPLIT_SIZE ) ;
112
+ info ! ( "Setting PYTORCH_CUDA_ALLOC_CONF to default value: {alloc_conf}" ) ;
113
+ Some ( alloc_conf)
114
+ } ,
115
+ Ok ( alloc_conf) if alloc_conf. trim ( ) . is_empty ( ) => {
116
+ info ! ( "PYTORCH_CUDA_ALLOC_CONF is unset" ) ;
117
+ Some ( String :: new ( ) ) // This means remove it from the env
118
+ } ,
119
+ Ok ( alloc_conf) => {
120
+ info ! ( "PYTORCH_CUDA_ALLOC_CONF is set to: {alloc_conf}" ) ;
121
+ None
122
+ } ,
123
+ Err ( VarError :: NotUnicode ( _) ) => panic ! ( "PYTORCH_CUDA_ALLOC_CONF set to non-unicode value" ) ,
124
+ } ;
125
+
106
126
// Signal handler
107
127
let running = Arc :: new ( AtomicBool :: new ( true ) ) ;
108
128
let r = running. clone ( ) ;
@@ -126,6 +146,7 @@ fn main() -> ExitCode {
126
146
let status_sender = status_sender. clone ( ) ;
127
147
let shutdown = shutdown. clone ( ) ;
128
148
let shutdown_sender = shutdown_sender. clone ( ) ;
149
+ let cuda_alloc_conf = cuda_alloc_conf. clone ( ) ;
129
150
thread:: spawn ( move || {
130
151
shard_manager (
131
152
args. model_name ,
@@ -138,6 +159,7 @@ fn main() -> ExitCode {
138
159
args. max_batch_weight ,
139
160
args. shard_uds_path ,
140
161
args. cuda_process_memory_fraction ,
162
+ cuda_alloc_conf,
141
163
rank,
142
164
num_shard,
143
165
args. master_addr ,
@@ -380,6 +402,7 @@ fn shard_manager(
380
402
max_batch_weight : Option < usize > ,
381
403
uds_path : String ,
382
404
cuda_process_memory_fraction : f32 ,
405
+ cuda_alloc_conf : Option < String > ,
383
406
rank : usize ,
384
407
world_size : usize ,
385
408
master_addr : String ,
@@ -449,15 +472,13 @@ fn shard_manager(
449
472
_ => ( ) ,
450
473
}
451
474
452
- // Set max_split_size to default value if PYTORCH_CUDA_ALLOC_CONF is not set
453
- match env :: var ( "PYTORCH_CUDA_ALLOC_CONF" ) {
454
- Err ( VarError :: NotPresent ) => {
455
- let alloc_conf = format ! ( "max_split_size_mb:{}" , DEFAULT_SPLIT_SIZE ) ;
456
- info ! ( "Setting PYTORCH_CUDA_ALLOC_CONF to default value {alloc_conf}" ) ;
475
+ if let Some ( alloc_conf ) = cuda_alloc_conf {
476
+ if alloc_conf . is_empty ( ) {
477
+ // Remove it from env
478
+ env . retain ( | ( k , v ) | k != "PYTORCH_CUDA_ALLOC_CONF" ) ;
479
+ } else {
457
480
env. push ( ( "PYTORCH_CUDA_ALLOC_CONF" . into ( ) , alloc_conf. into ( ) ) ) ;
458
- } ,
459
- Err ( VarError :: NotUnicode ( _) ) => panic ! ( "PYTORCH_CUDA_ALLOC_CONF set to non-unicode value" ) ,
460
- Ok ( alloc_conf) => info ! ( "PYTORCH_CUDA_ALLOC_CONF is set to {alloc_conf}" ) ,
481
+ }
461
482
}
462
483
463
484
// Torch Distributed / DeepSpeed Env vars
0 commit comments