Skip to content

Commit c394772

Browse files
sunjiweiswiftairMeng
authored andcommitted
Update chunked_prefill.cpp piplinestage=2
1 parent 8b0d167 commit c394772

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla
3838
FetchContent_Declare(
3939
repo-cutlass-sycl
4040
GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git
41-
GIT_TAG ab1f4b8ddfd5748e4c00317710cdbcecda58de28
41+
GIT_TAG e02de57e31a20f1c5c7e472aecd322e9196b2792
4242
GIT_SHALLOW OFF
4343
)
4444
FetchContent_MakeAvailable(repo-cutlass-sycl)

cmake/BuildFlags.cmake

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
113113

114114

115115
set(AOT_TARGETS "bmg")
116-
if(TORCH_XPU_ARCH_LIST)
117-
set(AOT_TARGETS "${TORCH_XPU_ARCH_LIST}")
118-
endif()
119-
if(AOT_TARGETS STREQUAL "none")
120-
set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE)
121-
else()
122-
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen)
123-
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
124-
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
125-
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
126-
set(TORCH_XPU_ARCH_LIST ${AOT_TARGETS} PARENT_SCOPE)
127-
endif()
116+
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen)
117+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
118+
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
119+
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
128120
message(STATUS "Compile Intel GPU AOT Targets for ${AOT_TARGETS}")
129121

130122
set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS})

src/sycl/chunked_prefill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ std::vector<at::Tensor> mha_fwd(
863863
at::Tensor out_accum, softmax_lse_accum;
864864
auto outaccum_type = at::ScalarType::Float;
865865

866-
constexpr int PipelineStages = 0;
866+
constexpr int PipelineStages = 2;
867867
if (params.is_causal) {
868868
switch (params.d) {
869869
case 64:

tests/test_flash_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,10 @@ def _generate_block_kvcache(
10221022

10231023

10241024
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
1025+
@pytest.mark.skipif(
1026+
True,
1027+
reason="flash_attn at sgl-kernel-xpu only supports paged cache",
1028+
)
10251029
@pytest.mark.parametrize(
10261030
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
10271031
)

0 commit comments

Comments
 (0)