File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -224,10 +224,14 @@ def forward(
224
224
225
225
if self .rotary_emb is not None :
226
226
q , k = self .rotary_emb (positions , q , k )
227
+
227
228
if self .qk_norm is not None :
228
- q = q .reshape (- 1 , self .num_heads , self .head_dim )
229
+ # Normalization is applied on the head_dim dimension. The rest of
230
+ # the dimensions are collapsed into a single dimension to support
231
+ # custom rms_norm cuda kernel.
232
+ q = q .reshape (- 1 , self .head_dim )
229
233
q = self .qk_norm (q .float ()).reshape (- 1 , self .q_size ).to (q .dtype )
230
- k = k .reshape (- 1 , self .num_kv_heads , self . head_dim )
234
+ k = k .reshape (- 1 , self .head_dim )
231
235
k = self .qk_norm (k .float ()).reshape (- 1 , self .kv_size ).to (k .dtype )
232
236
233
237
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
You can’t perform that action at this time.
0 commit comments