45
45
46
46
from .interfaces import SupportsLoRA , SupportsPP
47
47
from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
48
- make_empty_intermediate_tensors_factory , make_layers )
48
+ make_empty_intermediate_tensors_factory , make_layers ,
49
+ maybe_prefix )
49
50
50
51
51
52
class GPTBigCodeAttention (nn .Module ):
@@ -83,13 +84,15 @@ def __init__(
83
84
total_num_kv_heads ,
84
85
bias = True ,
85
86
quant_config = quant_config ,
87
+ prefix = f"{ prefix } .c_attn" ,
86
88
)
87
89
88
90
self .c_proj = RowParallelLinear (
89
91
self .hidden_size ,
90
92
self .hidden_size ,
91
93
bias = True ,
92
94
quant_config = quant_config ,
95
+ prefix = f"{ prefix } .c_proj" ,
93
96
)
94
97
self .attn = Attention (self .num_heads ,
95
98
self .head_dim ,
@@ -123,6 +126,7 @@ def __init__(
123
126
intermediate_size : int ,
124
127
config : GPTBigCodeConfig ,
125
128
quant_config : Optional [QuantizationConfig ] = None ,
129
+ prefix : str = "" ,
126
130
):
127
131
super ().__init__ ()
128
132
hidden_size = config .hidden_size
@@ -131,12 +135,14 @@ def __init__(
131
135
intermediate_size ,
132
136
bias = True ,
133
137
quant_config = quant_config ,
138
+ prefix = f"{ prefix } .c_fc" ,
134
139
)
135
140
self .c_proj = RowParallelLinear (
136
141
intermediate_size ,
137
142
hidden_size ,
138
143
bias = True ,
139
144
quant_config = quant_config ,
145
+ prefix = f"{ prefix } .c_proj" ,
140
146
)
141
147
self .act = get_act_fn (config .activation_function )
142
148
@@ -167,7 +173,10 @@ def __init__(
167
173
quant_config ,
168
174
prefix = f"{ prefix } .attn" )
169
175
self .ln_2 = nn .LayerNorm (hidden_size , eps = config .layer_norm_epsilon )
170
- self .mlp = GPTBigMLP (inner_dim , config , quant_config )
176
+ self .mlp = GPTBigMLP (inner_dim ,
177
+ config ,
178
+ quant_config ,
179
+ prefix = f"{ prefix } .mlp" )
171
180
172
181
def forward (
173
182
self ,
@@ -260,7 +269,7 @@ def load_weights(self, weights: Iterable[tuple[str,
260
269
weight_loader = getattr (param , "weight_loader" ,
261
270
default_weight_loader )
262
271
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
263
- if "c_attn.input_scale" in name or "c_attn.weight_scale" in name :
272
+ if "c_attn.input_scale" in name :
264
273
weight_loader (param , loaded_weight , 'q' )
265
274
weight_loader (param , loaded_weight , 'k' )
266
275
weight_loader (param , loaded_weight , 'v' )
@@ -284,7 +293,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
284
293
285
294
self .quant_config = quant_config
286
295
self .transformer = GPTBigCodeModel (vllm_config = vllm_config ,
287
- prefix = prefix )
296
+ prefix = maybe_prefix (
297
+ prefix , "transformer" ))
288
298
if self .config .tie_word_embeddings :
289
299
self .lm_head = self .transformer .wte
290
300
else :
0 commit comments