Skip to content

Commit a094121

Browse files
committed
fix device lost
1 parent 39c7bd0 commit a094121

File tree

6 files changed

+150
-209
lines changed

6 files changed

+150
-209
lines changed

include/sgl_flash_kernel_ops.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,13 @@ std::vector<at::Tensor> mha_fwd(
5353
std::optional<const at::Tensor>&
5454
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
5555
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
56-
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
5756
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
5857
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
5958
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
60-
std::optional<const at::Tensor>&
61-
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
62-
std::optional<const at::Tensor>&
63-
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
6459
std::optional<int> max_seqlen_q_,
65-
// TODO: check if we need max_seqlen_k
6660
std::optional<int> max_seqlen_k_,
6761
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
62+
std::optional<const at::Tensor>& num_pages_, // (b_k, )
6863
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
6964
std::optional<const at::Tensor>& leftpad_k_, // b
7065
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ exclude = [
3131

3232
[tool.scikit-build]
3333
cmake.build-type = "Release"
34+
build-dir = "build"
3435
minimum-version = "build-system.requires"
3536

3637
wheel.py-api = "cp39"

python/sgl_kernel/flash_attn.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,22 +177,35 @@ def flash_attn_with_kvcache(
177177
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
178178
rotary_seqlens = maybe_contiguous(rotary_seqlens)
179179

180+
if cu_seqlens_q == None: # !is_varlen_q
181+
cu_seqlens_q = torch.arange(0, q.size(0)+1, dtype=torch.int, device=q.device) * q.size(1)
182+
max_seqlen_q = q.size(1)
183+
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
184+
if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new
185+
cu_seqlens_k_new = torch.arange(0, k.size(0)+1, dtype=torch.int, device=k.device)
186+
elif k is None:
187+
cu_seqlens_k_new = torch.zeros_like(cu_seqlens_q, dtype=torch.int32, device=q.device)
188+
if cache_seqlens is not None:
189+
max_seqlen_k = cache_seqlens.max().item()
190+
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
191+
page_size = k_cache.size(1)
192+
num_pages_per_seq = (cache_seqlens + page_size - 1) // page_size
193+
cu_seqlens_k = torch.concat((torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), torch.cumsum(cache_seqlens, 0))).to(torch.int32)
194+
180195
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
181196
q,
182197
k_cache,
183198
v_cache,
184199
k,
185200
v,
186201
qv,
187-
None, # out
188202
cu_seqlens_q,
189-
None, # cu_seqlens_k
203+
cu_seqlens_k,
190204
cu_seqlens_k_new,
191-
None, # seqused_q
192-
cache_seqlens,
193205
max_seqlen_q,
194-
None, # max_seqlen_k
206+
max_seqlen_k,
195207
page_table,
208+
num_pages_per_seq,
196209
cache_batch_idx,
197210
cache_leftpad,
198211
rotary_cos,

0 commit comments

Comments
 (0)