Skip to content

Commit 489e3a5

Browse files
committed
fix c++ code compile
1 parent 809401f commit 489e3a5

File tree

9 files changed

+1084
-79
lines changed

9 files changed

+1084
-79
lines changed

include/sgl_flash_kernel_ops.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,36 @@ std::vector<at::Tensor> mha_fwd(
7272
int num_splits,
7373
std::optional<bool> pack_gqa_,
7474
int const sm_margin);
75+
/*
76+
* From flash-attention
77+
*/
78+
std::vector<at::Tensor> flash_decode(
79+
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
80+
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
81+
// h_k, d) if there is page_table.
82+
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
83+
// page_size, h_k, dv) if there is page_table.
84+
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
85+
const at::Tensor& cu_seqlens_q, // b+1
86+
const at::Tensor& cu_seqlens_k, // b+1
87+
int max_seqlen_q,
88+
const at::Tensor& page_table,
89+
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
90+
std::optional<const at::Tensor>& leftpad_k_, // b
91+
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
92+
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
93+
std::optional<const at::Tensor>& seqlens_rotary_, // b
94+
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
95+
std::optional<at::Tensor>& k_descale_, // (b, h_k)
96+
std::optional<at::Tensor>& v_descale_, // (b, h_k)
97+
float const softmax_scale,
98+
std::optional<const at::Tensor>& sinks,
99+
bool is_causal,
100+
int window_size_left,
101+
int window_size_right,
102+
float const softcap,
103+
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
104+
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
105+
int num_splits,
106+
std::optional<bool> pack_gqa_,
107+
int const sm_margin);

python/sgl_kernel/flash_attn.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,192 @@ def flash_attn_with_kvcache(
286286
return (out, softmax_lse, *rest) if return_softmax_lse else out
287287

288288

289+
def flash_attn_with_kvcache_decode(
290+
q,
291+
k_cache,
292+
v_cache,
293+
k=None,
294+
v=None,
295+
qv=None,
296+
rotary_cos=None,
297+
rotary_sin=None,
298+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
299+
cache_batch_idx: Optional[torch.Tensor] = None,
300+
cache_leftpad: Optional[torch.Tensor] = None,
301+
page_table: Optional[torch.Tensor] = None,
302+
cu_seqlens_q: Optional[torch.Tensor] = None,
303+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
304+
max_seqlen_q: Optional[int] = None,
305+
rotary_seqlens: Optional[torch.Tensor] = None,
306+
q_descale: Optional[torch.Tensor] = None,
307+
k_descale: Optional[torch.Tensor] = None,
308+
v_descale: Optional[torch.Tensor] = None,
309+
softmax_scale=None,
310+
sinks=None,
311+
causal=False,
312+
window_size=(-1, -1), # -1 means infinite context window
313+
softcap=0.0, # 0.0 means deactivated
314+
rotary_interleaved=True,
315+
scheduler_metadata=None,
316+
num_splits=0, # Can be tuned for speed
317+
pack_gqa=None, # Can be tuned for speed
318+
sm_margin=0, # Can be tuned if some SMs are used for communication
319+
return_softmax_lse=False,
320+
):
321+
"""
322+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
323+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
324+
the previous step, and update them with the new keys/values from the current step, and do
325+
attention with the updated cache, all in 1 kernel.
326+
327+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
328+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
329+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
330+
331+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
332+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
333+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
334+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
335+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
336+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
337+
338+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
339+
340+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
341+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
342+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
343+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
344+
345+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
346+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
347+
1 1 1 1 0
348+
1 1 1 1 1
349+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
350+
0 0
351+
0 0
352+
0 0
353+
1 0
354+
1 1
355+
If the row of the mask is all zero, the output will be zero.
356+
357+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
358+
will only attend to keys between
359+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
360+
361+
Note: Does not support backward pass.
362+
363+
Arguments:
364+
q: (batch_size, seqlen, nheads, headdim)
365+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
366+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
367+
page_block_size must be a multiple of 256.
368+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
369+
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
370+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
371+
k with k_cache, starting at the indices specified by cache_seqlens.
372+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
373+
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
374+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
375+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
376+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
377+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
378+
KV cache.
379+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
380+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
381+
If the indices are not distinct, and k and v are provided, the values updated in the cache
382+
might come from any of the duplicate indices.
383+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
384+
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
385+
softmax_scale: float. The scaling of QK^T before applying softmax.
386+
Default to 1 / sqrt(headdim).
387+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
388+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
389+
softcap: float. Anything > 0 activates softcapping attention.
390+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
391+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
392+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
393+
(i.e. GPT-NeoX style).
394+
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
395+
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
396+
to automatically determine the number of splits.
397+
Don't change this unless you know what you are doing.
398+
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
399+
400+
Return:
401+
out: (batch_size, seqlen, nheads, headdim).
402+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
403+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
404+
normalization factor).
405+
"""
406+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
407+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
408+
if softmax_scale is None:
409+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
410+
-0.5
411+
)
412+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
413+
cache_seqlens = torch.full(
414+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
415+
)
416+
cache_seqlens = maybe_contiguous(cache_seqlens)
417+
418+
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
419+
v_cache = (
420+
v_cache.contiguous()
421+
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
422+
else v_cache
423+
)
424+
cu_seqlens_q, cu_seqlens_k_new = [
425+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
426+
]
427+
page_table, cache_batch_idx, cache_leftpad = [
428+
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
429+
]
430+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
431+
rotary_seqlens = maybe_contiguous(rotary_seqlens)
432+
433+
if cu_seqlens_q == None: # !is_varlen_q
434+
cu_seqlens_q = torch.arange(
435+
0, q.size(0) + 1, dtype=torch.int, device=q.device
436+
) * q.size(1)
437+
max_seqlen_q = q.size(1)
438+
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
439+
if cache_seqlens is not None:
440+
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
441+
cu_seqlens_k = cache_seqlens
442+
out, softmax_lse, *rest = torch.ops.sgl_kernel.flash_decode.default(
443+
q,
444+
k_cache,
445+
v_cache,
446+
qv,
447+
cu_seqlens_q,
448+
cu_seqlens_k,
449+
max_seqlen_q,
450+
page_table,
451+
cache_batch_idx,
452+
cache_leftpad,
453+
rotary_cos,
454+
rotary_sin,
455+
rotary_seqlens,
456+
q_descale,
457+
k_descale,
458+
v_descale,
459+
softmax_scale,
460+
sinks,
461+
causal,
462+
window_size[0],
463+
window_size[1],
464+
softcap,
465+
rotary_interleaved,
466+
scheduler_metadata,
467+
num_splits,
468+
pack_gqa,
469+
sm_margin,
470+
)
471+
# return (out, softmax_lse) if return_softmax_lse else out
472+
return (out, softmax_lse, *rest) if return_softmax_lse else out
473+
474+
289475
def flash_attn_varlen_func(
290476
q,
291477
k,

src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ATen XPU sources
22

3-
file(GLOB device_cpp "sycl/*.cpp" "sycl/*.sycl")
4-
file(GLOB host_cpp "./*.cpp" "./*.cc")
3+
file(GLOB device_cpp "sycl/*.cpp" "sycl/kernels/flash_attention/*.cpp" "sycl/*.sycl")
4+
file(GLOB host_cpp "./*.cpp" "sycl/kernels/flash_attention/*.cpp" "./*.cc")
55

66
list(APPEND ATen_XPU_CPP_SRCS ${host_cpp})
77
list(APPEND ATen_XPU_SYCL_SRCS ${device_cpp})

0 commit comments

Comments
 (0)