Skip to content

Commit fc9f821

Browse files
authored
fix cross attention (#28346)
Signed-off-by: fsx950223 <[email protected]>
1 parent 9452863 commit fc9f821

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

vllm/v1/attention/backends/triton_attn.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,11 @@ def __init__(
244244

245245
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
246246

247-
if attn_type != AttentionType.DECODER:
247+
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
248248
raise NotImplementedError(
249-
"Encoder self-attention and "
250-
"encoder/decoder cross-attention "
251-
"are not implemented for "
252-
"TritonAttentionImpl"
249+
"Encoder self-attention is not implemented for TritonAttentionImpl"
253250
)
254-
251+
self.attn_type = attn_type
255252
self.fp8_dtype = current_platform.fp8_dtype()
256253

257254
self.sinks = sinks
@@ -312,7 +309,11 @@ def forward(
312309
num_actual_tokens = attn_metadata.num_actual_tokens
313310
key_cache, value_cache = kv_cache.unbind(1)
314311

315-
if self.kv_sharing_target_layer_name is None:
312+
if (
313+
self.kv_sharing_target_layer_name is None
314+
and key is not None
315+
and value is not None
316+
):
316317
# Reshape the input keys and values and store them in the cache.
317318
# Skip this if sharing KV cache with an earlier attention layer.
318319
if self.kv_cache_dtype.startswith("fp8"):
@@ -346,7 +347,7 @@ def forward(
346347
max_seqlen_k = attn_metadata.max_seq_len
347348
block_table = attn_metadata.block_table
348349

349-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
350+
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
350351

351352
unified_attention(
352353
q=query[:num_actual_tokens],

0 commit comments

Comments
 (0)