@@ -331,7 +331,8 @@ def np_cache_weights_iterator(
331
331
with get_lock (model_name_or_path , cache_dir ):
332
332
if not os .path .exists (weight_names_file ):
333
333
weight_names : List [str ] = []
334
- for bin_file in hf_weights_files :
334
+ for bin_file in tqdm (hf_weights_files ,
335
+ desc = "Loading np_cache checkpoint shards" ):
335
336
state = torch .load (bin_file , map_location = "cpu" )
336
337
for name , param in state .items ():
337
338
param_path = os .path .join (np_folder , name )
@@ -355,7 +356,8 @@ def safetensors_weights_iterator(
355
356
hf_weights_files : List [str ]
356
357
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
357
358
"""Iterate over the weights in the model safetensor files."""
358
- for st_file in hf_weights_files :
359
+ for st_file in tqdm (hf_weights_files ,
360
+ desc = "Loading safetensors checkpoint shards" ):
359
361
with safe_open (st_file , framework = "pt" ) as f :
360
362
for name in f .keys (): # noqa: SIM118
361
363
param = f .get_tensor (name )
@@ -366,7 +368,8 @@ def pt_weights_iterator(
366
368
hf_weights_files : List [str ]
367
369
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
368
370
"""Iterate over the weights in the model bin/pt files."""
369
- for bin_file in hf_weights_files :
371
+ for bin_file in tqdm (hf_weights_files ,
372
+ desc = "Loading pt checkpoint shards" ):
370
373
state = torch .load (bin_file , map_location = "cpu" )
371
374
for name , param in state .items ():
372
375
yield name , param
0 commit comments