Skip to content

Commit 43b752c

Browse files
authored
[Llama4] [multimodal] Fix misplaced dtype cast of cos_sin_cache in Llama4VisionRotaryEmbedding (#25889)
Signed-off-by: cjackal <[email protected]>
1 parent cfd302d commit 43b752c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def forward_native( # type: ignore[override]
5959
key: Optional[torch.Tensor] = None,
6060
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
6161
assert key is not None
62-
self._match_cos_sin_cache_dtype(query)
62+
# self.cos_sin_cache here is complex tensor so we cannot cast into
63+
# query's dtype directly with self._match_cos_sin_cache_dtype
64+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
6365
query_ = torch.view_as_complex(query.float().reshape(
6466
*query.shape[:-1], -1, 2))
6567
key_ = torch.view_as_complex(key.float().reshape(

0 commit comments

Comments
 (0)