Skip to content

Commit eed74a5

Browse files
authored
Simplify weight loading logic (#2133)
1 parent 2acd76f commit eed74a5

File tree

3 files changed

+33
-37
lines changed

3 files changed

+33
-37
lines changed

vllm/config.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,10 @@ def _verify_load_format(self) -> None:
122122

123123
# TODO: Remove this check once HF updates the pt weights of Mixtral.
124124
architectures = getattr(self.hf_config, "architectures", [])
125-
if "MixtralForCausalLM" in architectures:
126-
if load_format == "pt":
127-
raise ValueError(
128-
"Currently, the 'pt' format is not supported for Mixtral. "
129-
"Please use the 'safetensors' format instead. ")
130-
elif load_format == "auto":
131-
# Do not fall back to pt weights.
132-
load_format = "safetensors"
133-
125+
if "MixtralForCausalLM" in architectures and load_format == "pt":
126+
raise ValueError(
127+
"Currently, the 'pt' format is not supported for Mixtral. "
128+
"Please use the 'safetensors' format instead. ")
134129
self.load_format = load_format
135130

136131
def _verify_tokenizer_mode(self) -> None:

vllm/model_executor/models/mixtral.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,11 @@ def load_weights(self,
412412

413413
params_dict = dict(self.named_parameters())
414414
for name, loaded_weight in hf_model_weights_iterator(
415-
model_name_or_path, cache_dir, load_format, revision):
415+
model_name_or_path,
416+
cache_dir,
417+
load_format,
418+
revision,
419+
fall_back_to_pt=False):
416420
if "rotary_emb.inv_freq" in name:
417421
continue
418422
for (param_name, weight_name, shard_id) in stacked_params_mapping:

vllm/model_executor/weight_utils.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,29 @@ def get_quant_config(
125125
def prepare_hf_model_weights(
126126
model_name_or_path: str,
127127
cache_dir: Optional[str] = None,
128-
use_safetensors: bool = False,
128+
load_format: str = "auto",
129129
fall_back_to_pt: bool = True,
130130
revision: Optional[str] = None,
131131
) -> Tuple[str, List[str], bool]:
132132
# Download model weights from huggingface.
133133
is_local = os.path.isdir(model_name_or_path)
134+
use_safetensors = False
134135
# Some quantized models use .pt files for storing the weights.
135-
allow_patterns = ["*.safetensors"
136-
] if use_safetensors else ["*.bin", "*.pt"]
136+
if load_format == "auto":
137+
allow_patterns = ["*.safetensors", "*.bin"]
138+
elif load_format == "safetensors":
139+
use_safetensors = True
140+
allow_patterns = ["*.safetensors"]
141+
elif load_format == "pt":
142+
allow_patterns = ["*.pt"]
143+
elif load_format == "npcache":
144+
allow_patterns = ["*.bin"]
145+
else:
146+
raise ValueError(f"Unknown load_format: {load_format}")
147+
148+
if fall_back_to_pt:
149+
allow_patterns += [".pt"]
150+
137151
if not is_local:
138152
# Use file lock to prevent multiple processes from
139153
# downloading the same model weights at the same time.
@@ -148,6 +162,10 @@ def prepare_hf_model_weights(
148162
hf_weights_files: List[str] = []
149163
for pattern in allow_patterns:
150164
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
165+
if len(hf_weights_files) > 0:
166+
if pattern == "*.safetensors":
167+
use_safetensors = True
168+
break
151169
if not use_safetensors:
152170
# Exclude files that are not needed for inference.
153171
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
@@ -163,13 +181,6 @@ def prepare_hf_model_weights(
163181
if not any(f.endswith(x) for x in blacklist)
164182
]
165183

166-
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
167-
return prepare_hf_model_weights(model_name_or_path,
168-
cache_dir=cache_dir,
169-
use_safetensors=False,
170-
fall_back_to_pt=False,
171-
revision=revision)
172-
173184
if len(hf_weights_files) == 0:
174185
raise RuntimeError(
175186
f"Cannot find any model weights with `{model_name_or_path}`")
@@ -182,30 +193,16 @@ def hf_model_weights_iterator(
182193
cache_dir: Optional[str] = None,
183194
load_format: str = "auto",
184195
revision: Optional[str] = None,
196+
fall_back_to_pt: Optional[bool] = True,
185197
) -> Iterator[Tuple[str, torch.Tensor]]:
186-
use_safetensors = False
187-
use_np_cache = False
188-
fall_back_to_pt = False
189-
if load_format == "auto":
190-
use_safetensors = True
191-
fall_back_to_pt = True
192-
elif load_format == "safetensors":
193-
use_safetensors = True
194-
elif load_format == "pt":
195-
pass
196-
elif load_format == "npcache":
197-
use_np_cache = True
198-
else:
199-
raise ValueError(f"Unknown load_format: {load_format}")
200-
201198
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
202199
model_name_or_path,
203200
cache_dir=cache_dir,
204-
use_safetensors=use_safetensors,
201+
load_format=load_format,
205202
fall_back_to_pt=fall_back_to_pt,
206203
revision=revision)
207204

208-
if use_np_cache:
205+
if load_format == "npcache":
209206
# Currently np_cache only support *.bin checkpoints
210207
assert use_safetensors is False
211208

0 commit comments

Comments
 (0)