@@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference(
313313 return hf_weights_files
314314
315315
316+ # explicitly use pure text format, with a newline at the end
317+ # this makes it impossible to see the animation in the progress bar
318+ # but will avoid messing up with ray or multiprocessing, which wraps
319+ # each line of output with some prefix.
320+ _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n " # noqa: E501
321+
322+
316323def np_cache_weights_iterator (
317324 model_name_or_path : str , cache_dir : Optional [str ], hf_folder : str ,
318325 hf_weights_files : List [str ]
@@ -321,6 +328,8 @@ def np_cache_weights_iterator(
321328
322329 Will dump the model weights to numpy files if they are not already dumped.
323330 """
331+ enable_tqdm = not torch .distributed .is_initialized (
332+ ) or torch .distributed .get_rank () == 0
324333 # Convert the model weights from torch tensors to numpy arrays for
325334 # faster loading.
326335 np_folder = os .path .join (hf_folder , "np" )
@@ -331,8 +340,12 @@ def np_cache_weights_iterator(
331340 with get_lock (model_name_or_path , cache_dir ):
332341 if not os .path .exists (weight_names_file ):
333342 weight_names : List [str ] = []
334- for bin_file in tqdm (hf_weights_files ,
335- desc = "Loading np_cache checkpoint shards" ):
343+ for bin_file in tqdm (
344+ hf_weights_files ,
345+ desc = "Loading np_cache checkpoint shards" ,
346+ disable = not enable_tqdm ,
347+ bar_format = _BAR_FORMAT ,
348+ ):
336349 state = torch .load (bin_file , map_location = "cpu" )
337350 for name , param in state .items ():
338351 param_path = os .path .join (np_folder , name )
@@ -356,8 +369,14 @@ def safetensors_weights_iterator(
356369 hf_weights_files : List [str ]
357370) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
358371 """Iterate over the weights in the model safetensor files."""
359- for st_file in tqdm (hf_weights_files ,
360- desc = "Loading safetensors checkpoint shards" ):
372+ enable_tqdm = not torch .distributed .is_initialized (
373+ ) or torch .distributed .get_rank () == 0
374+ for st_file in tqdm (
375+ hf_weights_files ,
376+ desc = "Loading safetensors checkpoint shards" ,
377+ disable = not enable_tqdm ,
378+ bar_format = _BAR_FORMAT ,
379+ ):
361380 with safe_open (st_file , framework = "pt" ) as f :
362381 for name in f .keys (): # noqa: SIM118
363382 param = f .get_tensor (name )
@@ -368,8 +387,14 @@ def pt_weights_iterator(
368387 hf_weights_files : List [str ]
369388) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
370389 """Iterate over the weights in the model bin/pt files."""
371- for bin_file in tqdm (hf_weights_files ,
372- desc = "Loading pt checkpoint shards" ):
390+ enable_tqdm = not torch .distributed .is_initialized (
391+ ) or torch .distributed .get_rank () == 0
392+ for bin_file in tqdm (
393+ hf_weights_files ,
394+ desc = "Loading pt checkpoint shards" ,
395+ disable = not enable_tqdm ,
396+ bar_format = _BAR_FORMAT ,
397+ ):
373398 state = torch .load (bin_file , map_location = "cpu" )
374399 for name , param in state .items ():
375400 yield name , param
0 commit comments