Skip to content

Commit e519ae0

Browse files
zhaotyertianyi.zhaoyoukaichao
authored
add tqdm when loading checkpoint shards (#6569)
Co-authored-by: tianyi.zhao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 7c2749a commit e519ae0

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def np_cache_weights_iterator(
331331
with get_lock(model_name_or_path, cache_dir):
332332
if not os.path.exists(weight_names_file):
333333
weight_names: List[str] = []
334-
for bin_file in hf_weights_files:
334+
for bin_file in tqdm(hf_weights_files,
335+
desc="Loading np_cache checkpoint shards"):
335336
state = torch.load(bin_file, map_location="cpu")
336337
for name, param in state.items():
337338
param_path = os.path.join(np_folder, name)
@@ -355,7 +356,8 @@ def safetensors_weights_iterator(
355356
hf_weights_files: List[str]
356357
) -> Generator[Tuple[str, torch.Tensor], None, None]:
357358
"""Iterate over the weights in the model safetensor files."""
358-
for st_file in hf_weights_files:
359+
for st_file in tqdm(hf_weights_files,
360+
desc="Loading safetensors checkpoint shards"):
359361
with safe_open(st_file, framework="pt") as f:
360362
for name in f.keys(): # noqa: SIM118
361363
param = f.get_tensor(name)
@@ -366,7 +368,8 @@ def pt_weights_iterator(
366368
hf_weights_files: List[str]
367369
) -> Generator[Tuple[str, torch.Tensor], None, None]:
368370
"""Iterate over the weights in the model bin/pt files."""
369-
for bin_file in hf_weights_files:
371+
for bin_file in tqdm(hf_weights_files,
372+
desc="Loading pt checkpoint shards"):
370373
state = torch.load(bin_file, map_location="cpu")
371374
for name, param in state.items():
372375
yield name, param

0 commit comments

Comments
 (0)