@@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference(
313
313
return hf_weights_files
314
314
315
315
316
+ # explicitly use pure text format, with a newline at the end
317
+ # this makes it impossible to see the animation in the progress bar
318
+ # but will avoid messing up with ray or multiprocessing, which wraps
319
+ # each line of output with some prefix.
320
+ _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n " # noqa: E501
321
+
322
+
316
323
def np_cache_weights_iterator (
317
324
model_name_or_path : str , cache_dir : Optional [str ], hf_folder : str ,
318
325
hf_weights_files : List [str ]
@@ -321,6 +328,8 @@ def np_cache_weights_iterator(
321
328
322
329
Will dump the model weights to numpy files if they are not already dumped.
323
330
"""
331
+ enable_tqdm = not torch .distributed .is_initialized (
332
+ ) or torch .distributed .get_rank () == 0
324
333
# Convert the model weights from torch tensors to numpy arrays for
325
334
# faster loading.
326
335
np_folder = os .path .join (hf_folder , "np" )
@@ -331,8 +340,12 @@ def np_cache_weights_iterator(
331
340
with get_lock (model_name_or_path , cache_dir ):
332
341
if not os .path .exists (weight_names_file ):
333
342
weight_names : List [str ] = []
334
- for bin_file in tqdm (hf_weights_files ,
335
- desc = "Loading np_cache checkpoint shards" ):
343
+ for bin_file in tqdm (
344
+ hf_weights_files ,
345
+ desc = "Loading np_cache checkpoint shards" ,
346
+ disable = not enable_tqdm ,
347
+ bar_format = _BAR_FORMAT ,
348
+ ):
336
349
state = torch .load (bin_file , map_location = "cpu" )
337
350
for name , param in state .items ():
338
351
param_path = os .path .join (np_folder , name )
@@ -356,8 +369,14 @@ def safetensors_weights_iterator(
356
369
hf_weights_files : List [str ]
357
370
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
358
371
"""Iterate over the weights in the model safetensor files."""
359
- for st_file in tqdm (hf_weights_files ,
360
- desc = "Loading safetensors checkpoint shards" ):
372
+ enable_tqdm = not torch .distributed .is_initialized (
373
+ ) or torch .distributed .get_rank () == 0
374
+ for st_file in tqdm (
375
+ hf_weights_files ,
376
+ desc = "Loading safetensors checkpoint shards" ,
377
+ disable = not enable_tqdm ,
378
+ bar_format = _BAR_FORMAT ,
379
+ ):
361
380
with safe_open (st_file , framework = "pt" ) as f :
362
381
for name in f .keys (): # noqa: SIM118
363
382
param = f .get_tensor (name )
@@ -368,8 +387,14 @@ def pt_weights_iterator(
368
387
hf_weights_files : List [str ]
369
388
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
370
389
"""Iterate over the weights in the model bin/pt files."""
371
- for bin_file in tqdm (hf_weights_files ,
372
- desc = "Loading pt checkpoint shards" ):
390
+ enable_tqdm = not torch .distributed .is_initialized (
391
+ ) or torch .distributed .get_rank () == 0
392
+ for bin_file in tqdm (
393
+ hf_weights_files ,
394
+ desc = "Loading pt checkpoint shards" ,
395
+ disable = not enable_tqdm ,
396
+ bar_format = _BAR_FORMAT ,
397
+ ):
373
398
state = torch .load (bin_file , map_location = "cpu" )
374
399
for name , param in state .items ():
375
400
yield name , param
0 commit comments