Skip to content

Commit 4f0f844

Browse files
authored
Fix cuda illegal mem access with Llama4 TP8 + rms_norm custom op (#22701)
Signed-off-by: Po-Han Huang <[email protected]>
1 parent c583038 commit 4f0f844

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,14 @@ def forward(
224224

225225
if self.rotary_emb is not None:
226226
q, k = self.rotary_emb(positions, q, k)
227+
227228
if self.qk_norm is not None:
228-
q = q.reshape(-1, self.num_heads, self.head_dim)
229+
# Normalization is applied on the head_dim dimension. The rest of
230+
# the dimensions are collapsed into a single dimension to support
231+
# custom rms_norm cuda kernel.
232+
q = q.reshape(-1, self.head_dim)
229233
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
230-
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
234+
k = k.reshape(-1, self.head_dim)
231235
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
232236

233237
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)

0 commit comments

Comments
 (0)