Skip to content

Commit 53a1ba6

Browse files
authored
[log] add weights loading time log to sharded_state loader (#28628)
Signed-off-by: Andy Xie <[email protected]>
1 parent 1840c5c commit 53a1ba6

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

vllm/model_executor/model_loader/sharded_state_loader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import collections
55
import glob
66
import os
7+
import time
78
from collections.abc import Generator
89
from typing import Any
910

@@ -132,6 +133,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
132133
f"pre-sharded checkpoints are currently supported!"
133134
)
134135
state_dict = self._filter_subtensors(model.state_dict())
136+
counter_before_loading_weights = time.perf_counter()
135137
for key, tensor in self.iterate_over_files(filepaths):
136138
# If loading with LoRA enabled, additional padding may
137139
# be added to certain parameters. We only load into a
@@ -150,6 +152,12 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
150152
)
151153
param_data.copy_(tensor)
152154
state_dict.pop(key)
155+
counter_after_loading_weights = time.perf_counter()
156+
logger.info_once(
157+
"Loading weights took %.2f seconds",
158+
counter_after_loading_weights - counter_before_loading_weights,
159+
scope="local",
160+
)
153161
if state_dict:
154162
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
155163

0 commit comments

Comments
 (0)