File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -11,6 +11,7 @@ use std::thread;
1111use std:: thread:: sleep;
1212use std:: time:: { Duration , Instant } ;
1313use std:: { fs, io} ;
14+ use std:: env:: VarError ;
1415use std:: ffi:: OsString ;
1516use subprocess:: { Popen , PopenConfig , PopenError , Redirection } ;
1617use tracing:: info;
@@ -364,6 +365,9 @@ enum ShardStatus {
364365 Failed ( ( usize , String ) ) ,
365366}
366367
368+ const DEFAULT_SPLIT_SIZE : & ' static str = "512" ;
369+
370+
367371#[ allow( clippy:: too_many_arguments) ]
368372fn shard_manager (
369373 model_name : String ,
@@ -445,6 +449,17 @@ fn shard_manager(
445449 _ => ( ) ,
446450 }
447451
452+ // Set max_split_size to default value if PYTORCH_CUDA_ALLOC_CONF is not set
453+ match env:: var ( "PYTORCH_CUDA_ALLOC_CONF" ) {
454+ Err ( VarError :: NotPresent ) => {
455+ let alloc_conf = format ! ( "max_split_size_mb:{}" , DEFAULT_SPLIT_SIZE ) ;
456+ info ! ( "Setting PYTORCH_CUDA_ALLOC_CONF to default value {alloc_conf}" ) ;
457+ env. push ( ( "PYTORCH_CUDA_ALLOC_CONF" . into ( ) , alloc_conf. into ( ) ) ) ;
458+ } ,
459+ Err ( VarError :: NotUnicode ( _) ) => panic ! ( "PYTORCH_CUDA_ALLOC_CONF set to non-unicode value" ) ,
460+ Ok ( alloc_conf) => info ! ( "PYTORCH_CUDA_ALLOC_CONF is set to {alloc_conf}" ) ,
461+ }
462+
448463 // Torch Distributed / DeepSpeed Env vars
449464 env. push ( ( "RANK" . into ( ) , rank. to_string ( ) . into ( ) ) ) ;
450465 env. push ( ( "LOCAL_RANK" . into ( ) , rank. to_string ( ) . into ( ) ) ) ;
You can’t perform that action at this time.
0 commit comments