Skip to content

Commit 10a0253

Browse files
authored
Fix loading of quantized BigCode models (#22463)
Signed-off-by: Eldar Kurtic <[email protected]>
1 parent 65552b4 commit 10a0253

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

vllm/model_executor/models/gpt_bigcode.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545

4646
from .interfaces import SupportsLoRA, SupportsPP
4747
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
48-
make_empty_intermediate_tensors_factory, make_layers)
48+
make_empty_intermediate_tensors_factory, make_layers,
49+
maybe_prefix)
4950

5051

5152
class GPTBigCodeAttention(nn.Module):
@@ -83,13 +84,15 @@ def __init__(
8384
total_num_kv_heads,
8485
bias=True,
8586
quant_config=quant_config,
87+
prefix=f"{prefix}.c_attn",
8688
)
8789

8890
self.c_proj = RowParallelLinear(
8991
self.hidden_size,
9092
self.hidden_size,
9193
bias=True,
9294
quant_config=quant_config,
95+
prefix=f"{prefix}.c_proj",
9396
)
9497
self.attn = Attention(self.num_heads,
9598
self.head_dim,
@@ -123,6 +126,7 @@ def __init__(
123126
intermediate_size: int,
124127
config: GPTBigCodeConfig,
125128
quant_config: Optional[QuantizationConfig] = None,
129+
prefix: str = "",
126130
):
127131
super().__init__()
128132
hidden_size = config.hidden_size
@@ -131,12 +135,14 @@ def __init__(
131135
intermediate_size,
132136
bias=True,
133137
quant_config=quant_config,
138+
prefix=f"{prefix}.c_fc",
134139
)
135140
self.c_proj = RowParallelLinear(
136141
intermediate_size,
137142
hidden_size,
138143
bias=True,
139144
quant_config=quant_config,
145+
prefix=f"{prefix}.c_proj",
140146
)
141147
self.act = get_act_fn(config.activation_function)
142148

@@ -167,7 +173,10 @@ def __init__(
167173
quant_config,
168174
prefix=f"{prefix}.attn")
169175
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
170-
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
176+
self.mlp = GPTBigMLP(inner_dim,
177+
config,
178+
quant_config,
179+
prefix=f"{prefix}.mlp")
171180

172181
def forward(
173182
self,
@@ -260,7 +269,7 @@ def load_weights(self, weights: Iterable[tuple[str,
260269
weight_loader = getattr(param, "weight_loader",
261270
default_weight_loader)
262271
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
263-
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
272+
if "c_attn.input_scale" in name:
264273
weight_loader(param, loaded_weight, 'q')
265274
weight_loader(param, loaded_weight, 'k')
266275
weight_loader(param, loaded_weight, 'v')
@@ -284,7 +293,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
284293

285294
self.quant_config = quant_config
286295
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
287-
prefix=prefix)
296+
prefix=maybe_prefix(
297+
prefix, "transformer"))
288298
if self.config.tie_word_embeddings:
289299
self.lm_head = self.transformer.wte
290300
else:

0 commit comments

Comments
 (0)