@@ -148,54 +148,47 @@ def __init__(
148
148
self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
149
149
150
150
if self .q_lora_rank is not None :
151
- self .q_a_proj = ReplicatedLinear (
152
- self .hidden_size ,
153
- self .q_lora_rank ,
154
- bias = False ,
155
- quant_config = quant_config ,
156
- prefix = f"{ prefix } .q_a_proj"
157
- )
151
+ self .q_a_proj = ReplicatedLinear (self .hidden_size ,
152
+ self .q_lora_rank ,
153
+ bias = False ,
154
+ quant_config = quant_config ,
155
+ prefix = f"{ prefix } .q_a_proj" )
158
156
self .q_a_layernorm = RMSNorm (self .q_lora_rank ,
159
157
eps = config .rms_norm_eps )
160
- self .q_b_proj = ColumnParallelLinear (
161
- q_lora_rank ,
162
- self .num_heads * self .qk_head_dim ,
163
- bias = False ,
164
- quant_config = quant_config ,
165
- prefix = f"{ prefix } .q_b_proj"
166
- )
158
+ self .q_b_proj = ColumnParallelLinear (q_lora_rank ,
159
+ self .num_heads *
160
+ self .qk_head_dim ,
161
+ bias = False ,
162
+ quant_config = quant_config ,
163
+ prefix = f"{ prefix } .q_b_proj" )
167
164
else :
168
- self .q_proj = ColumnParallelLinear (
169
- self .hidden_size ,
170
- self .num_heads * self .qk_head_dim ,
171
- bias = False ,
172
- quant_config = quant_config ,
173
- prefix = f"{ prefix } .q_proj"
174
- )
165
+ self .q_proj = ColumnParallelLinear (self .hidden_size ,
166
+ self .num_heads *
167
+ self .qk_head_dim ,
168
+ bias = False ,
169
+ quant_config = quant_config ,
170
+ prefix = f"{ prefix } .q_proj" )
175
171
176
172
self .kv_a_proj_with_mqa = ReplicatedLinear (
177
173
self .hidden_size ,
178
174
self .kv_lora_rank + self .qk_rope_head_dim ,
179
175
bias = False ,
180
176
quant_config = quant_config ,
181
- prefix = f"{ prefix } .kv_a_proj_with_mqa"
182
- )
177
+ prefix = f"{ prefix } .kv_a_proj_with_mqa" )
183
178
self .kv_a_layernorm = RMSNorm (self .kv_lora_rank ,
184
179
eps = config .rms_norm_eps )
185
180
self .kv_b_proj = ColumnParallelLinear (
186
181
self .kv_lora_rank ,
187
182
self .num_heads * (self .qk_nope_head_dim + self .v_head_dim ),
188
183
bias = False ,
189
184
quant_config = quant_config ,
190
- prefix = f"{ prefix } .kv_b_proj"
191
- )
185
+ prefix = f"{ prefix } .kv_b_proj" )
192
186
self .o_proj = CustomDeepseekV2RowParallelLinear (
193
187
self .num_heads * self .v_head_dim ,
194
188
self .hidden_size ,
195
189
bias = False ,
196
190
quant_config = quant_config ,
197
- prefix = f"{ prefix } .o_proj"
198
- )
191
+ prefix = f"{ prefix } .o_proj" )
199
192
200
193
if rope_scaling :
201
194
rope_scaling ["rope_type" ] = 'deepseek_yarn'
0 commit comments