@@ -286,6 +286,192 @@ def flash_attn_with_kvcache(
286286 return (out , softmax_lse , * rest ) if return_softmax_lse else out
287287
288288
289+ def flash_attn_with_kvcache_decode (
290+ q ,
291+ k_cache ,
292+ v_cache ,
293+ k = None ,
294+ v = None ,
295+ qv = None ,
296+ rotary_cos = None ,
297+ rotary_sin = None ,
298+ cache_seqlens : Optional [Union [(int , torch .Tensor )]] = None ,
299+ cache_batch_idx : Optional [torch .Tensor ] = None ,
300+ cache_leftpad : Optional [torch .Tensor ] = None ,
301+ page_table : Optional [torch .Tensor ] = None ,
302+ cu_seqlens_q : Optional [torch .Tensor ] = None ,
303+ cu_seqlens_k_new : Optional [torch .Tensor ] = None ,
304+ max_seqlen_q : Optional [int ] = None ,
305+ rotary_seqlens : Optional [torch .Tensor ] = None ,
306+ q_descale : Optional [torch .Tensor ] = None ,
307+ k_descale : Optional [torch .Tensor ] = None ,
308+ v_descale : Optional [torch .Tensor ] = None ,
309+ softmax_scale = None ,
310+ sinks = None ,
311+ causal = False ,
312+ window_size = (- 1 , - 1 ), # -1 means infinite context window
313+ softcap = 0.0 , # 0.0 means deactivated
314+ rotary_interleaved = True ,
315+ scheduler_metadata = None ,
316+ num_splits = 0 , # Can be tuned for speed
317+ pack_gqa = None , # Can be tuned for speed
318+ sm_margin = 0 , # Can be tuned if some SMs are used for communication
319+ return_softmax_lse = False ,
320+ ):
321+ """
322+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
323+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
324+ the previous step, and update them with the new keys/values from the current step, and do
325+ attention with the updated cache, all in 1 kernel.
326+
327+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
328+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
329+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
330+
331+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
332+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
333+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
334+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
335+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
336+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
337+
338+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
339+
340+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
341+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
342+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
343+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
344+
345+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
346+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
347+ 1 1 1 1 0
348+ 1 1 1 1 1
349+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
350+ 0 0
351+ 0 0
352+ 0 0
353+ 1 0
354+ 1 1
355+ If the row of the mask is all zero, the output will be zero.
356+
357+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
358+ will only attend to keys between
359+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
360+
361+ Note: Does not support backward pass.
362+
363+ Arguments:
364+ q: (batch_size, seqlen, nheads, headdim)
365+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
366+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
367+ page_block_size must be a multiple of 256.
368+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
369+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
370+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
371+ k with k_cache, starting at the indices specified by cache_seqlens.
372+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
373+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
374+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
375+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
376+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
377+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
378+ KV cache.
379+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
380+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
381+ If the indices are not distinct, and k and v are provided, the values updated in the cache
382+ might come from any of the duplicate indices.
383+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
384+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
385+ softmax_scale: float. The scaling of QK^T before applying softmax.
386+ Default to 1 / sqrt(headdim).
387+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
388+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
389+ softcap: float. Anything > 0 activates softcapping attention.
390+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
391+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
392+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
393+ (i.e. GPT-NeoX style).
394+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
395+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
396+ to automatically determine the number of splits.
397+ Don't change this unless you know what you are doing.
398+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
399+
400+ Return:
401+ out: (batch_size, seqlen, nheads, headdim).
402+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
403+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
404+ normalization factor).
405+ """
406+ assert k_cache .stride (- 1 ) == 1 , "k_cache must have contiguous last dimension"
407+ assert v_cache .stride (- 1 ) == 1 , "v_cache must have contiguous last dimension"
408+ if softmax_scale is None :
409+ softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (
410+ - 0.5
411+ )
412+ if cache_seqlens is not None and isinstance (cache_seqlens , int ):
413+ cache_seqlens = torch .full (
414+ (k_cache .shape [0 ],), cache_seqlens , dtype = torch .int32 , device = k_cache .device
415+ )
416+ cache_seqlens = maybe_contiguous (cache_seqlens )
417+
418+ q , k_cache , k , v = [maybe_contiguous (x ) for x in (q , k_cache , k , v )]
419+ v_cache = (
420+ v_cache .contiguous ()
421+ if v_cache .stride (- 1 ) != 1 and v_cache .stride (- 3 ) != 1
422+ else v_cache
423+ )
424+ cu_seqlens_q , cu_seqlens_k_new = [
425+ maybe_contiguous (x ) for x in (cu_seqlens_q , cu_seqlens_k_new )
426+ ]
427+ page_table , cache_batch_idx , cache_leftpad = [
428+ maybe_contiguous (x ) for x in (page_table , cache_batch_idx , cache_leftpad )
429+ ]
430+ rotary_cos , rotary_sin = [maybe_contiguous (x ) for x in (rotary_cos , rotary_sin )]
431+ rotary_seqlens = maybe_contiguous (rotary_seqlens )
432+
433+ if cu_seqlens_q == None : # !is_varlen_q
434+ cu_seqlens_q = torch .arange (
435+ 0 , q .size (0 ) + 1 , dtype = torch .int , device = q .device
436+ ) * q .size (1 )
437+ max_seqlen_q = q .size (1 )
438+ q = q .view (- 1 , q .size (- 2 ), q .size (- 1 )).contiguous ()
439+ if cache_seqlens is not None :
440+ assert cache_seqlens .size (0 ) + 1 == cu_seqlens_q .size (0 )
441+ cu_seqlens_k = cache_seqlens
442+ out , softmax_lse , * rest = torch .ops .sgl_kernel .flash_decode .default (
443+ q ,
444+ k_cache ,
445+ v_cache ,
446+ qv ,
447+ cu_seqlens_q ,
448+ cu_seqlens_k ,
449+ max_seqlen_q ,
450+ page_table ,
451+ cache_batch_idx ,
452+ cache_leftpad ,
453+ rotary_cos ,
454+ rotary_sin ,
455+ rotary_seqlens ,
456+ q_descale ,
457+ k_descale ,
458+ v_descale ,
459+ softmax_scale ,
460+ sinks ,
461+ causal ,
462+ window_size [0 ],
463+ window_size [1 ],
464+ softcap ,
465+ rotary_interleaved ,
466+ scheduler_metadata ,
467+ num_splits ,
468+ pack_gqa ,
469+ sm_margin ,
470+ )
471+ # return (out, softmax_lse) if return_softmax_lse else out
472+ return (out , softmax_lse , * rest ) if return_softmax_lse else out
473+
474+
289475def flash_attn_varlen_func (
290476 q ,
291477 k ,
0 commit comments