Skip to content

Commit bf7e3c5

Browse files
authored
[Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (#15939)
Signed-off-by: Jonghyun Choe <[email protected]>
1 parent a35a8a8 commit bf7e3c5

File tree

3 files changed

+119
-100
lines changed

3 files changed

+119
-100
lines changed

vllm/model_executor/models/baichuan.py

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from vllm.sequence import IntermediateTensors
4848

4949
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
50-
from .utils import (is_pp_missing_parameter,
50+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
5151
make_empty_intermediate_tensors_factory, make_layers)
5252

5353

@@ -321,6 +321,45 @@ def forward(
321321
hidden_states, _ = self.norm(hidden_states, residual)
322322
return hidden_states
323323

324+
def load_weights(self, weights: Iterable[Tuple[str,
325+
torch.Tensor]]) -> Set[str]:
326+
stacked_params_mapping = [
327+
# (param_name, shard_name, shard_id)
328+
("gate_up_proj", "gate_proj", 0),
329+
("gate_up_proj", "up_proj", 1),
330+
]
331+
params_dict = dict(self.named_parameters())
332+
loaded_params: Set[str] = set()
333+
for name, loaded_weight in weights:
334+
if "rotary_emb.inv_freq" in name:
335+
continue
336+
337+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
338+
if weight_name not in name:
339+
continue
340+
name = name.replace(weight_name, param_name)
341+
# Skip loading extra bias for GPTQ models.
342+
if name.endswith(".bias") and name not in params_dict:
343+
continue
344+
if is_pp_missing_parameter(name, self):
345+
continue
346+
param = params_dict[name]
347+
weight_loader = param.weight_loader
348+
weight_loader(param, loaded_weight, shard_id)
349+
break
350+
else:
351+
# Skip loading extra bias for GPTQ models.
352+
if name.endswith(".bias") and name not in params_dict:
353+
continue
354+
if is_pp_missing_parameter(name, self):
355+
continue
356+
param = params_dict[name]
357+
weight_loader = getattr(param, "weight_loader",
358+
default_weight_loader)
359+
weight_loader(param, loaded_weight)
360+
loaded_params.add(name)
361+
return loaded_params
362+
324363

325364
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
326365
SupportsQuant):
@@ -353,6 +392,7 @@ def __init__(
353392
self.lm_head = ParallelLMHead(config.vocab_size,
354393
config.hidden_size,
355394
quant_config=quant_config)
395+
self.lm_head.weight.weight_loader = self.lm_head_weight_loader
356396
if self.config.tie_word_embeddings:
357397
self.lm_head.weight = self.model.embed_tokens.weight
358398
self.logits_processor = LogitsProcessor(config.vocab_size)
@@ -393,53 +433,22 @@ def sample(
393433

394434
def load_weights(self, weights: Iterable[Tuple[str,
395435
torch.Tensor]]) -> Set[str]:
396-
stacked_params_mapping = [
397-
# (param_name, shard_name, shard_id)
398-
("gate_up_proj", "gate_proj", 0),
399-
("gate_up_proj", "up_proj", 1),
400-
]
401-
params_dict = dict(self.named_parameters())
402-
loaded_params: Set[str] = set()
403-
for name, loaded_weight in weights:
404-
if "rotary_emb.inv_freq" in name:
405-
continue
406-
if name == "lm_head.weight":
407-
# Unlike Baichuan, Baichuan2 normalizes the head weights.
408-
# Refer to:
409-
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
410-
# Distinguish between Baichuan and Baichuan2 by checking the
411-
# vocab size. This is suggested by
412-
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
413-
is_baichuan2 = self.config.vocab_size == 125696
414-
if is_baichuan2:
415-
loaded_weight = torch.nn.functional.normalize(
416-
loaded_weight)
417-
418-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
419-
if weight_name not in name:
420-
continue
421-
name = name.replace(weight_name, param_name)
422-
# Skip loading extra bias for GPTQ models.
423-
if name.endswith(".bias") and name not in params_dict:
424-
continue
425-
if is_pp_missing_parameter(name, self):
426-
continue
427-
param = params_dict[name]
428-
weight_loader = param.weight_loader
429-
weight_loader(param, loaded_weight, shard_id)
430-
break
431-
else:
432-
# Skip loading extra bias for GPTQ models.
433-
if name.endswith(".bias") and name not in params_dict:
434-
continue
435-
if is_pp_missing_parameter(name, self):
436-
continue
437-
param = params_dict[name]
438-
weight_loader = getattr(param, "weight_loader",
439-
default_weight_loader)
440-
weight_loader(param, loaded_weight)
441-
loaded_params.add(name)
442-
return loaded_params
436+
loader = AutoWeightsLoader(self)
437+
return loader.load_weights(weights)
438+
439+
def lm_head_weight_loader(self, param: nn.Parameter,
440+
loaded_weight: torch.Tensor):
441+
# Unlike Baichuan, Baichuan2 normalizes the head weights.
442+
# Refer to:
443+
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
444+
# Distinguish between Baichuan and Baichuan2 by checking the
445+
# vocab size. This is suggested by
446+
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
447+
is_baichuan2 = self.config.vocab_size == 125696
448+
if is_baichuan2:
449+
loaded_weight = torch.nn.functional.normalize(loaded_weight)
450+
451+
default_weight_loader(param, loaded_weight)
443452

444453

445454
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):

vllm/model_executor/models/gpt_neox.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.sequence import IntermediateTensors
4343

4444
from .interfaces import SupportsPP
45-
from .utils import (is_pp_missing_parameter,
45+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4646
make_empty_intermediate_tensors_factory, make_layers,
4747
maybe_prefix)
4848

@@ -241,6 +241,45 @@ def forward(
241241
hidden_states = self.final_layer_norm(hidden_states)
242242
return hidden_states
243243

244+
def load_weights(self, weights: Iterable[Tuple[str,
245+
torch.Tensor]]) -> Set[str]:
246+
params_dict = dict(self.named_parameters())
247+
loaded_params: Set[str] = set()
248+
for name, loaded_weight in weights:
249+
if ("attention.bias" in name or "attention.masked_bias" in name
250+
or "rotary_emb.inv_freq" in name):
251+
continue
252+
if ("rotary_emb.cos_cached" in name
253+
or "rotary_emb.sin_cached" in name):
254+
# Models trained using OpenRLHF may include
255+
# these tensors in the checkpoint. Skip them.
256+
continue
257+
if is_pp_missing_parameter(name, self):
258+
continue
259+
param = params_dict[name]
260+
261+
if "query_key_value" in name:
262+
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
263+
# (num_heads * 3 * head_size), while the
264+
# required shape is (3 * num_heads * head_size).
265+
# Thus, we need weight conversion.
266+
output_dim = getattr(param, "output_dim", None)
267+
num_heads = self.config.num_attention_heads
268+
if output_dim is not None:
269+
loaded_weight_shape = loaded_weight.shape
270+
loaded_weight = loaded_weight.view(
271+
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
272+
loaded_weight_shape[output_dim + 1:])
273+
loaded_weight = loaded_weight.transpose(
274+
output_dim, output_dim + 1)
275+
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
276+
277+
weight_loader = getattr(param, "weight_loader",
278+
default_weight_loader)
279+
weight_loader(param, loaded_weight)
280+
loaded_params.add(name)
281+
return loaded_params
282+
244283

245284
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
246285

@@ -297,39 +336,5 @@ def sample(
297336

298337
def load_weights(self, weights: Iterable[Tuple[str,
299338
torch.Tensor]]) -> Set[str]:
300-
params_dict = dict(self.named_parameters())
301-
loaded_params: Set[str] = set()
302-
for name, loaded_weight in weights:
303-
if ("attention.bias" in name or "attention.masked_bias" in name
304-
or "rotary_emb.inv_freq" in name):
305-
continue
306-
if ("rotary_emb.cos_cached" in name
307-
or "rotary_emb.sin_cached" in name):
308-
# Models trained using OpenRLHF may include
309-
# these tensors in the checkpoint. Skip them.
310-
continue
311-
if is_pp_missing_parameter(name, self):
312-
continue
313-
param = params_dict[name]
314-
315-
if "query_key_value" in name:
316-
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
317-
# (num_heads * 3 * head_size), while the
318-
# required shape is (3 * num_heads * head_size).
319-
# Thus, we need weight conversion.
320-
output_dim = getattr(param, "output_dim", None)
321-
num_heads = self.config.num_attention_heads
322-
if output_dim is not None:
323-
loaded_weight_shape = loaded_weight.shape
324-
loaded_weight = loaded_weight.view(
325-
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
326-
loaded_weight_shape[output_dim + 1:])
327-
loaded_weight = loaded_weight.transpose(
328-
output_dim, output_dim + 1)
329-
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
330-
331-
weight_loader = getattr(param, "weight_loader",
332-
default_weight_loader)
333-
weight_loader(param, loaded_weight)
334-
loaded_params.add(name)
335-
return loaded_params
339+
loader = AutoWeightsLoader(self)
340+
return loader.load_weights(weights)

vllm/model_executor/models/mpt.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.transformers_utils.configs.mpt import MPTConfig
2828

2929
from .interfaces import SupportsPP
30-
from .utils import (is_pp_missing_parameter,
30+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
3131
make_empty_intermediate_tensors_factory, make_layers,
3232
maybe_prefix)
3333

@@ -266,6 +266,23 @@ def forward(
266266
hidden_states = self.norm_f(hidden_states)
267267
return hidden_states
268268

269+
def load_weights(self, weights: Iterable[Tuple[str,
270+
torch.Tensor]]) -> Set[str]:
271+
params_dict = dict(self.named_parameters(remove_duplicate=False))
272+
loaded_params: Set[str] = set()
273+
for name, loaded_weight in weights:
274+
# Skip loading extra bias for GPTQ models.
275+
if name.endswith(".bias") and name not in params_dict:
276+
continue
277+
if is_pp_missing_parameter(name, self):
278+
continue
279+
param = params_dict[name]
280+
weight_loader = getattr(param, "weight_loader",
281+
default_weight_loader)
282+
weight_loader(param, loaded_weight)
283+
loaded_params.add(name)
284+
return loaded_params
285+
269286

270287
class MPTForCausalLM(nn.Module, SupportsPP):
271288

@@ -318,17 +335,5 @@ def sample(
318335

319336
def load_weights(self, weights: Iterable[Tuple[str,
320337
torch.Tensor]]) -> Set[str]:
321-
params_dict = dict(self.named_parameters(remove_duplicate=False))
322-
loaded_params: Set[str] = set()
323-
for name, loaded_weight in weights:
324-
# Skip loading extra bias for GPTQ models.
325-
if name.endswith(".bias") and name not in params_dict:
326-
continue
327-
if is_pp_missing_parameter(name, self):
328-
continue
329-
param = params_dict[name]
330-
weight_loader = getattr(param, "weight_loader",
331-
default_weight_loader)
332-
weight_loader(param, loaded_weight)
333-
loaded_params.add(name)
334-
return loaded_params
338+
loader = AutoWeightsLoader(self)
339+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)