Skip to content

Commit d2e195a

Browse files
committed
enable causal
1 parent 8a3ddea commit d2e195a

File tree

3 files changed

+9
-15
lines changed

3 files changed

+9
-15
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla
3838
FetchContent_Declare(
3939
repo-cutlass-sycl
4040
GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git
41-
GIT_TAG 742d127cf5ee75cc6db4eac32c8b72f00c53d0fe
41+
GIT_TAG f46ae0df764a1751879ce3e22765c700b1d52eca
4242
GIT_SHALLOW OFF
4343
)
4444
FetchContent_MakeAvailable(repo-cutlass-sycl)

src/sycl/chunked_prefill.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ struct FMHAConfig {
392392
ElementOutput,
393393
GmemTiledCopyStore>;
394394
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::
395-
FlashChunkPrefillSoftmaxEpilogue<Causal, false, EpilogueDispatchPolicy, ElementAccumulator>;
395+
FlashChunkPrefillSoftmaxEpilogue<Causal, LocalMask, EpilogueDispatchPolicy, ElementAccumulator>;
396396

397397
using ProblemShapeRegular = cute::tuple<int, int, int, int, int, int, int, int>;
398398
using namespace cutlass::fmha::collective;
@@ -777,7 +777,7 @@ std::vector<at::Tensor> mha_fwd(
777777
params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
778778
}
779779
} else {
780-
TORCH_CHECK(cu_seqlens_k_new_.has_value(), "If k_new ");
780+
TORCH_CHECK(cu_seqlens_k_new_.has_value(), "cu_seqlens_k_new all zeros");
781781
params.seqlen_knew = 0;
782782
params.total_knew = 0;
783783
at::Tensor cu_seqlens_k_new = cu_seqlens_k_new_.value();

tests/test_flash_attention.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,8 @@ def generate_qkv(
479479
# "causal,local",
480480
# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []),
481481
# )
482-
# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)])
483-
@pytest.mark.parametrize("causal,local", [(False, False)])
482+
@pytest.mark.parametrize("causal,local", [(False, False), (True, False)])
483+
# @pytest.mark.parametrize("causal,local", [(True, False)])
484484
# @pytest.mark.parametrize(
485485
# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
486486
# )
@@ -566,6 +566,8 @@ def test_flash_attn_kvcache(
566566
batch_size = 5
567567
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
568568
nheads = 16
569+
if seqlen_k <= seqlen_q:
570+
seqlen_k += seqlen_q
569571
# nheads = 1
570572
# rotary_dim must be a multiple of 16, and must be <= d
571573
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
@@ -694,17 +696,9 @@ def test_flash_attn_kvcache(
694696
dtype_ref,
695697
)
696698
cache_seqlens = torch.randint(
697-
0 if new_kv else 1,
699+
seqlen_q,
698700
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
699-
(
700-
(
701-
seqlen_k
702-
- (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new)
703-
+ 1
704-
)
705-
if new_kv
706-
else (seqlen_k + 1)
707-
),
701+
seqlen_k,
708702
(batch_size,),
709703
dtype=torch.int32,
710704
device=device,

0 commit comments

Comments
 (0)