@@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
111
111
# Maximum query length in the batch.
112
112
max_query_len : Optional [int ]
113
113
114
- # Number of query tokens for each request in the batch.
115
- # Currently, we require that all requests have the same number of query
116
- # tokens during the decoding phase. When speculavie decoding is enabled,
117
- # decode_query_len might be greater than 1. In all other cases, it is 1.
118
- decode_query_len : Optional [int ]
114
+ # Max number of query tokens among request in the batch.
115
+ max_decode_query_len : Optional [int ]
119
116
120
117
# Maximum sequence length among prefill batch. 0 if there are decoding
121
118
# requests only.
@@ -173,9 +170,9 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
173
170
slot_mapping = self .slot_mapping [:self .num_prefill_tokens ],
174
171
seq_lens = self .seq_lens [:self .num_prefills ],
175
172
seq_lens_tensor = self .seq_lens_tensor [:self .num_prefills ],
176
- decode_query_len = 0 ,
177
173
max_query_len = self .max_query_len ,
178
174
max_prefill_seq_len = self .max_prefill_seq_len ,
175
+ max_decode_query_len = 0 ,
179
176
max_decode_seq_len = 0 ,
180
177
query_start_loc = self .query_start_loc [:self .num_prefills + 1 ],
181
178
seq_start_loc = self .seq_start_loc [:self .num_prefills + 1 ],
@@ -202,12 +199,14 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
202
199
slot_mapping = self .slot_mapping [self .num_prefill_tokens :],
203
200
seq_lens = None ,
204
201
seq_lens_tensor = self .seq_lens_tensor [self .num_prefills :],
205
- decode_query_len = self .decode_query_len ,
202
+ max_decode_query_len = self .max_decode_query_len ,
206
203
max_query_len = self .max_query_len ,
207
204
max_prefill_seq_len = 0 ,
208
205
max_decode_seq_len = self .max_decode_seq_len ,
209
- query_start_loc = None ,
210
- seq_start_loc = None ,
206
+ query_start_loc = self .query_start_loc [self .num_prefills :]
207
+ if self .query_start_loc is not None else None ,
208
+ seq_start_loc = self .seq_start_loc [self .num_prefills :]
209
+ if self .seq_start_loc is not None else None ,
211
210
context_lens_tensor = None ,
212
211
block_tables = self .block_tables [self .num_prefills :],
213
212
use_cuda_graph = self .use_cuda_graph ,
@@ -413,9 +412,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
413
412
max_query_len = max (query_lens )
414
413
decode_query_lens = query_lens [self .num_prefills :]
415
414
if len (decode_query_lens ) > 0 :
416
- decode_query_len = max (decode_query_lens )
415
+ max_decode_query_len = max (decode_query_lens )
417
416
else :
418
- decode_query_len = 1
417
+ max_decode_query_len = 1
419
418
max_prefill_seq_len = max (self .prefill_seq_lens , default = 0 )
420
419
max_decode_seq_len = max (self .curr_seq_lens , default = 0 )
421
420
num_decode_tokens = self .num_decode_tokens
@@ -468,7 +467,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
468
467
seq_lens = seq_lens ,
469
468
seq_lens_tensor = seq_lens_tensor ,
470
469
max_query_len = max_query_len ,
471
- decode_query_len = decode_query_len ,
470
+ max_decode_query_len = max_decode_query_len ,
472
471
max_prefill_seq_len = max_prefill_seq_len ,
473
472
max_decode_seq_len = max_decode_seq_len ,
474
473
query_start_loc = query_start_loc ,
@@ -714,20 +713,37 @@ def unified_flash_attention(
714
713
715
714
if decode_meta := attn_metadata .decode_metadata :
716
715
# Decoding run.
717
- _ , num_head , head_dim = decode_query .shape
718
- decode_query = decode_query .reshape (- 1 , decode_meta .decode_query_len ,
719
- num_head , head_dim )
720
- decode_output = flash_attn_with_kvcache (
721
- q = decode_query ,
722
- k_cache = key_cache ,
723
- v_cache = value_cache ,
724
- block_table = decode_meta .block_tables ,
725
- cache_seqlens = decode_meta .seq_lens_tensor ,
726
- softmax_scale = softmax_scale ,
727
- causal = True ,
728
- alibi_slopes = alibi_slopes ,
729
- softcap = logits_soft_cap ,
730
- ).squeeze (1 )
716
+ # Use flash_attn_varlen_func kernel for speculative decoding
717
+ # because different queries might have different lengths.
718
+ assert decode_meta .max_decode_query_len is not None
719
+ if decode_meta .max_decode_query_len > 1 :
720
+ decode_output = flash_attn_varlen_func (
721
+ q = decode_query ,
722
+ k = key_cache ,
723
+ v = value_cache ,
724
+ cu_seqlens_q = decode_meta .query_start_loc ,
725
+ max_seqlen_q = decode_meta .max_decode_query_len ,
726
+ cu_seqlens_k = decode_meta .seq_start_loc ,
727
+ max_seqlen_k = decode_meta .max_decode_seq_len ,
728
+ softmax_scale = softmax_scale ,
729
+ causal = True ,
730
+ alibi_slopes = alibi_slopes ,
731
+ softcap = logits_soft_cap ,
732
+ block_table = decode_meta .block_tables ,
733
+ )
734
+ else :
735
+ # Use flash_attn_with_kvcache for normal decoding.
736
+ decode_output = flash_attn_with_kvcache (
737
+ q = decode_query .unsqueeze (1 ),
738
+ k_cache = key_cache ,
739
+ v_cache = value_cache ,
740
+ block_table = decode_meta .block_tables ,
741
+ cache_seqlens = decode_meta .seq_lens_tensor ,
742
+ softmax_scale = softmax_scale ,
743
+ causal = True ,
744
+ alibi_slopes = alibi_slopes ,
745
+ softcap = logits_soft_cap ,
746
+ ).squeeze (1 )
731
747
732
748
if prefill_output is None :
733
749
assert decode_output is not None
@@ -739,7 +755,6 @@ def unified_flash_attention(
739
755
# Chunked prefill does not work with speculative decoding.
740
756
# Therefore, the query length for decode should be 1 in chunked prefill.
741
757
assert decode_meta is not None
742
- assert decode_meta .decode_query_len == 1
743
758
decode_output = decode_output .squeeze (1 )
744
759
output = torch .cat ([prefill_output , decode_output ], dim = 0 )
745
760
return output .view (num_tokens , hidden_size )
0 commit comments