Skip to content

Commit 906abe3

Browse files
Fix Llama4 shape mismatch for 32k+ context window (#842)
Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: #680 and #684 --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com>
1 parent 6e2d045 commit 906abe3

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm_gaudi/ops/hpu_fused_moe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def forward_oot(
8989
permuted_weights=True,
9090
activation=layer.activation,
9191
)
92-
return output.view(*(output.size(0), *input_shape[1:]))
92+
if layer.dp_size > 1:
93+
return output.view(*(output.size(0), *input_shape[1:]))
94+
else:
95+
return output.view(*input_shape)
9396

9497

9598
def reduce_output(self, states: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)