Skip to content

Commit c520124

Browse files
authored
[misc] only tqdm for first rank (#6672)
1 parent 97234be commit c520124

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference(
313313
return hf_weights_files
314314

315315

316+
# explicitly use pure text format, with a newline at the end
317+
# this makes it impossible to see the animation in the progress bar
318+
# but will avoid messing up with ray or multiprocessing, which wraps
319+
# each line of output with some prefix.
320+
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
321+
322+
316323
def np_cache_weights_iterator(
317324
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
318325
hf_weights_files: List[str]
@@ -321,6 +328,8 @@ def np_cache_weights_iterator(
321328
322329
Will dump the model weights to numpy files if they are not already dumped.
323330
"""
331+
enable_tqdm = not torch.distributed.is_initialized(
332+
) or torch.distributed.get_rank() == 0
324333
# Convert the model weights from torch tensors to numpy arrays for
325334
# faster loading.
326335
np_folder = os.path.join(hf_folder, "np")
@@ -331,8 +340,12 @@ def np_cache_weights_iterator(
331340
with get_lock(model_name_or_path, cache_dir):
332341
if not os.path.exists(weight_names_file):
333342
weight_names: List[str] = []
334-
for bin_file in tqdm(hf_weights_files,
335-
desc="Loading np_cache checkpoint shards"):
343+
for bin_file in tqdm(
344+
hf_weights_files,
345+
desc="Loading np_cache checkpoint shards",
346+
disable=not enable_tqdm,
347+
bar_format=_BAR_FORMAT,
348+
):
336349
state = torch.load(bin_file, map_location="cpu")
337350
for name, param in state.items():
338351
param_path = os.path.join(np_folder, name)
@@ -356,8 +369,14 @@ def safetensors_weights_iterator(
356369
hf_weights_files: List[str]
357370
) -> Generator[Tuple[str, torch.Tensor], None, None]:
358371
"""Iterate over the weights in the model safetensor files."""
359-
for st_file in tqdm(hf_weights_files,
360-
desc="Loading safetensors checkpoint shards"):
372+
enable_tqdm = not torch.distributed.is_initialized(
373+
) or torch.distributed.get_rank() == 0
374+
for st_file in tqdm(
375+
hf_weights_files,
376+
desc="Loading safetensors checkpoint shards",
377+
disable=not enable_tqdm,
378+
bar_format=_BAR_FORMAT,
379+
):
361380
with safe_open(st_file, framework="pt") as f:
362381
for name in f.keys(): # noqa: SIM118
363382
param = f.get_tensor(name)
@@ -368,8 +387,14 @@ def pt_weights_iterator(
368387
hf_weights_files: List[str]
369388
) -> Generator[Tuple[str, torch.Tensor], None, None]:
370389
"""Iterate over the weights in the model bin/pt files."""
371-
for bin_file in tqdm(hf_weights_files,
372-
desc="Loading pt checkpoint shards"):
390+
enable_tqdm = not torch.distributed.is_initialized(
391+
) or torch.distributed.get_rank() == 0
392+
for bin_file in tqdm(
393+
hf_weights_files,
394+
desc="Loading pt checkpoint shards",
395+
disable=not enable_tqdm,
396+
bar_format=_BAR_FORMAT,
397+
):
373398
state = torch.load(bin_file, map_location="cpu")
374399
for name, param in state.items():
375400
yield name, param

0 commit comments

Comments
 (0)