We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
cos_sin_cache
Llama4VisionRotaryEmbedding
1 parent cfd302d commit 43b752cCopy full SHA for 43b752c
vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
@@ -59,7 +59,9 @@ def forward_native( # type: ignore[override]
59
key: Optional[torch.Tensor] = None,
60
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
61
assert key is not None
62
- self._match_cos_sin_cache_dtype(query)
+ # self.cos_sin_cache here is complex tensor so we cannot cast into
63
+ # query's dtype directly with self._match_cos_sin_cache_dtype
64
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
65
query_ = torch.view_as_complex(query.float().reshape(
66
*query.shape[:-1], -1, 2))
67
key_ = torch.view_as_complex(key.float().reshape(
0 commit comments