@@ -244,14 +244,11 @@ def __init__(
244244
245245 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
246246
247- if attn_type != AttentionType .DECODER :
247+ if attn_type not in [ AttentionType .DECODER , AttentionType . ENCODER_DECODER ] :
248248 raise NotImplementedError (
249- "Encoder self-attention and "
250- "encoder/decoder cross-attention "
251- "are not implemented for "
252- "TritonAttentionImpl"
249+ "Encoder self-attention is not implemented for TritonAttentionImpl"
253250 )
254-
251+ self . attn_type = attn_type
255252 self .fp8_dtype = current_platform .fp8_dtype ()
256253
257254 self .sinks = sinks
@@ -312,7 +309,11 @@ def forward(
312309 num_actual_tokens = attn_metadata .num_actual_tokens
313310 key_cache , value_cache = kv_cache .unbind (1 )
314311
315- if self .kv_sharing_target_layer_name is None :
312+ if (
313+ self .kv_sharing_target_layer_name is None
314+ and key is not None
315+ and value is not None
316+ ):
316317 # Reshape the input keys and values and store them in the cache.
317318 # Skip this if sharing KV cache with an earlier attention layer.
318319 if self .kv_cache_dtype .startswith ("fp8" ):
@@ -346,7 +347,7 @@ def forward(
346347 max_seqlen_k = attn_metadata .max_seq_len
347348 block_table = attn_metadata .block_table
348349
349- descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key .shape [1 ])
350+ descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key_cache .shape [2 ])
350351
351352 unified_attention (
352353 q = query [:num_actual_tokens ],
0 commit comments