From b46c6b0c6d8e35adee4f30ecba82f4d6a4e4c1ab Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 29 Jul 2025 04:26:13 +0800 Subject: [PATCH 01/25] initialize Cutlass support Add chunked prefill op --- CMakeLists.txt | 31 +- cmake/BuildFlags.cmake | 9 +- src/sycl/chunked_prefill.cpp | 1032 ++++++++++++++++++++++++++++++++++ src/sycl/cutlass_helper.hpp | 80 +++ 4 files changed, 1139 insertions(+), 13 deletions(-) create mode 100644 src/sycl/chunked_prefill.cpp create mode 100644 src/sycl/cutlass_helper.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 44cd62b..d242569 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19.2) project(sgl_kernel) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) # Torch find_package(Python3 COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED) @@ -27,20 +27,29 @@ include(${SGL_OPS_XPU_ROOT}/cmake/BuildFlags.cmake) include(FetchContent) -# # cutlass -# FetchContent_Declare( -# repo-cutlass-sycl -# GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git -# GIT_TAG ef9797f4327886ad231bfe853099ca022060c293 -# GIT_SHALLOW OFF -# ) -# FetchContent_Populate(repo-cutlass-sycl) +# SYCL support in cutlass +add_compile_definitions(CUTLASS_ENABLE_SYCL) +add_compile_definitions(SYCL_INTEL_TARGET) +set(CUTLASS_ENABLE_SYCL ON CACHE BOOL "Enable SYCL in the cutlass" FORCE) +set(CUTLASS_ENABLE_BENCHMARKS OFF CACHE BOOL "Remove benchmark to avoid cmake version issue in google benchmark" FORCE) +set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutlass" FORCE) + +# cutlass +FetchContent_Declare( + repo-cutlass-sycl + GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git + GIT_TAG b15779b2d99bc392bfaa8209d547fbb8b2f5c807 + GIT_SHALLOW OFF +) +FetchContent_MakeAvailable(repo-cutlass-sycl) + include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/src - # ${repo-cutlass-sycl_SOURCE_DIR}/include - # ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include + ${repo-cutlass-sycl_SOURCE_DIR}/include + ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include + ${repo-cutlass-sycl_SOURCE_DIR}/applications ) add_subdirectory(${SGL_OPS_XPU_ROOT}/src) diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index a8d46b2..d2104dd 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -26,7 +26,7 @@ endfunction() if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # # -- Host flags (SYCL_CXX_FLAGS) list(APPEND SYCL_HOST_FLAGS -fPIC) - list(APPEND SYCL_HOST_FLAGS -std=c++17) + list(APPEND SYCL_HOST_FLAGS -std=c++20) # SYCL headers warnings list(APPEND SYCL_HOST_FLAGS -Wno-deprecated-declarations) list(APPEND SYCL_HOST_FLAGS -Wno-deprecated) @@ -71,6 +71,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-approx-func) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Wno-absolute-value) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz) + set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code) + set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator) + set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier) if(CMAKE_BUILD_TYPE MATCHES Debug) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization) @@ -116,7 +119,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(AOT_TARGETS STREQUAL "none") set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE) else() - set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen,spir64) + set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION}) set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION}) set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}") @@ -126,6 +129,8 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS}) + # set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} -fno-sycl-instrument-device-code) + # set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} ${SYCL_LINK_FLAGS}) set(SYCL_OFFLINE_COMPILER_FLAGS "${SYCL_OFFLINE_COMPILER_AOT_OPTIONS}${SYCL_OFFLINE_COMPILER_CG_OPTIONS}") else() message("Not compiling with XPU. Currently only support GCC compiler on Linux as CXX compiler.") diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp new file mode 100644 index 0000000..d28bc56 --- /dev/null +++ b/src/sycl/chunked_prefill.cpp @@ -0,0 +1,1032 @@ +#include +#include +#include +#include + +#include +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/sycl_event_manager.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" +#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" + +using namespace cute; + +struct Flash_fwd_params { + using index_t = int64_t; + + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t v_dim_stride; + + // The number of heads. + int h, h_k; + + // The O matrix (output). + void* __restrict__ o_ptr; + void* __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int total_q, total_k, total_knew; + int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim + + // The scaling factors for the kernel. + float scale_softmax; + float softcap; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q; + int* __restrict__ cu_seqlens_k; + int* __restrict__ cu_seqlens_knew; + int* __restrict__ leftpad_k; + + // If provided, the actual length of each q/k sequence. + int* __restrict__ seqused_q; + int* __restrict__ seqused_k; + + // The stride between rows of Oaccum. + index_t oaccum_split_stride; + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + + // The stride between rows of LSEaccum. + index_t lseaccum_split_stride; + index_t lseaccum_batch_stride; + index_t lseaccum_head_stride; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr; + void* __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + void* __restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr; + void* __restrict__ rotary_sin_ptr; + int* __restrict__ seqlens_rotary; + + // The indices to index into the KV cache. + int* __restrict__ kv_batch_idx; + + // Paged KV cache + int* __restrict__ page_table; + index_t page_table_batch_stride; + int page_size; + int num_pages; + bool pagedkv_tma; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + + // Local window size + int window_size_left, window_size_right; + int attention_chunk; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t* rng_state; + + bool is_bf16; + bool is_fp32; + bool is_e4m3; + bool is_causal; + bool is_local; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + bool pack_gqa; + + int* __restrict__ tile_count_semaphore; + // int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; + int* __restrict__ num_splits_dynamic_ptr; + bool skip_scheduler_metadata_computation; + + int arch; + int num_sm; +}; + +template +class KernelCur {}; + +// Flash Attention takes 3 input matrices: (K)eys, (Q)ueries and (V)alues. +using LayoutQ = cutlass::layout::RowMajor; +using LayoutK = cutlass::layout::ColumnMajor; +using LayoutV = cutlass::layout::RowMajor; +using LayoutO = cutlass::layout::RowMajor; + +template +struct ExampleRunner { + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_K_cache; + cutlass::DeviceAllocation block_V_cache; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_ref_O; + + std::vector cumulative_seqlen_q; + std::vector cumulative_seqlen_kv; + std::vector cumulative_seqlen_kv_cache; + cutlass::DeviceAllocation device_cumulative_seqlen_q; + cutlass::DeviceAllocation device_cumulative_seqlen_kv; + cutlass::DeviceAllocation device_cumulative_seqlen_kv_cache; + + struct PagedKVParams { + cutlass::DeviceAllocation page_table; + int page_size = 0; + cutlass::DeviceAllocation num_pages_per_seq; + }; + PagedKVParams paged_kv_cache; + + template + auto initialize_varlen(const Flash_fwd_params& params, ProblemShape& problem_size) { + // Use Cacheline Size to calculate alignment + constexpr int cacheline_bytes = 64; + constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements + constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements + + cumulative_seqlen_q = {0}; + cumulative_seqlen_kv = {0}; + cumulative_seqlen_kv_cache = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int total_seqlen_kv_cache = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + int max_seqlen_kv_cache = 0; + + ProblemShape problem_size_for_init = problem_size; + get<0>(problem_size_for_init) = 1; + get<3>(problem_size_for_init) = params.total_q; + get<4>(problem_size_for_init) = params.total_knew; + get<5>(problem_size_for_init) = params.total_k; + + ProblemShapeType problem_size_for_launch; + + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.total_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.total_knew}; + get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.total_k}; + get<6>(problem_size_for_launch) = get<6>(problem_size); + get<7>(problem_size_for_launch) = get<7>(problem_size); + get<0>(problem_size_for_launch) = get<0>(problem_size); + get<1>(problem_size_for_launch) = get<1>(problem_size); + get<2>(problem_size_for_launch) = get<2>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + ProblemShapeType initialize(const Flash_fwd_params& params) { + auto problem_shape_in = cute::make_tuple( + params.b, // batch + params.h, // num_heads_q + params.h_k, // num_heads_kv + params.seqlen_q, + params.seqlen_knew, + params.seqlen_k, + params.d, + params.dv); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (isVarLen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(params, problem_shape_in); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = + problem_size; + auto group_q_size = num_heads_q / num_heads_kv; + ; + auto group_q_num = num_heads_q / group_q_size; + + // stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo * group_q_size, head_size_qk, + // batch * group_q_num)); + stride_Q = + cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = + cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = + cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride( + StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride( + StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv)); + // stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, head_size_vo, + // batch * group_q_num)); + stride_O = cutlass::make_cute_packed_stride( + StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); + + block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); + block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); + block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); + block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); + block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); + block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + + if constexpr (isVarLen) { + get<3>(problem_shape).cumulative_length = params.cu_seqlens_q; + get<4>(problem_shape).cumulative_length = params.cu_seqlens_knew; + get<5>(problem_shape).cumulative_length = params.cu_seqlens_k; + } + + return problem_shape; + } + + // Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this + // secondary `run` function is required to launch the kernel. + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + syclcompat::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size}; + syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + // auto event = syclcompat::experimental::launch>(policy, params); + + sycl::ext::oneapi::experimental::launch_config config(policy.get_range(), policy.get_launch_properties()); + auto cgf = [&](::sycl::handler& cgh) { + auto KernelFunctor = + syclcompat::experimental::detail::build_kernel_functor>( + cgh, policy, params); + sycl::ext::oneapi::experimental::detail:: + LaunchConfigAccess, decltype(policy.get_launch_properties())> + ConfigAccess(config); + cgh.parallel_for>( + ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); + }; + auto q = syclcompat::get_default_queue(); + auto event = q.submit(cgf); + + EventManager::getInstance().addEvent(event); + } + + cutlass::Status run(const Flash_fwd_params& params, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = initialize(params); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_Q.get(), + stride_Q, + block_K.get(), + stride_K, + block_V.get(), + stride_V, + block_K_cache.get(), + stride_K_cache, + block_V_cache.get(), + stride_V_cache, + params.page_table, + params.page_size, + params.cu_seqlens_k}, + {params.scale_softmax}, + {block_O.get(), stride_O}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + (FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get())); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params_kernel); + return cutlass::Status::kSuccess; + } +}; + +// the default value used for the case BF16 +template < + bool Causal, + typename TileShapeQK, + typename TileShapePV, + typename TileShapeOutput, + typename SubgroupLayout, + int PipelineStages, + typename ElementInputQ = bfloat16_t, + typename ElementInputKV = bfloat16_t, + typename MMAOperation = XE_8x16x16_F32BF16BF16F32_TT, + typename GmemTiledCopyQ = XE_2D_U16x8x32_LD_N, + typename GmemTiledCopyK = XE_2D_U16x16x16_LD_T, // _T designates a transposed block load operation + typename GmemTiledCopyV = XE_2D_U16x16x32_LD_V, + typename ElementAccumulator = float, + typename ElementComputeEpilogue = float, + typename ElementOutput = float, + typename GmemTiledCopyStore = XE_2D_U32x8x16_ST_N> +struct FMHAConfig { + template + static int run(const Flash_fwd_params& params) { + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + EpilogueDispatchPolicy, + MMAOperation, + TileShapeOutput, + SubgroupLayout, + ElementComputeEpilogue, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective:: + FlashChunkPrefillSoftmaxEpilogue; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, + ProblemShapeType, + ElementInputQ, + cutlass::gemm::TagToStrideA_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + MMAOperation, + TileShapeQK, + TileShapePV, + SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, + PagedKV>; + + using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk< + ProblemShapeType, + CollectiveMainloop, + CollectiveSoftmaxEpilogue, + CollectiveEpilogue, + Scheduler>; + + ExampleRunner runner; + + (runner.run(params, hw_info)); + return 0; + } + + static int run(const Flash_fwd_params& params) { + // only support varlen and paged kv now + return run(params); + } +}; + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +inline int round_up_headdim(int head_size) { + if (head_size <= 64) { + return 64; + } + if (head_size <= 96) { + return 96; + } + if (head_size <= 128) { + return 128; + } + if (head_size <= 192) { + return 192; + } + if (head_size <= 256) { + return 256; + } + return 256; +} + +std::vector mha_fwd( + at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& + k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional& + v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional& cu_seqlens_q_, // b+1 + std::optional& cu_seqlens_k_, // b+1 + std::optional& cu_seqlens_k_new_, // b+1 + std::optional& + seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional& + seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + // std::optional &q_descale_, // (b, h_k), not (b, h) + // std::optional &k_descale_, // (b, h_k) + // std::optional &v_descale_, // (b, h_k) + std::optional softmax_scale_, + bool is_causal, + int window_size_left, + int window_size_right, + int attention_chunk, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_splits, + std::optional pack_gqa_, + int const sm_margin) { + // TODO: check GPU support + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // TORCH_CHECK(drops->name.find("B580") != std::string::npos, "sgl_kernel_xpu only supports BMG+"); + + auto q_type = q.scalar_type(); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "SGL Kernel XPU only supports fp16 and bf16 type"); + + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); + CHECK_DEVICE(k); + CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor page_table; + const bool paged_KV = page_table_.has_value(); + if (paged_KV) { + page_table = page_table_.value(); + CHECK_DEVICE(page_table); + TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + } + at::Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + at::Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); + CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + } + + auto const sizes = q.sizes(); + const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; + int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); + int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = + !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + + // Currently only support head dims <= 256 + static constexpr int max_headdim = 256; + TORCH_CHECK( + head_size <= max_headdim, + "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + if (window_size_right >= seqlen_q - 1) { + window_size_right = -1; + } + // causal=true is the same as causal=false in this case + if (is_causal) { + window_size_right = 0; + } + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.has_value()) { + auto seqused_q = seqused_q_.value(); + TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); + CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()) { + auto seqused_k = seqused_k_.value(); + TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); + CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + bool const is_varlen = + is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + + static constexpr int alignment = 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + auto out_type = q_type; + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK( + out.scalar_type() == out_type, + "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + } + } else { + out = !is_varlen_q ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + c10::DeviceGuard device_guard(q.device()); + + at::Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + } else { + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + } + + Flash_fwd_params params; + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (!is_varlen_q) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (!is_varlen_k) { + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + + params.cu_seqlens_q = !is_varlen_q ? nullptr : static_cast(cu_seqlens_q.data_ptr()); + params.cu_seqlens_k = !is_varlen_k ? nullptr : static_cast(cu_seqlens_k.data_ptr()); + params.seqused_q = seqused_q_.has_value() ? static_cast(seqused_q_.value().data_ptr()) : nullptr; + params.seqused_k = seqused_k_.has_value() ? static_cast(seqused_k_.value().data_ptr()) : nullptr; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { + window_size_left = seqlen_k - 1; + } + if (window_size_right < 0) { + window_size_right = seqlen_q - 1; + } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + if (paged_KV) { + params.page_table = page_table.data_ptr(); + params.page_table_batch_stride = page_table.stride(0); + } + params.page_size = page_size; + params.num_pages = num_pages; + + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + at::Tensor k_new, v_new; + TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + at::Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.value(); + CHECK_DEVICE(cu_seqlens_k_new); + CHECK_CONTIGUOUS(cu_seqlens_k_new); + TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); + } + k_new = k_new_.value(); + v_new = v_new_.value(); + TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); + TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); + CHECK_DEVICE(k_new); + CHECK_DEVICE(v_new); + TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1) : k_new.size(0); + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + params.seqlen_knew = seqlen_k_new; + params.total_knew = total_k_new; + params.knew_ptr = k_new.data_ptr(); + params.vnew_ptr = v_new.data_ptr(); + // All stride are in elements, not bytes. + params.knew_row_stride = k_new.stride(-3); + params.vnew_row_stride = v_new.stride(-3); + params.knew_head_stride = k_new.stride(-2); + params.vnew_head_stride = v_new.stride(-2); + if (!is_varlen_k_new) { + params.knew_batch_stride = k_new.stride(0); + params.vnew_batch_stride = v_new.stride(0); + } + if (is_varlen_k_new) { + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + } + } + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + + if (rotary_cos_.has_value()) { + TORCH_CHECK( + k_new_.has_value(), + "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + if (paged_KV) { + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); + CHECK_CONTIGUOUS(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); + CHECK_CONTIGUOUS(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + at::Tensor out_accum, softmax_lse_accum; + auto outaccum_type = at::ScalarType::Float; + + constexpr int PipelineStages = 2; + if (params.is_causal) { + switch (params.d) { + case 64: + FMHAConfig< + true, + Shape<_128, _64, _64>, + Shape<_128, _32, _64>, + Shape<_128, _64, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + case 96: + FMHAConfig< + true, + Shape<_128, _64, _32>, + Shape<_128, _32, _64>, + Shape<_128, _96, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + case 128: + FMHAConfig< + true, + Shape<_128, _64, _64>, + Shape<_128, _32, _64>, + Shape<_128, _128, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + case 192: + FMHAConfig< + true, + Shape<_256, _64, _64>, + Shape<_256, _32, _64>, + Shape<_256, _192, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + default: + TORCH_CHECK(false, "Unsupported head size for causal attention"); + } + } else { + switch (params.d) { + case 64: + FMHAConfig< + false, + Shape<_128, _64, _64>, + Shape<_128, _32, _64>, + Shape<_128, _64, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + case 96: + FMHAConfig< + false, + Shape<_128, _64, _32>, + Shape<_128, _32, _64>, + Shape<_128, _96, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + case 128: + FMHAConfig< + false, + Shape<_128, _64, _64>, + Shape<_128, _32, _64>, + Shape<_128, _128, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + case 192: + FMHAConfig< + false, + Shape<_256, _64, _64>, + Shape<_256, _32, _64>, + Shape<_256, _192, _64>, + Layout, Stride<_1, _1, _1>>, + PipelineStages>::run(params); + break; + default: + TORCH_CHECK(false, "Unsupported head size for causal attention"); + } + } + // return {out, softmax_lse}; + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} diff --git a/src/sycl/cutlass_helper.hpp b/src/sycl/cutlass_helper.hpp new file mode 100644 index 0000000..2501b39 --- /dev/null +++ b/src/sycl/cutlass_helper.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include +#include "sycl/ext/oneapi/experimental/enqueue_functions.hpp" +#include "sycl/ext/oneapi/properties/properties.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace sycl_exp = sycl::ext::oneapi::experimental; + +template +struct KernelFunctor { + KernelFunctor(KProps kernel_props, Args... args) + : _kernel_properties{kernel_props}, + _argument_tuple(std::make_tuple(args...)) {} + + KernelFunctor(KProps kernel_props, sycl::local_accessor local_acc, + Args... args) + : _kernel_properties{kernel_props}, _local_acc{local_acc}, + _argument_tuple(std::make_tuple(args...)) {} + + auto get(sycl_exp::properties_tag) const { return _kernel_properties; } + + __syclcompat_inline__ void + operator()(syclcompat::detail::range_to_item_t) const { + if constexpr (HasLocalMem) { + char *local_mem_ptr = static_cast( + _local_acc.template get_multi_ptr() + .get()); + apply_helper( + [lmem_ptr = local_mem_ptr](auto &&...args) { + [[clang::always_inline]] F(args..., lmem_ptr); + }, + _argument_tuple); + } else { + apply_helper([](auto &&...args) { [[clang::always_inline]] F(args...); }, + _argument_tuple); + } + } + + KProps _kernel_properties; + std::tuple _argument_tuple; + std::conditional_t, std::monostate> + _local_acc; // monostate for empty type +}; + + +template +sycl::event launch(LaunchPolicy launch_policy, sycl::queue q, Args... args) { + static_assert(syclcompat::args_compatible, + "Mismatch between device function signature and supplied " + "arguments. Have you correctly handled local memory/char*?"); + + sycl_exp::launch_config config(launch_policy.get_range(), + launch_policy.get_launch_properties()); + + return sycl_exp::submit_with_event(q, [&](sycl::handler &cgh) { + auto KernelFunctor = build_kernel_functor(cgh, launch_policy, args...); + if constexpr (syclcompat::detail::is_range_v< + typename LaunchPolicy::RangeT>) { + parallel_for(cgh, config, KernelFunctor); + } else { + static_assert( + syclcompat::detail::is_nd_range_v); + nd_launch(cgh, config, KernelFunctor); + } + }); +} From 866ab6cbe570587e03dfd501e6a4fc24516a84a1 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 13 Aug 2025 17:41:10 +0800 Subject: [PATCH 02/25] remove the experimental launch headers --- src/sycl/cutlass_helper.hpp | 80 ------------------------------------- 1 file changed, 80 deletions(-) delete mode 100644 src/sycl/cutlass_helper.hpp diff --git a/src/sycl/cutlass_helper.hpp b/src/sycl/cutlass_helper.hpp deleted file mode 100644 index 2501b39..0000000 --- a/src/sycl/cutlass_helper.hpp +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include -#include "sycl/ext/oneapi/experimental/enqueue_functions.hpp" -#include "sycl/ext/oneapi/properties/properties.hpp" -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace sycl_exp = sycl::ext::oneapi::experimental; - -template -struct KernelFunctor { - KernelFunctor(KProps kernel_props, Args... args) - : _kernel_properties{kernel_props}, - _argument_tuple(std::make_tuple(args...)) {} - - KernelFunctor(KProps kernel_props, sycl::local_accessor local_acc, - Args... args) - : _kernel_properties{kernel_props}, _local_acc{local_acc}, - _argument_tuple(std::make_tuple(args...)) {} - - auto get(sycl_exp::properties_tag) const { return _kernel_properties; } - - __syclcompat_inline__ void - operator()(syclcompat::detail::range_to_item_t) const { - if constexpr (HasLocalMem) { - char *local_mem_ptr = static_cast( - _local_acc.template get_multi_ptr() - .get()); - apply_helper( - [lmem_ptr = local_mem_ptr](auto &&...args) { - [[clang::always_inline]] F(args..., lmem_ptr); - }, - _argument_tuple); - } else { - apply_helper([](auto &&...args) { [[clang::always_inline]] F(args...); }, - _argument_tuple); - } - } - - KProps _kernel_properties; - std::tuple _argument_tuple; - std::conditional_t, std::monostate> - _local_acc; // monostate for empty type -}; - - -template -sycl::event launch(LaunchPolicy launch_policy, sycl::queue q, Args... args) { - static_assert(syclcompat::args_compatible, - "Mismatch between device function signature and supplied " - "arguments. Have you correctly handled local memory/char*?"); - - sycl_exp::launch_config config(launch_policy.get_range(), - launch_policy.get_launch_properties()); - - return sycl_exp::submit_with_event(q, [&](sycl::handler &cgh) { - auto KernelFunctor = build_kernel_functor(cgh, launch_policy, args...); - if constexpr (syclcompat::detail::is_range_v< - typename LaunchPolicy::RangeT>) { - parallel_for(cgh, config, KernelFunctor); - } else { - static_assert( - syclcompat::detail::is_nd_range_v); - nd_launch(cgh, config, KernelFunctor); - } - }); -} From 20d35e0ac0e9d86aaded1c989c81533fa9e02d1e Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 13 Aug 2025 17:43:44 +0800 Subject: [PATCH 03/25] add bindings --- src/torch_extension_sycl.cc | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 7e07662..53ddb56 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -50,6 +50,45 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, // -> Tensor"); // m.impl("fp8_blockwise_scaled_mm", torch::kXPU, &fp8_blockwise_scaled_mm); + + /* + * From cutlass attention + */ + m.def( + "fwd(Tensor! q," + " Tensor k," + " Tensor v," + " Tensor? k_new," + " Tensor? v_new," + " Tensor? q_v," + " Tensor!? out," + " Tensor? cu_seqlens_q," + " Tensor? cu_seqlens_k," + " Tensor? cu_seqlens_k_new," + " Tensor? seqused_q," + " Tensor? seqused_k," + " int? max_seqlen_q," + " int? max_seqlen_k," + " Tensor? page_table," + " Tensor? kv_batch_idx," + " Tensor? leftpad_k," + " Tensor? rotary_cos," + " Tensor? rotary_sin," + " Tensor? seqlens_rotary," + " Tensor? q_descale," + " Tensor? k_descale," + " Tensor? v_descale," + " float softmax_scale," + " bool is_causal," + " int window_size_left," + " int window_size_right," + " float softcap," + " bool is_rotary_interleaved," + " Tensor? scheduler_metadata," + " int num_splits," + " bool? pack_gqa," + " int sm_margin) -> Tensor[]"); + m.impl("fwd", torch::kXPU, make_pytorch_shim(&mha_fwd)); } REGISTER_EXTENSION(common_ops) From 25c7bd56a834b1abbc240b1c60ebb9a5773c26fb Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 13 Aug 2025 23:30:12 +0800 Subject: [PATCH 04/25] enable ut --- pyproject.toml | 1 - src/sycl/chunked_prefill.cpp | 2 +- src/torch_extension_sycl.cc | 2 ++ tests/test_flash_attention.py | 17 +++++++++++------ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a9efbff..d4bb45d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,6 @@ [build-system] requires = [ "scikit-build-core>=0.10", - "pytorch-triton-xpu @ https://download.pytorch.org/whl/test/pytorch_triton_xpu-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", "wheel", ] build-backend = "scikit_build_core.build" diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index d28bc56..1f9b8a2 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -162,7 +162,7 @@ struct Flash_fwd_params { template class KernelCur {}; -// Flash Attention takes 3 input matrices: (K)eys, (Q)ueries and (V)alues. +// Flash Attention takes 3 input matrices: Keys, Queries and Values. using LayoutQ = cutlass::layout::RowMajor; using LayoutK = cutlass::layout::ColumnMajor; using LayoutV = cutlass::layout::RowMajor; diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 53ddb56..534d575 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -16,7 +16,9 @@ limitations under the License. #include #include +#include "sgl_flash_kernel_ops.h" #include "sgl_kernel_ops.h" +#include "sgl_kernel_torch_shim.h" TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index def092a..a80dcbb 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -6,8 +6,11 @@ import pytest import torch import torch.nn.functional as F +import utils from einops import rearrange, repeat +device = utils.get_device() + apply_rotary_emb = None @@ -25,10 +28,14 @@ def is_fa3_supported(device=None) -> bool: # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. - return ( - torch.cuda.get_device_capability(device)[0] == 9 - or torch.cuda.get_device_capability(device)[0] == 8 - ) and (torch.version.cuda >= "12.3") + if torch.cuda.is_available(): + return ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) and (torch.version.cuda >= "12.3") + elif torch.xpu.is_available(): + device_name = torch.xpu.get_device_properties(0).name + return "B580" in device_name or "e211" in device_name DISABLE_BACKWARD = True @@ -551,7 +558,6 @@ def test_flash_attn_kvcache( pytest.skip() if rotary_fraction == 0.0 and has_rotary_seqlens: pytest.skip() - device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 @@ -1077,7 +1083,6 @@ def test_flash_attn_varlen_output( ): from sgl_kernel.flash_attn import flash_attn_varlen_func - device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 From 6bd00d72de5e174dd356e35ab76cf5fbba13103b Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Thu, 14 Aug 2025 22:45:30 +0800 Subject: [PATCH 05/25] align the interface --- python/sgl_kernel/flash_attn.py | 18 ++++++++++++------ src/sycl/chunked_prefill.cpp | 25 ++++++++----------------- tests/test_flash_attention.py | 5 ++++- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index fbf0b0d..981b59d 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn -try: - from sgl_kernel import flash_ops -except: - raise ImportError("Can not import sgl_kernel. Please check your installation.") +# try: +# from sgl_kernel import flash_ops +# except: +# raise ImportError("Can not import sgl_kernel. Please check your installation.") def is_fa3_supported(device=None) -> bool: @@ -18,10 +18,16 @@ def is_fa3_supported(device=None) -> bool: # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. - return ( + if torch.cuda.is_available(): + return ( torch.cuda.get_device_capability(device)[0] == 9 or torch.cuda.get_device_capability(device)[0] == 8 - ) and (torch.version.cuda >= "12.3") + ) and (torch.version.cuda >= "12.3") + elif torch.xpu.is_available(): + device_name = torch.xpu.get_device_properties(0).name + return "B580" in device_name or "e211" in device_name + else: + return False def maybe_contiguous(x): diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 1f9b8a2..02abf39 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -133,7 +133,6 @@ struct Flash_fwd_params { // Local window size int window_size_left, window_size_right; - int attention_chunk; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t* rng_state; @@ -541,14 +540,13 @@ std::vector mha_fwd( std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional& seqlens_rotary_, // b - // std::optional &q_descale_, // (b, h_k), not (b, h) - // std::optional &k_descale_, // (b, h_k) - // std::optional &v_descale_, // (b, h_k) - std::optional softmax_scale_, + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + const float softmax_scale_, bool is_causal, int window_size_left, int window_size_right, - int attention_chunk, float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) @@ -619,10 +617,8 @@ std::vector mha_fwd( int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); - double softmax_scale = 1.0 / sqrt(double(head_size)); - if (softmax_scale_.has_value()) { - softmax_scale = softmax_scale_.value(); - } + float softmax_scale = softmax_scale_; + if (!kv_batch_idx_.has_value()) { TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } @@ -791,8 +787,8 @@ std::vector mha_fwd( // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; // TODO: check this if (window_size_left < 0) { @@ -801,13 +797,8 @@ std::vector mha_fwd( if (window_size_right < 0) { window_size_right = seqlen_q - 1; } - if (attention_chunk > 0) { - window_size_left = std::min(window_size_left, attention_chunk - 1); - window_size_right = std::min(window_size_right, attention_chunk - 1); - } params.window_size_left = window_size_left; params.window_size_right = window_size_right; - params.attention_chunk = attention_chunk; params.total_q = total_q; params.total_k = total_k; diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index a80dcbb..5cbe6a5 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -16,7 +16,10 @@ def is_hopper(): # Only Hopper supports different V headdim - return torch.cuda.get_device_properties(0).major >= 9 + if torch.cuda.is_available(): + return torch.cuda.get_device_properties(0).major >= 9 + else: + return False def is_fa3_supported(device=None) -> bool: From d5a32ec89c5dd31826cf4001ffeff53bde030a20 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Fri, 5 Sep 2025 22:29:19 +0800 Subject: [PATCH 06/25] fix device lost --- include/sgl_flash_kernel_ops.h | 7 +- pyproject.toml | 1 + python/sgl_kernel/flash_attn.py | 23 ++++- src/sycl/chunked_prefill.cpp | 178 ++++++++++---------------------- src/torch_extension_sycl.cc | 4 +- tests/test_flash_attention.py | 146 +++++++++++++------------- 6 files changed, 150 insertions(+), 209 deletions(-) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index c406fa9..e4b2047 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -53,18 +53,13 @@ std::vector mha_fwd( std::optional& v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional& cu_seqlens_q_, // b+1 std::optional& cu_seqlens_k_, // b+1 std::optional& cu_seqlens_k_new_, // b+1 - std::optional& - seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional& - seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, - // TODO: check if we need max_seqlen_k std::optional max_seqlen_k_, std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& num_pages_, // (b_k, ) std::optional& kv_batch_idx_, // b. indices to index into the KV cache std::optional& leftpad_k_, // b std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) diff --git a/pyproject.toml b/pyproject.toml index d4bb45d..74df1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ exclude = [ [tool.scikit-build] cmake.build-type = "Release" +build-dir = "build" minimum-version = "build-system.requires" wheel.py-api = "cp39" diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 981b59d..a1c7d8c 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -177,6 +177,21 @@ def flash_attn_with_kvcache( rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_seqlens = maybe_contiguous(rotary_seqlens) + if cu_seqlens_q == None: # !is_varlen_q + cu_seqlens_q = torch.arange(0, q.size(0)+1, dtype=torch.int, device=q.device) * q.size(1) + max_seqlen_q = q.size(1) + q = q.view(-1, q.size(-2), q.size(-1)).contiguous() + if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new + cu_seqlens_k_new = torch.arange(0, k.size(0)+1, dtype=torch.int, device=k.device) + elif k is None: + cu_seqlens_k_new = torch.zeros_like(cu_seqlens_q, dtype=torch.int32, device=q.device) + if cache_seqlens is not None: + max_seqlen_k = cache_seqlens.max().item() + assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) + page_size = k_cache.size(1) + num_pages_per_seq = (cache_seqlens + page_size - 1) // page_size + cu_seqlens_k = torch.concat((torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), torch.cumsum(cache_seqlens, 0))).to(torch.int32) + out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, k_cache, @@ -184,15 +199,13 @@ def flash_attn_with_kvcache( k, v, qv, - None, # out cu_seqlens_q, - None, # cu_seqlens_k + cu_seqlens_k, cu_seqlens_k_new, - None, # seqused_q - cache_seqlens, max_seqlen_q, - None, # max_seqlen_k + max_seqlen_k, page_table, + num_pages_per_seq, cache_batch_idx, cache_leftpad, rotary_cos, diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 02abf39..0786d83 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -61,7 +61,8 @@ struct Flash_fwd_params { // The dimensions. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; - int total_q, total_k, total_knew; + int total_q, total_k; + int total_knew = 0; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim @@ -117,6 +118,7 @@ struct Flash_fwd_params { // Paged KV cache int* __restrict__ page_table; + int* __restrict__ num_pages_per_seq; index_t page_table_batch_stride; int page_size; int num_pages; @@ -197,29 +199,6 @@ struct ExampleRunner { StrideK stride_K_cache; StrideV stride_V_cache; StrideO stride_O; - uint64_t seed = 0; - - cutlass::DeviceAllocation block_Q; - cutlass::DeviceAllocation block_K; - cutlass::DeviceAllocation block_V; - cutlass::DeviceAllocation block_K_cache; - cutlass::DeviceAllocation block_V_cache; - cutlass::DeviceAllocation block_O; - cutlass::DeviceAllocation block_ref_O; - - std::vector cumulative_seqlen_q; - std::vector cumulative_seqlen_kv; - std::vector cumulative_seqlen_kv_cache; - cutlass::DeviceAllocation device_cumulative_seqlen_q; - cutlass::DeviceAllocation device_cumulative_seqlen_kv; - cutlass::DeviceAllocation device_cumulative_seqlen_kv_cache; - - struct PagedKVParams { - cutlass::DeviceAllocation page_table; - int page_size = 0; - cutlass::DeviceAllocation num_pages_per_seq; - }; - PagedKVParams paged_kv_cache; template auto initialize_varlen(const Flash_fwd_params& params, ProblemShape& problem_size) { @@ -228,33 +207,22 @@ struct ExampleRunner { constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements - cumulative_seqlen_q = {0}; - cumulative_seqlen_kv = {0}; - cumulative_seqlen_kv_cache = {0}; - - int total_seqlen_q = 0; - int total_seqlen_kv = 0; - int total_seqlen_kv_cache = 0; - int max_seqlen_q = 0; - int max_seqlen_kv = 0; - int max_seqlen_kv_cache = 0; - ProblemShape problem_size_for_init = problem_size; get<0>(problem_size_for_init) = 1; get<3>(problem_size_for_init) = params.total_q; - get<4>(problem_size_for_init) = params.total_knew; + get<4>(problem_size_for_init) = 0; get<5>(problem_size_for_init) = params.total_k; ProblemShapeType problem_size_for_launch; - get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.total_q}; - get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.total_knew}; - get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.total_k}; - get<6>(problem_size_for_launch) = get<6>(problem_size); - get<7>(problem_size_for_launch) = get<7>(problem_size); get<0>(problem_size_for_launch) = get<0>(problem_size); get<1>(problem_size_for_launch) = get<1>(problem_size); get<2>(problem_size_for_launch) = get<2>(problem_size); + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_knew}; + get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_k}; + get<6>(problem_size_for_launch) = get<6>(problem_size); + get<7>(problem_size_for_launch) = get<7>(problem_size); return cute::make_tuple(problem_size_for_init, problem_size_for_launch); } @@ -276,8 +244,8 @@ struct ExampleRunner { if constexpr (isVarLen) { auto [problem_shape_init, problem_shape_launch] = initialize_varlen(params, problem_shape_in); - problem_shape = problem_shape_launch; problem_size = problem_shape_init; + problem_shape = problem_shape_launch; } else { problem_size = problem_shape_in; problem_shape = problem_shape_in; @@ -286,11 +254,8 @@ struct ExampleRunner { auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_size; auto group_q_size = num_heads_q / num_heads_kv; - ; auto group_q_num = num_heads_q / group_q_size; - // stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo * group_q_size, head_size_qk, - // batch * group_q_num)); stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); stride_K = @@ -301,19 +266,9 @@ struct ExampleRunner { stride_K_cache = cutlass::make_cute_packed_stride( StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); stride_V_cache = cutlass::make_cute_packed_stride( - StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv)); - // stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, head_size_vo, - // batch * group_q_num)); + StrideV{}, cute::make_shape(head_size_vo * head_size_qk, seq_len_kv_cache, batch * num_heads_kv)); stride_O = cutlass::make_cute_packed_stride( - StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); - - block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); - block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); - block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); - block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); - block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); - block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); - block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + StrideQ{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); if constexpr (isVarLen) { get<3>(problem_shape).cumulative_length = params.cu_seqlens_q; @@ -356,9 +311,10 @@ struct ExampleRunner { ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); }; auto q = syclcompat::get_default_queue(); - auto event = q.submit(cgf); + q.submit(cgf).wait(); + // auto event = q.submit(cgf); - EventManager::getInstance().addEvent(event); + // EventManager::getInstance().addEvent(event); } cutlass::Status run(const Flash_fwd_params& params, const cutlass::KernelHardwareInfo& hw_info) { @@ -367,21 +323,24 @@ struct ExampleRunner { typename FMHAChunkPrefillKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_Q.get(), + {// static_cast(params.q_ptr), + static_cast(params.q_ptr), stride_Q, - block_K.get(), + static_cast(params.knew_ptr), stride_K, - block_V.get(), + static_cast(params.vnew_ptr), stride_V, - block_K_cache.get(), + static_cast(params.k_ptr), stride_K_cache, - block_V_cache.get(), + static_cast(params.v_ptr), stride_V_cache, params.page_table, params.page_size, - params.cu_seqlens_k}, - {params.scale_softmax}, - {block_O.get(), stride_O}, + params.num_pages_per_seq, + -1, + -1}, + {(ElementQ)params.scale_softmax}, + {static_cast(params.o_ptr), stride_O}, hw_info}; // Define device-global scratch memory @@ -412,6 +371,7 @@ template < typename TileShapeOutput, typename SubgroupLayout, int PipelineStages, + bool LocalMask = false, typename ElementInputQ = bfloat16_t, typename ElementInputKV = bfloat16_t, typename MMAOperation = XE_8x16x16_F32BF16BF16F32_TT, @@ -420,8 +380,8 @@ template < typename GmemTiledCopyV = XE_2D_U16x16x32_LD_V, typename ElementAccumulator = float, typename ElementComputeEpilogue = float, - typename ElementOutput = float, - typename GmemTiledCopyStore = XE_2D_U32x8x16_ST_N> + typename ElementOutput = bfloat16_t, + typename GmemTiledCopyStore = XE_2D_U16x8x16_ST_N> struct FMHAConfig { template static int run(const Flash_fwd_params& params) { @@ -442,7 +402,7 @@ struct FMHAConfig { ElementOutput, GmemTiledCopyStore>; using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective:: - FlashChunkPrefillSoftmaxEpilogue; + FlashChunkPrefillSoftmaxEpilogue; using ProblemShapeRegular = cute::tuple; using namespace cutlass::fmha::collective; @@ -467,6 +427,7 @@ struct FMHAConfig { GmemTiledCopyK, // K GmemTiledCopyV, // V, Causal, + LocalMask, PagedKV>; using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk< @@ -484,7 +445,11 @@ struct FMHAConfig { static int run(const Flash_fwd_params& params) { // only support varlen and paged kv now - return run(params); + if (params.page_table != nullptr && params.cu_seqlens_k != nullptr) { + return run(params); + } else { + return 0; + } } }; @@ -523,18 +488,13 @@ std::vector mha_fwd( std::optional& v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional& cu_seqlens_q_, // b+1 std::optional& cu_seqlens_k_, // b+1 std::optional& cu_seqlens_k_new_, // b+1 - std::optional& - seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional& - seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, - // TODO: check if we need max_seqlen_k std::optional max_seqlen_k_, std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& num_pages_, // (b_k, ) std::optional& kv_batch_idx_, // b. indices to index into the KV cache std::optional& leftpad_k_, // b std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) @@ -582,7 +542,7 @@ std::vector mha_fwd( TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); } at::Tensor cu_seqlens_q; - bool const is_varlen_q = cu_seqlens_q_.has_value(); + bool const is_varlen_q = q.dim() == 3; // variable length if 3 dimensions if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); @@ -596,10 +556,8 @@ std::vector mha_fwd( cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); - TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); - TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); } auto const sizes = q.sizes(); @@ -614,7 +572,7 @@ std::vector mha_fwd( int const page_size = !paged_KV ? 1 : k.size(1); int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); - int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : cu_seqlens_k[-1].item(); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); float softmax_scale = softmax_scale_; @@ -664,21 +622,6 @@ std::vector mha_fwd( CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } - if (seqused_q_.has_value()) { - auto seqused_q = seqused_q_.value(); - TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); - CHECK_DEVICE(seqused_q); - CHECK_CONTIGUOUS(seqused_q); - CHECK_SHAPE(seqused_q, batch_size); - } - if (seqused_k_.has_value()) { - auto seqused_k = seqused_k_.value(); - TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); - CHECK_DEVICE(seqused_k); - CHECK_CONTIGUOUS(seqused_k); - CHECK_SHAPE(seqused_k, batch_size); - } - if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); @@ -687,32 +630,16 @@ std::vector mha_fwd( CHECK_SHAPE(leftpad_k, batch_size); } - bool const is_varlen = - is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + bool const is_varlen = is_varlen_q || is_varlen_k || leftpad_k_.has_value(); static constexpr int alignment = 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); - auto out_type = q_type; at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK( - out.scalar_type() == out_type, - "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); - } else { - CHECK_SHAPE(out, total_q, num_heads, head_size_v); - } - } else { - out = !is_varlen_q ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) - : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); - } + out = !is_varlen_q ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts) + : torch::empty({total_q, num_heads, head_size_v}, opts); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); @@ -731,6 +658,7 @@ std::vector mha_fwd( softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); } + // align with FA3 Flash_fwd_params params; params.is_bf16 = q.dtype() == torch::kBFloat16; @@ -761,8 +689,6 @@ std::vector mha_fwd( params.cu_seqlens_q = !is_varlen_q ? nullptr : static_cast(cu_seqlens_q.data_ptr()); params.cu_seqlens_k = !is_varlen_k ? nullptr : static_cast(cu_seqlens_k.data_ptr()); - params.seqused_q = seqused_q_.has_value() ? static_cast(seqused_q_.value().data_ptr()) : nullptr; - params.seqused_k = seqused_k_.has_value() ? static_cast(seqused_k_.value().data_ptr()) : nullptr; // Softmax sum params.softmax_lse_ptr = softmax_lse.data_ptr(); @@ -805,8 +731,10 @@ std::vector mha_fwd( params.b_k = batch_size_k; params.dv = head_size_v; if (paged_KV) { + TORCH_CHECK(num_pages_.has_value(), "num_pages must be provided if page_table is provided"); params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); + params.num_pages_per_seq = num_pages_.value().data_ptr(); } params.page_size = page_size; params.num_pages = num_pages; @@ -814,10 +742,9 @@ std::vector mha_fwd( if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma at::Tensor k_new, v_new; TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); - TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); at::Tensor cu_seqlens_k_new; - bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + bool const is_varlen_k_new = k_new_.value().dim() == 3; if (is_varlen_k_new) { cu_seqlens_k_new = cu_seqlens_k_new_.value(); CHECK_DEVICE(cu_seqlens_k_new); @@ -832,8 +759,7 @@ std::vector mha_fwd( CHECK_DEVICE(v_new); TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); - // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new - int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 1; int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1) : k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); @@ -859,6 +785,12 @@ std::vector mha_fwd( if (is_varlen_k_new) { params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } + } else { + TORCH_CHECK(cu_seqlens_k_new_.has_value(), "If k_new "); + params.seqlen_knew = 0; + params.total_knew = 0; + at::Tensor cu_seqlens_k_new = cu_seqlens_k_new_.value(); + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } if (q_v_.has_value()) { @@ -934,7 +866,7 @@ std::vector mha_fwd( at::Tensor out_accum, softmax_lse_accum; auto outaccum_type = at::ScalarType::Float; - constexpr int PipelineStages = 2; + constexpr int PipelineStages = 0; if (params.is_causal) { switch (params.d) { case 64: diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 534d575..ff46dd9 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -63,15 +63,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor? k_new," " Tensor? v_new," " Tensor? q_v," - " Tensor!? out," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," " Tensor? cu_seqlens_k_new," - " Tensor? seqused_q," - " Tensor? seqused_k," " int? max_seqlen_q," " int? max_seqlen_k," " Tensor? page_table," + " Tensor? num_pages," " Tensor? kv_batch_idx," " Tensor? leftpad_k," " Tensor? rotary_cos," diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 5cbe6a5..a3510e1 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -56,7 +56,7 @@ def is_fa3_supported(device=None) -> bool: # ) DISABLE_SPLIT = True -DISABLE_PAGEDKV = True +DISABLE_PAGEDKV = False DISABLE_APPENDKV = True DISABLE_LOCAL = True DISABLE_SOFTCAP = True @@ -471,18 +471,18 @@ def generate_qkv( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) # @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +@pytest.mark.parametrize("new_kv", [False]) # @pytest.mark.parametrize( # "causal,local", # [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []), # ) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) @pytest.mark.parametrize("causal,local", [(False, False)]) -@pytest.mark.parametrize( - "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] -) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize( +# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] +# ) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) # @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) @pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize( @@ -563,10 +563,10 @@ def test_flash_attn_kvcache( pytest.skip() # set seed torch.random.manual_seed(0) - batch_size = 5 - # batch_size = 1 + # batch_size = 5 + batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 6 + nheads = 1 # nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 @@ -910,72 +910,74 @@ def test_flash_attn_kvcache( # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) + torch.xpu.synchronize() print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() + # # breakpoint() - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) - if not has_batch_idx - else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) - if not has_batch_idx - else v_cache.to(dtype_ref)[cache_batch_idx] - ) - else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[ - ( - page_table - if not has_batch_idx - else page_table[cache_batch_idx] - ).flatten() - ], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[ - ( - page_table - if not has_batch_idx - else page_table[cache_batch_idx] - ).flatten() - ], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose( - v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 - ) - # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() - if dtype is not torch.float8_e4m3fn: - assert torch.allclose( - k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 - ) - else: - assert torch.allclose( - k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 - ) + # # Check that FlashAttention's numerical error is at most twice the numerical error + # # of a Pytorch implementation. + # if new_kv: + # if page_size is None: + # k_cache_select = ( + # k_cache.to(dtype_ref) + # if not has_batch_idx + # else k_cache.to(dtype_ref)[cache_batch_idx] + # ) + # v_cache_select = ( + # v_cache.to(dtype_ref) + # if not has_batch_idx + # else v_cache.to(dtype_ref)[cache_batch_idx] + # ) + # else: + # k_cache_select = rearrange( + # k_cache_paged.to(dtype_ref)[ + # ( + # page_table + # if not has_batch_idx + # else page_table[cache_batch_idx] + # ).flatten() + # ], + # "(b nblocks) block_size ... -> b (nblocks block_size) ...", + # b=batch_size, + # )[:, :seqlen_k].to(dtype_ref) + # v_cache_select = rearrange( + # v_cache_paged.to(dtype_ref)[ + # ( + # page_table + # if not has_batch_idx + # else page_table[cache_batch_idx] + # ).flatten() + # ], + # "(b nblocks) block_size ... -> b (nblocks block_size) ...", + # b=batch_size, + # )[:, :seqlen_k].to(dtype_ref) + # k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + # v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + # # if dtype is not torch.float8_e4m3fn: + # # import pdb; pdb.set_trace() + # # assert torch.equal(v_cache_select, v_cache_ref) + # # else: + # # assert torch.allclose( + # # v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + # # ) + # # breakpoint() + # # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + # # if rotary_dim == 0: + # # assert torch.equal(k_cache_select, k_cache_ref) + # # else: + # # # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # # # breakpoint() + # # if dtype is not torch.float8_e4m3fn: + # # assert torch.allclose( + # # k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + # # ) + # # else: + # # assert torch.allclose( + # # k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + # # ) mult = 4 if dtype == torch.float8_e4m3fn else 2 assert (out - out_ref).abs().max().item() <= mult * ( out_pt - out_ref @@ -989,7 +991,7 @@ def test_flash_attn_kvcache( def _generate_block_kvcache( seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref ): - num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + num_blocks = math.ceil(seqlen_k / page_size) * batch_size k_cache_paged = ( torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) .to(dtype) From 77d3545d94e99e99c5848eba49dd42e889c325fb Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Fri, 5 Sep 2025 22:39:09 +0800 Subject: [PATCH 07/25] fix lint --- python/sgl_kernel/flash_attn.py | 32 +++++++++++++++++++------------- src/sycl/chunked_prefill.cpp | 4 ++-- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index a1c7d8c..afc3a6d 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -3,11 +3,6 @@ import torch import torch.nn as nn -# try: -# from sgl_kernel import flash_ops -# except: -# raise ImportError("Can not import sgl_kernel. Please check your installation.") - def is_fa3_supported(device=None) -> bool: # There some fa3 FYI @@ -19,9 +14,9 @@ def is_fa3_supported(device=None) -> bool: # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. if torch.cuda.is_available(): - return ( - torch.cuda.get_device_capability(device)[0] == 9 - or torch.cuda.get_device_capability(device)[0] == 8 + return ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 ) and (torch.version.cuda >= "12.3") elif torch.xpu.is_available(): device_name = torch.xpu.get_device_properties(0).name @@ -177,20 +172,31 @@ def flash_attn_with_kvcache( rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_seqlens = maybe_contiguous(rotary_seqlens) - if cu_seqlens_q == None: # !is_varlen_q - cu_seqlens_q = torch.arange(0, q.size(0)+1, dtype=torch.int, device=q.device) * q.size(1) + if cu_seqlens_q == None: # !is_varlen_q + cu_seqlens_q = torch.arange( + 0, q.size(0) + 1, dtype=torch.int, device=q.device + ) * q.size(1) max_seqlen_q = q.size(1) q = q.view(-1, q.size(-2), q.size(-1)).contiguous() if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new - cu_seqlens_k_new = torch.arange(0, k.size(0)+1, dtype=torch.int, device=k.device) + cu_seqlens_k_new = torch.arange( + 0, k.size(0) + 1, dtype=torch.int, device=k.device + ) elif k is None: - cu_seqlens_k_new = torch.zeros_like(cu_seqlens_q, dtype=torch.int32, device=q.device) + cu_seqlens_k_new = torch.zeros_like( + cu_seqlens_q, dtype=torch.int32, device=q.device + ) if cache_seqlens is not None: max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) page_size = k_cache.size(1) num_pages_per_seq = (cache_seqlens + page_size - 1) // page_size - cu_seqlens_k = torch.concat((torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), torch.cumsum(cache_seqlens, 0))).to(torch.int32) + cu_seqlens_k = torch.concat( + ( + torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), + torch.cumsum(cache_seqlens, 0), + ) + ).to(torch.int32) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 0786d83..d041654 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -170,7 +170,7 @@ using LayoutV = cutlass::layout::RowMajor; using LayoutO = cutlass::layout::RowMajor; template -struct ExampleRunner { +struct KernelRunner { using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; using StrideK = typename FMHAChunkPrefillKernel::StrideK; using StrideV = typename FMHAChunkPrefillKernel::StrideV; @@ -437,7 +437,7 @@ struct FMHAConfig { CollectiveEpilogue, Scheduler>; - ExampleRunner runner; + KernelRunner runner; (runner.run(params, hw_info)); return 0; From 31f6fa934ac88cb62a6ce4944f205859f09fe8d2 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Mon, 8 Sep 2025 22:41:44 +0800 Subject: [PATCH 08/25] backup --- CMakeLists.txt | 2 +- python/sgl_kernel/flash_attn.py | 7 ++++++- tests/test_flash_attention.py | 9 ++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d242569..1707774 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG b15779b2d99bc392bfaa8209d547fbb8b2f5c807 + GIT_TAG 28f1fe81a92b6e51aa98d89b4260cbe8022596a1 GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index afc3a6d..99babe2 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -190,7 +190,12 @@ def flash_attn_with_kvcache( max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) page_size = k_cache.size(1) - num_pages_per_seq = (cache_seqlens + page_size - 1) // page_size + num_pages_per_seq = torch.concat( + ( + torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), + torch.cumsum((cache_seqlens + page_size - 1) // page_size, 0), + ) + ).to(torch.int32) cu_seqlens_k = torch.concat( ( torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index a3510e1..47ba37d 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -499,7 +499,7 @@ def generate_qkv( ) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize( - "page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []) + "page_size", [64, 128] ) # @pytest.mark.parametrize("page_size", [None]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @@ -566,7 +566,7 @@ def test_flash_attn_kvcache( # batch_size = 5 batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 1 + nheads = 16 # nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 @@ -911,6 +911,9 @@ def test_flash_attn_kvcache( # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) torch.xpu.synchronize() + out = out.flatten() + out_ref = out_ref.flatten() + out_pt = out_pt.flatten() print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1003,7 +1006,7 @@ def _generate_block_kvcache( .to(dtype_ref) ) page_table = rearrange( - torch.randperm(num_blocks, dtype=torch.int32, device=device), + torch.arange(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) From 5388444fef87222636b3e54e70488a5ded792c21 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 9 Sep 2025 19:41:02 +0800 Subject: [PATCH 09/25] fix unused page --- CMakeLists.txt | 2 +- python/sgl_kernel/flash_attn.py | 11 ++++------- src/sycl/chunked_prefill.cpp | 20 ++++++++------------ tests/test_flash_attention.py | 12 +++++++----- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1707774..55fc5fe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG 28f1fe81a92b6e51aa98d89b4260cbe8022596a1 + GIT_TAG ab1f4b8ddfd5748e4c00317710cdbcecda58de28 GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 99babe2..0f4b2cc 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -189,13 +189,8 @@ def flash_attn_with_kvcache( if cache_seqlens is not None: max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) - page_size = k_cache.size(1) - num_pages_per_seq = torch.concat( - ( - torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), - torch.cumsum((cache_seqlens + page_size - 1) // page_size, 0), - ) - ).to(torch.int32) + max_page_size_per_seq = page_table.shape(1) + num_pages_per_seq = torch.arange(0, cache_seqlens.size(0) * max_page_size_per_seq, max_page_size_per_seq, device=cache_seqlens.device).to(torch.int32) cu_seqlens_k = torch.concat( ( torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), @@ -203,6 +198,8 @@ def flash_attn_with_kvcache( ) ).to(torch.int32) + import pdb; pdb.set_trace() + out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, k_cache, diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index d041654..7e25367 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -202,15 +202,10 @@ struct KernelRunner { template auto initialize_varlen(const Flash_fwd_params& params, ProblemShape& problem_size) { - // Use Cacheline Size to calculate alignment - constexpr int cacheline_bytes = 64; - constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements - constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements - ProblemShape problem_size_for_init = problem_size; - get<0>(problem_size_for_init) = 1; + get<0>(problem_size_for_init) = 1; // concentrated batch get<3>(problem_size_for_init) = params.total_q; - get<4>(problem_size_for_init) = 0; + get<4>(problem_size_for_init) = params.total_knew; get<5>(problem_size_for_init) = params.total_k; ProblemShapeType problem_size_for_launch; @@ -218,9 +213,9 @@ struct KernelRunner { get<0>(problem_size_for_launch) = get<0>(problem_size); get<1>(problem_size_for_launch) = get<1>(problem_size); get<2>(problem_size_for_launch) = get<2>(problem_size); - get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_q}; - get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_knew}; - get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_k}; + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_q, params.total_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_knew, params.total_knew}; + get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_k, params.total_k}; get<6>(problem_size_for_launch) = get<6>(problem_size); get<7>(problem_size_for_launch) = get<7>(problem_size); @@ -571,8 +566,9 @@ std::vector mha_fwd( int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); int const seqlen_k = - !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); - int const total_k = !is_varlen_k ? batch_size * k.size(1) : cu_seqlens_k[-1].item(); + !is_varlen_k ? k.size(1) : (!paged_KV ? max_seqlen_k_.value() : max_num_pages_per_seq * page_size); + int const total_k = + !is_varlen_k ? batch_size * k.size(1) : (!paged_KV ? cu_seqlens_k[-1].item() : num_pages * page_size); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); float softmax_scale = softmax_scale_; diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 47ba37d..ab17a71 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -499,7 +499,7 @@ def generate_qkv( ) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize( - "page_size", [64, 128] + "page_size", [64, 128, 256] ) # @pytest.mark.parametrize("page_size", [None]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @@ -563,8 +563,7 @@ def test_flash_attn_kvcache( pytest.skip() # set seed torch.random.manual_seed(0) - # batch_size = 5 - batch_size = 1 + batch_size = 5 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 16 # nheads = 1 @@ -1006,7 +1005,7 @@ def _generate_block_kvcache( .to(dtype_ref) ) page_table = rearrange( - torch.arange(num_blocks, dtype=torch.int32, device=device), + torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) @@ -1022,7 +1021,10 @@ def _generate_block_kvcache( )[:, :seqlen_k] return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks - +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="flash_attn at sgl-kernel-xpu only supports paged cache", +) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) From 2863442b486249d8074ccc847f48569a89e357ee Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 9 Sep 2025 19:43:02 +0800 Subject: [PATCH 10/25] enable test --- .github/workflows/pr-test-xpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml index b8a424d..93e5d9b 100644 --- a/.github/workflows/pr-test-xpu.yml +++ b/.github/workflows/pr-test-xpu.yml @@ -48,7 +48,7 @@ jobs: timeout-minutes: 20 run: | docker exec -w /root/sglang ci_sglang_xpu \ - /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py" + /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py && python3 -m pytest -v -s test_flash_attention.py" - name: Run E2E Bfloat16 tests timeout-minutes: 20 From b8a60748935f358a853d0bc87fecc9716f22b173 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 9 Sep 2025 19:50:43 +0800 Subject: [PATCH 11/25] small fix --- cmake/BuildFlags.cmake | 2 - python/sgl_kernel/flash_attn.py | 2 - tests/test_flash_attention.py | 117 ++++++++++++++++---------------- 3 files changed, 58 insertions(+), 63 deletions(-) diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index d2104dd..4adb70e 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -129,8 +129,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS}) - # set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} -fno-sycl-instrument-device-code) - # set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} ${SYCL_LINK_FLAGS}) set(SYCL_OFFLINE_COMPILER_FLAGS "${SYCL_OFFLINE_COMPILER_AOT_OPTIONS}${SYCL_OFFLINE_COMPILER_CG_OPTIONS}") else() message("Not compiling with XPU. Currently only support GCC compiler on Linux as CXX compiler.") diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 0f4b2cc..77f1fe3 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -198,8 +198,6 @@ def flash_attn_with_kvcache( ) ).to(torch.int32) - import pdb; pdb.set_trace() - out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, k_cache, diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index ab17a71..631d21b 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -921,65 +921,64 @@ def test_flash_attn_kvcache( # # Check that FlashAttention's numerical error is at most twice the numerical error # # of a Pytorch implementation. - # if new_kv: - # if page_size is None: - # k_cache_select = ( - # k_cache.to(dtype_ref) - # if not has_batch_idx - # else k_cache.to(dtype_ref)[cache_batch_idx] - # ) - # v_cache_select = ( - # v_cache.to(dtype_ref) - # if not has_batch_idx - # else v_cache.to(dtype_ref)[cache_batch_idx] - # ) - # else: - # k_cache_select = rearrange( - # k_cache_paged.to(dtype_ref)[ - # ( - # page_table - # if not has_batch_idx - # else page_table[cache_batch_idx] - # ).flatten() - # ], - # "(b nblocks) block_size ... -> b (nblocks block_size) ...", - # b=batch_size, - # )[:, :seqlen_k].to(dtype_ref) - # v_cache_select = rearrange( - # v_cache_paged.to(dtype_ref)[ - # ( - # page_table - # if not has_batch_idx - # else page_table[cache_batch_idx] - # ).flatten() - # ], - # "(b nblocks) block_size ... -> b (nblocks block_size) ...", - # b=batch_size, - # )[:, :seqlen_k].to(dtype_ref) - # k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - # v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - # # if dtype is not torch.float8_e4m3fn: - # # import pdb; pdb.set_trace() - # # assert torch.equal(v_cache_select, v_cache_ref) - # # else: - # # assert torch.allclose( - # # v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 - # # ) - # # breakpoint() - # # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - # # if rotary_dim == 0: - # # assert torch.equal(k_cache_select, k_cache_ref) - # # else: - # # # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # # # breakpoint() - # # if dtype is not torch.float8_e4m3fn: - # # assert torch.allclose( - # # k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 - # # ) - # # else: - # # assert torch.allclose( - # # k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 - # # ) + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + import pdb; pdb.set_trace() + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + breakpoint() + if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) mult = 4 if dtype == torch.float8_e4m3fn else 2 assert (out - out_ref).abs().max().item() <= mult * ( out_pt - out_ref From 6ad98d8b912e27fc672a18278678d890269ede12 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 9 Sep 2025 20:00:35 +0800 Subject: [PATCH 12/25] fix lint --- python/sgl_kernel/flash_attn.py | 7 ++++++- tests/test_flash_attention.py | 17 ++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 77f1fe3..e77e68b 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -190,7 +190,12 @@ def flash_attn_with_kvcache( max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) max_page_size_per_seq = page_table.shape(1) - num_pages_per_seq = torch.arange(0, cache_seqlens.size(0) * max_page_size_per_seq, max_page_size_per_seq, device=cache_seqlens.device).to(torch.int32) + num_pages_per_seq = torch.arange( + 0, + cache_seqlens.size(0) * max_page_size_per_seq, + max_page_size_per_seq, + device=cache_seqlens.device, + ).to(torch.int32) cu_seqlens_k = torch.concat( ( torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 631d21b..76b2fb9 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -498,9 +498,7 @@ def generate_qkv( ), ) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize( - "page_size", [64, 128, 256] -) +@pytest.mark.parametrize("page_size", [64, 128, 256]) # @pytest.mark.parametrize("page_size", [None]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @pytest.mark.parametrize("has_leftpad", [False]) @@ -917,10 +915,10 @@ def test_flash_attn_kvcache( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # # breakpoint() + # breakpoint() - # # Check that FlashAttention's numerical error is at most twice the numerical error - # # of a Pytorch implementation. + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. if new_kv: if page_size is None: k_cache_select = ( @@ -959,14 +957,14 @@ def test_flash_attn_kvcache( k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - import pdb; pdb.set_trace() assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose( v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 ) - breakpoint() - if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: assert torch.equal(k_cache_select, k_cache_ref) else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): @@ -1020,6 +1018,7 @@ def _generate_block_kvcache( )[:, :seqlen_k] return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + @pytest.mark.skipif( not torch.cuda.is_available(), reason="flash_attn at sgl-kernel-xpu only supports paged cache", From 1550a6a155e737c153d90c1c20790121ee4d4a7b Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 10 Sep 2025 17:48:02 +0800 Subject: [PATCH 13/25] small fix --- python/sgl_kernel/flash_attn.py | 23 +++++++++++++++++------ src/sycl/chunked_prefill.cpp | 26 +++++++++++++------------- tests/test_flash_attention.py | 6 ++---- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index e77e68b..d19c0d7 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -189,7 +189,7 @@ def flash_attn_with_kvcache( if cache_seqlens is not None: max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) - max_page_size_per_seq = page_table.shape(1) + max_page_size_per_seq = page_table.size(1) num_pages_per_seq = torch.arange( 0, cache_seqlens.size(0) * max_page_size_per_seq, @@ -265,13 +265,26 @@ def flash_attn_varlen_func( ): if not is_fa3_supported(): raise NotImplementedError( - "flash_attn at sgl-kernel is only supported on sm90 and above" + "flash_attn at sgl-kernel-xpu is only supported on BMG and later" ) if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( -0.5 ) + if cu_seqlens_q == None: # !is_varlen_q + cu_seqlens_q = torch.arange( + 0, q.size(0) + 1, dtype=torch.int, device=q.device + ) * q.size(1) + max_seqlen_q = q.size(1) + q = q.view(-1, q.size(-2), q.size(-1)).contiguous() + batch_size = cu_seqlens_q.numel() - 1 + page_table = ( + torch.arange(0, batch_size, device=q.device) + .to(torch.int32) + .reshape([batch_size, 1]) + .contiguous() + ) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -280,15 +293,13 @@ def flash_attn_varlen_func( None, # k_new None, # v_new qv, # qv - None, # out cu_seqlens_q, cu_seqlens_k, None, # cu_seqlens_k_new - seqused_q, - seqused_k, max_seqlen_q, max_seqlen_k, - None, # page_table, + page_table, # page_table, + page_table, # num_pages_per_seq None, # kv_batch_idx None, # leftpad_k None, # rotary cos diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 7e25367..da82634 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -488,16 +488,16 @@ std::vector mha_fwd( std::optional& cu_seqlens_k_new_, // b+1 std::optional max_seqlen_q_, std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& num_pages_, // (b_k, ) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& num_pages_per_seq_, // (b_k, ) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) const float softmax_scale_, bool is_causal, int window_size_left, @@ -730,10 +730,10 @@ std::vector mha_fwd( TORCH_CHECK(num_pages_.has_value(), "num_pages must be provided if page_table is provided"); params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); - params.num_pages_per_seq = num_pages_.value().data_ptr(); + params.num_pages_per_seq = num_pages_per_seq_.value().data_ptr(); + params.page_size = page_size; + params.num_pages = num_pages; } - params.page_size = page_size; - params.num_pages = num_pages; if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma at::Tensor k_new, v_new; diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 76b2fb9..580153f 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -39,6 +39,8 @@ def is_fa3_supported(device=None) -> bool: elif torch.xpu.is_available(): device_name = torch.xpu.get_device_properties(0).name return "B580" in device_name or "e211" in device_name + else: + return False DISABLE_BACKWARD = True @@ -1019,10 +1021,6 @@ def _generate_block_kvcache( return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="flash_attn at sgl-kernel-xpu only supports paged cache", -) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) From f5c2c893485598633eefb1e9b930b668f9e0892f Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Thu, 11 Sep 2025 23:36:53 +0800 Subject: [PATCH 14/25] small typo --- include/sgl_flash_kernel_ops.h | 20 ++++++++++---------- src/sycl/chunked_prefill.cpp | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index e4b2047..9e8b209 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -58,16 +58,16 @@ std::vector mha_fwd( std::optional& cu_seqlens_k_new_, // b+1 std::optional max_seqlen_q_, std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& num_pages_, // (b_k, ) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& num_pages_per_seq_, // (b_k, ) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) float const softmax_scale, bool is_causal, int window_size_left, diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index da82634..b8af394 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -727,7 +727,7 @@ std::vector mha_fwd( params.b_k = batch_size_k; params.dv = head_size_v; if (paged_KV) { - TORCH_CHECK(num_pages_.has_value(), "num_pages must be provided if page_table is provided"); + TORCH_CHECK(num_pages_per_seq_.has_value(), "num_pages must be provided if page_table is provided"); params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); params.num_pages_per_seq = num_pages_per_seq_.value().data_ptr(); From a05d6ce5ae3b52b07f113a050aba7131cc5b2e08 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 16 Sep 2025 17:24:05 +0800 Subject: [PATCH 15/25] spirv build flags --- cmake/BuildFlags.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index 4adb70e..95d373a 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -73,7 +73,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator) - set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier) + set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate) if(CMAKE_BUILD_TYPE MATCHES Debug) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization) From 8b0d1673de6b174af76350fa5dbb493425d928b3 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 17 Sep 2025 23:02:40 +0800 Subject: [PATCH 16/25] fix different queue --- src/sycl/chunked_prefill.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index b8af394..62a5e6d 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -305,7 +305,8 @@ struct KernelRunner { cgh.parallel_for>( ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); }; - auto q = syclcompat::get_default_queue(); + auto stream = at::xpu::getCurrentXPUStream(); + auto q = stream.queue(); q.submit(cgf).wait(); // auto event = q.submit(cgf); From c3947725f61bcb227445c720265c39ad853f300f Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Thu, 18 Sep 2025 16:53:18 +0800 Subject: [PATCH 17/25] Update chunked_prefill.cpp piplinestage=2 --- CMakeLists.txt | 2 +- cmake/BuildFlags.cmake | 16 ++++------------ src/sycl/chunked_prefill.cpp | 2 +- tests/test_flash_attention.py | 4 ++++ 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 55fc5fe..7f9a97a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG ab1f4b8ddfd5748e4c00317710cdbcecda58de28 + GIT_TAG e02de57e31a20f1c5c7e472aecd322e9196b2792 GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index 95d373a..111b631 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -113,18 +113,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(AOT_TARGETS "bmg") - if(TORCH_XPU_ARCH_LIST) - set(AOT_TARGETS "${TORCH_XPU_ARCH_LIST}") - endif() - if(AOT_TARGETS STREQUAL "none") - set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE) - else() - set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen) - set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION}) - set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION}) - set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}") - set(TORCH_XPU_ARCH_LIST ${AOT_TARGETS} PARENT_SCOPE) - endif() + set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen) + set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION}) + set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION}) + set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}") message(STATUS "Compile Intel GPU AOT Targets for ${AOT_TARGETS}") set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS}) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 62a5e6d..00a9f49 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -863,7 +863,7 @@ std::vector mha_fwd( at::Tensor out_accum, softmax_lse_accum; auto outaccum_type = at::ScalarType::Float; - constexpr int PipelineStages = 0; + constexpr int PipelineStages = 2; if (params.is_causal) { switch (params.d) { case 64: diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 580153f..13199b4 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1022,6 +1022,10 @@ def _generate_block_kvcache( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.skipif( + True, + reason="flash_attn at sgl-kernel-xpu only supports paged cache", +) @pytest.mark.parametrize( "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) ) From 0079b6e1bd120eb1e903cd129ac135847b2e08c3 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Mon, 22 Sep 2025 19:44:56 +0800 Subject: [PATCH 18/25] revert spirv flags; update cutlass --- CMakeLists.txt | 2 +- cmake/BuildFlags.cmake | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f9a97a..385d93f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG e02de57e31a20f1c5c7e472aecd322e9196b2792 + GIT_TAG 742d127cf5ee75cc6db4eac32c8b72f00c53d0fe GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index 111b631..885e39d 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -73,7 +73,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator) - set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate) + set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier) #,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate) if(CMAKE_BUILD_TYPE MATCHES Debug) set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization) From 8a3ddeabf92025ebabec748c69c020eac6f5fbe8 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Mon, 22 Sep 2025 20:01:29 +0800 Subject: [PATCH 19/25] remove syclcompat dependency --- src/sycl/chunked_prefill.cpp | 20 ++---- src/sycl/helper.h | 125 +++++++++++++++++++++++++++++++++++ src/sycl/sycl_common.hpp | 51 ++++++++++++++ 3 files changed, 183 insertions(+), 13 deletions(-) create mode 100644 src/sycl/helper.h create mode 100644 src/sycl/sycl_common.hpp diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 00a9f49..32a5697 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -4,16 +4,10 @@ #include #include -#include #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/util/GPU_Clock.hpp" -#include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/device/gemm_complex.h" -#include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/sycl_event_manager.hpp" #include "flash_attention_v2/collective/fmha_fusion.hpp" #include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" @@ -283,21 +277,21 @@ struct KernelRunner { // configure smem size and carveout int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; - const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); - const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); - syclcompat::experimental::launch_properties launch_props{ + using namespace compat::experimental; + compat::experimental::launch_properties launch_props{ sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), }; - syclcompat::experimental::kernel_properties kernel_props{ + compat::experimental::kernel_properties kernel_props{ sycl::ext::oneapi::experimental::sub_group_size}; - syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; - // auto event = syclcompat::experimental::launch>(policy, params); + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; sycl::ext::oneapi::experimental::launch_config config(policy.get_range(), policy.get_launch_properties()); auto cgf = [&](::sycl::handler& cgh) { auto KernelFunctor = - syclcompat::experimental::detail::build_kernel_functor>( + compat::experimental::detail::build_kernel_functor>( cgh, policy, params); sycl::ext::oneapi::experimental::detail:: LaunchConfigAccess, decltype(policy.get_launch_properties())> diff --git a/src/sycl/helper.h b/src/sycl/helper.h new file mode 100644 index 0000000..79e9746 --- /dev/null +++ b/src/sycl/helper.h @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_timer.hpp" +#else +#include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU + * stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; diff --git a/src/sycl/sycl_common.hpp b/src/sycl/sycl_common.hpp new file mode 100644 index 0000000..24c0e3d --- /dev/null +++ b/src/sycl/sycl_common.hpp @@ -0,0 +1,51 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +template +inline bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// TODO(Codeplay): use on device initialisation for this + +template +void convert_dtype(const SrcT* d_src, DstT* d_dst, size_t size) { + syclcompat::get_default_queue() + .parallel_for(size, [=](auto indx) { d_dst[indx] = static_cast(d_src[indx]); }) + .wait(); +} From 67a20fe291a639f32680250b8a586d1850e47d49 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 24 Sep 2025 16:18:42 +0800 Subject: [PATCH 20/25] enable causal --- CMakeLists.txt | 2 +- src/sycl/chunked_prefill.cpp | 4 +-- src/sycl/sycl_common.hpp | 51 ----------------------------------- tests/test_flash_attention.py | 18 +++++-------- 4 files changed, 9 insertions(+), 66 deletions(-) delete mode 100644 src/sycl/sycl_common.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 385d93f..10af489 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG 742d127cf5ee75cc6db4eac32c8b72f00c53d0fe + GIT_TAG f46ae0df764a1751879ce3e22765c700b1d52eca GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 32a5697..11224ec 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -392,7 +392,7 @@ struct FMHAConfig { ElementOutput, GmemTiledCopyStore>; using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective:: - FlashChunkPrefillSoftmaxEpilogue; + FlashChunkPrefillSoftmaxEpilogue; using ProblemShapeRegular = cute::tuple; using namespace cutlass::fmha::collective; @@ -777,7 +777,7 @@ std::vector mha_fwd( params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } else { - TORCH_CHECK(cu_seqlens_k_new_.has_value(), "If k_new "); + TORCH_CHECK(cu_seqlens_k_new_.has_value(), "cu_seqlens_k_new all zeros"); params.seqlen_knew = 0; params.total_knew = 0; at::Tensor cu_seqlens_k_new = cu_seqlens_k_new_.value(); diff --git a/src/sycl/sycl_common.hpp b/src/sycl/sycl_common.hpp deleted file mode 100644 index 24c0e3d..0000000 --- a/src/sycl/sycl_common.hpp +++ /dev/null @@ -1,51 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/mixed_dtype_utils.hpp" -#include "cutlass/util/reference/device/sycl_tensor_fill.h" - -template -inline bool is_close(T a, T b, float atol, float rtol) { - return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); -} - -// TODO(Codeplay): use on device initialisation for this - -template -void convert_dtype(const SrcT* d_src, DstT* d_dst, size_t size) { - syclcompat::get_default_queue() - .parallel_for(size, [=](auto indx) { d_dst[indx] = static_cast(d_src[indx]); }) - .wait(); -} diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 13199b4..f9a0756 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -479,8 +479,8 @@ def generate_qkv( # "causal,local", # [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []), # ) -# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -@pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("causal,local", [(True, False)]) # @pytest.mark.parametrize( # "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] # ) @@ -566,6 +566,8 @@ def test_flash_attn_kvcache( batch_size = 5 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 16 + if seqlen_k <= seqlen_q: + seqlen_k += seqlen_q # nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 @@ -694,17 +696,9 @@ def test_flash_attn_kvcache( dtype_ref, ) cache_seqlens = torch.randint( - 0 if new_kv else 1, + seqlen_q, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - ( - seqlen_k - - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) - + 1 - ) - if new_kv - else (seqlen_k + 1) - ), + seqlen_k, (batch_size,), dtype=torch.int32, device=device, From 53850dcfcfb6919dd7922e81f68ad041bffb2d88 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 24 Sep 2025 16:34:00 +0800 Subject: [PATCH 21/25] update cutlass commit --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 10af489..6735183 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG f46ae0df764a1751879ce3e22765c700b1d52eca + GIT_TAG 355feacef0916f277e5b3b7d02f15c9848b391df GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) From 884200b78efb70d8db62e7e68c6d7638e89201fb Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 24 Sep 2025 17:49:03 +0800 Subject: [PATCH 22/25] remove the useless wait --- src/sycl/chunked_prefill.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 11224ec..2af2dbb 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -301,10 +301,7 @@ struct KernelRunner { }; auto stream = at::xpu::getCurrentXPUStream(); auto q = stream.queue(); - q.submit(cgf).wait(); - // auto event = q.submit(cgf); - - // EventManager::getInstance().addEvent(event); + q.submit(cgf); } cutlass::Status run(const Flash_fwd_params& params, const cutlass::KernelHardwareInfo& hw_info) { From 6433273a84b1e1c6eef574acbb99ea2bee64fc5c Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Thu, 2 Oct 2025 23:41:31 +0800 Subject: [PATCH 23/25] delete kv_new and cur_kv_new --- CMakeLists.txt | 5 +- include/sgl_flash_kernel_ops.h | 6 - python/sgl_kernel/flash_attn.py | 35 +- src/sycl/chunked_prefill.cpp | 140 ++--- src/sycl/helper.h | 125 ---- .../kernels/chunk_prefill/fmha_fusion.hpp | 95 +++ .../tile_scheduler_chunk_prefill.hpp | 222 +++++++ .../chunk_prefill/xe_chunk_prefill.hpp | 594 ++++++++++++++++++ .../xe_flash_attn_chunk_prefill_epilogue.hpp | 290 +++++++++ .../xe_flash_attn_chunk_prefill_mma.hpp | 467 ++++++++++++++ ...sh_attn_chunk_prefill_softmax_epilogue.hpp | 222 +++++++ src/torch_extension_sycl.cc | 4 - 12 files changed, 1973 insertions(+), 232 deletions(-) delete mode 100644 src/sycl/helper.h create mode 100644 src/sycl/kernels/chunk_prefill/fmha_fusion.hpp create mode 100644 src/sycl/kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp create mode 100644 src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp create mode 100644 src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp create mode 100644 src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp create mode 100644 src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6735183..7bf44f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,8 +37,8 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla # cutlass FetchContent_Declare( repo-cutlass-sycl - GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git - GIT_TAG 355feacef0916f277e5b3b7d02f15c9848b391df + GIT_REPOSITORY https://github.com/intel/sycl-tla.git + GIT_TAG 8cdf47660e5c64c0f2191b11525a87bc76d71d9a GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) @@ -49,7 +49,6 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/src ${repo-cutlass-sycl_SOURCE_DIR}/include ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include - ${repo-cutlass-sycl_SOURCE_DIR}/applications ) add_subdirectory(${SGL_OPS_XPU_ROOT}/src) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index 9e8b209..a4b6f86 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -48,18 +48,12 @@ std::vector mha_fwd( // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, // page_size, h_k, dv) if there is page_table. - std::optional& - k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional& - v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional& cu_seqlens_q_, // b+1 std::optional& cu_seqlens_k_, // b+1 - std::optional& cu_seqlens_k_new_, // b+1 std::optional max_seqlen_q_, std::optional max_seqlen_k_, std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& num_pages_per_seq_, // (b_k, ) std::optional& kv_batch_idx_, // b. indices to index into the KV cache std::optional& leftpad_k_, // b std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index d19c0d7..14ce759 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -178,24 +178,25 @@ def flash_attn_with_kvcache( ) * q.size(1) max_seqlen_q = q.size(1) q = q.view(-1, q.size(-2), q.size(-1)).contiguous() - if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new - cu_seqlens_k_new = torch.arange( - 0, k.size(0) + 1, dtype=torch.int, device=k.device - ) - elif k is None: - cu_seqlens_k_new = torch.zeros_like( - cu_seqlens_q, dtype=torch.int32, device=q.device - ) + # if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new + # cu_seqlens_k_new = torch.arange( + # 0, k.size(0) + 1, dtype=torch.int, device=k.device + # ) + # elif k is None: + # cu_seqlens_k_new = torch.zeros_like( + # cu_seqlens_q, dtype=torch.int32, device=q.device + # ) if cache_seqlens is not None: max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) - max_page_size_per_seq = page_table.size(1) - num_pages_per_seq = torch.arange( - 0, - cache_seqlens.size(0) * max_page_size_per_seq, - max_page_size_per_seq, - device=cache_seqlens.device, - ).to(torch.int32) + # max_page_size_per_seq = page_table.size(1) + # # will delete later + # num_pages_per_seq = torch.arange( + # 0, + # cache_seqlens.size(0) * max_page_size_per_seq, + # max_page_size_per_seq, + # device=cache_seqlens.device, + # ).to(torch.int32) cu_seqlens_k = torch.concat( ( torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), @@ -207,16 +208,12 @@ def flash_attn_with_kvcache( q, k_cache, v_cache, - k, - v, qv, cu_seqlens_q, cu_seqlens_k, - cu_seqlens_k_new, max_seqlen_q, max_seqlen_k, page_table, - num_pages_per_seq, cache_batch_idx, cache_leftpad, rotary_cos, diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 2af2dbb..1fd23b7 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -9,11 +9,11 @@ #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/sycl_event_manager.hpp" -#include "flash_attention_v2/collective/fmha_fusion.hpp" -#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" -#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" -#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" -#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" +#include "kernels/chunk_prefill/fmha_fusion.hpp" +#include "kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp" +#include "kernels/chunk_prefill/xe_chunk_prefill.hpp" +#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" using namespace cute; @@ -112,7 +112,7 @@ struct Flash_fwd_params { // Paged KV cache int* __restrict__ page_table; - int* __restrict__ num_pages_per_seq; + int max_num_pages_per_seq; index_t page_table_batch_stride; int page_size; int num_pages; @@ -313,17 +313,17 @@ struct KernelRunner { {// static_cast(params.q_ptr), static_cast(params.q_ptr), stride_Q, - static_cast(params.knew_ptr), - stride_K, - static_cast(params.vnew_ptr), - stride_V, + // static_cast(params.knew_ptr), + // stride_K, + // static_cast(params.vnew_ptr), + // stride_V, static_cast(params.k_ptr), stride_K_cache, static_cast(params.v_ptr), stride_V_cache, params.page_table, params.page_size, - params.num_pages_per_seq, + params.max_num_pages_per_seq, -1, -1}, {(ElementQ)params.scale_softmax}, @@ -470,18 +470,12 @@ std::vector mha_fwd( // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, // page_size, h_k, dv) if there is page_table. - std::optional& - k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional& - v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional& cu_seqlens_q_, // b+1 std::optional& cu_seqlens_k_, // b+1 - std::optional& cu_seqlens_k_new_, // b+1 std::optional max_seqlen_q_, std::optional max_seqlen_k_, std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& num_pages_per_seq_, // (b_k, ) std::optional& kv_batch_idx_, // b. indices to index into the KV cache std::optional& leftpad_k_, // b std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) @@ -719,67 +713,66 @@ std::vector mha_fwd( params.b_k = batch_size_k; params.dv = head_size_v; if (paged_KV) { - TORCH_CHECK(num_pages_per_seq_.has_value(), "num_pages must be provided if page_table is provided"); params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); - params.num_pages_per_seq = num_pages_per_seq_.value().data_ptr(); + params.max_num_pages_per_seq = max_num_pages_per_seq; params.page_size = page_size; params.num_pages = num_pages; } - if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma - at::Tensor k_new, v_new; - TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); - TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); - at::Tensor cu_seqlens_k_new; - bool const is_varlen_k_new = k_new_.value().dim() == 3; - if (is_varlen_k_new) { - cu_seqlens_k_new = cu_seqlens_k_new_.value(); - CHECK_DEVICE(cu_seqlens_k_new); - CHECK_CONTIGUOUS(cu_seqlens_k_new); - TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); - } - k_new = k_new_.value(); - v_new = v_new_.value(); - TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); - TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); - CHECK_DEVICE(k_new); - CHECK_DEVICE(v_new); - TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); - TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); - int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 1; - int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1) : k_new.size(0); - if (!is_varlen_k_new) { - CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); - } else { - CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); - CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); - } - params.seqlen_knew = seqlen_k_new; - params.total_knew = total_k_new; - params.knew_ptr = k_new.data_ptr(); - params.vnew_ptr = v_new.data_ptr(); - // All stride are in elements, not bytes. - params.knew_row_stride = k_new.stride(-3); - params.vnew_row_stride = v_new.stride(-3); - params.knew_head_stride = k_new.stride(-2); - params.vnew_head_stride = v_new.stride(-2); - if (!is_varlen_k_new) { - params.knew_batch_stride = k_new.stride(0); - params.vnew_batch_stride = v_new.stride(0); - } - if (is_varlen_k_new) { - params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); - } - } else { - TORCH_CHECK(cu_seqlens_k_new_.has_value(), "cu_seqlens_k_new all zeros"); - params.seqlen_knew = 0; - params.total_knew = 0; - at::Tensor cu_seqlens_k_new = cu_seqlens_k_new_.value(); - params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); - } + // if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + // at::Tensor k_new, v_new; + // TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + // TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + // at::Tensor cu_seqlens_k_new; + // bool const is_varlen_k_new = k_new_.value().dim() == 3; + // if (is_varlen_k_new) { + // cu_seqlens_k_new = cu_seqlens_k_new_.value(); + // CHECK_DEVICE(cu_seqlens_k_new); + // CHECK_CONTIGUOUS(cu_seqlens_k_new); + // TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); + // } + // k_new = k_new_.value(); + // v_new = v_new_.value(); + // TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); + // TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); + // CHECK_DEVICE(k_new); + // CHECK_DEVICE(v_new); + // TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + // TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 1; + // int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1) : k_new.size(0); + // if (!is_varlen_k_new) { + // CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + // CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + // } else { + // CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + // CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + // CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + // } + // params.seqlen_knew = seqlen_k_new; + // params.total_knew = total_k_new; + // params.knew_ptr = k_new.data_ptr(); + // params.vnew_ptr = v_new.data_ptr(); + // // All stride are in elements, not bytes. + // params.knew_row_stride = k_new.stride(-3); + // params.vnew_row_stride = v_new.stride(-3); + // params.knew_head_stride = k_new.stride(-2); + // params.vnew_head_stride = v_new.stride(-2); + // if (!is_varlen_k_new) { + // params.knew_batch_stride = k_new.stride(0); + // params.vnew_batch_stride = v_new.stride(0); + // } + // if (is_varlen_k_new) { + // params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + // } + // } else { + // TORCH_CHECK(cu_seqlens_k_new_.has_value(), "cu_seqlens_k_new all zeros"); + // params.seqlen_knew = 0; + // params.total_knew = 0; + // at::Tensor cu_seqlens_k_new = cu_seqlens_k_new_.value(); + // params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + // } if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); @@ -806,9 +799,6 @@ std::vector mha_fwd( } if (rotary_cos_.has_value()) { - TORCH_CHECK( - k_new_.has_value(), - "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); diff --git a/src/sycl/helper.h b/src/sycl/helper.h deleted file mode 100644 index 79e9746..0000000 --- a/src/sycl/helper.h +++ /dev/null @@ -1,125 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#if defined(CUTLASS_ENABLE_SYCL) -#include "cutlass/util/sycl_timer.hpp" -#else -#include -#endif -#include - -/** - * Panic wrapper for unwinding CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -/** - * Panic wrapper for unwinding CUDA runtime errors - */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -/** - * GPU timer for recording the elapsed time across kernel(s) launched in GPU - * stream - */ -struct GpuTimer { -#if defined(CUTLASS_ENABLE_SYCL) - using cudaStream_t = int; - SYCLTimer syclTimer; -#else - cudaEvent_t _start; - cudaEvent_t _stop; -#endif - cudaStream_t _stream_id; - - /// Constructor - GpuTimer() : _stream_id(0) { -#if !defined(CUTLASS_ENABLE_SYCL) - CUDA_CHECK(cudaEventCreate(&_start)); - CUDA_CHECK(cudaEventCreate(&_stop)); -#endif - } - - /// Destructor - ~GpuTimer() { -#if !defined(CUTLASS_ENABLE_SYCL) - CUDA_CHECK(cudaEventDestroy(_start)); - CUDA_CHECK(cudaEventDestroy(_stop)); -#endif - } - - /// Start the timer for a given stream (defaults to the default stream) - void start(cudaStream_t stream_id = 0) { - _stream_id = stream_id; -#if defined(CUTLASS_ENABLE_SYCL) - syclTimer.start(); -#else - CUDA_CHECK(cudaEventRecord(_start, _stream_id)); -#endif - } - - /// Stop the timer - void stop() { -#if defined(CUTLASS_ENABLE_SYCL) - syclTimer.stop(); -#else - CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); -#endif - } - - /// Return the elapsed time (in milliseconds) - float elapsed_millis() { -#if defined(CUTLASS_ENABLE_SYCL) - return syclTimer.milliseconds(); -#else - float elapsed = 0.0; - CUDA_CHECK(cudaEventSynchronize(_stop)); - CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); - return elapsed; -#endif - } -}; diff --git a/src/sycl/kernels/chunk_prefill/fmha_fusion.hpp b/src/sycl/kernels/chunk_prefill/fmha_fusion.hpp new file mode 100644 index 0000000..39e6fa1 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/fmha_fusion.hpp @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct VariableLength { + int max_length; + int total_length = 0; + int* cumulative_length = nullptr; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template +struct is_variable_length : std::false_type {}; +template <> +struct is_variable_length : std::true_type {}; +template +constexpr bool is_variable_length_v = is_variable_length::value; + +template +CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v>) { + return s.cumulative_length[idx + 1] - s.cumulative_length[idx]; + } else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v>) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template <> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} // namespace cute diff --git a/src/sycl/kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp new file mode 100644 index 0000000..823c5a3 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in b96inary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::flash_attention { + +namespace kernel { + +struct XeFlashIndividualTileScheduler { + struct Params { + dim3 grid; + // FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params + to_underlying_arguments(ProblemSize const& problem_size, KernelHardwareInfo hw_info, TileShape const& tile_shape) { + using namespace cute; + // problem_size = [batch, num_heads_q , num_heads_kv, seq_len_qo, + // seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] + + // dim3 grid(size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))), + // size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + // size(shape<0>(problem_size) * shape<1>(problem_size))); + + int batch = size<0>(problem_size); + int num_heads_q = size<1>(problem_size); + int num_heads_kv = size<2>(problem_size); + int seq_len_qo = size<3>(problem_size); // if varlen seq_len_qo = max_seq_len + int seq_len_kv = size<4>(problem_size); // if varlen seq_len_qo = max_seq_len + int seq_len_kv_cache = size<5>(problem_size); + int head_size_qk = size<6>(problem_size); + int head_size_vo = size<7>(problem_size); + auto group_heads_q = num_heads_q / num_heads_kv; + + dim3 grid( + size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + size(shape<1>(problem_size)), + size(shape<0>(problem_size))); + return Params{grid}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct XeFlashPersistentTileScheduler { + struct Params { + int num_blocks; + FastDivmod divmod_seq_len_block; + FastDivmod divmod_head_size_block; + FastDivmod divmod_num_heads; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler(Params const& params) : block_idx(BlockIdxX()), params(params) {} + + template + static Params + to_underlying_arguments(ProblemSize const& problem_size, KernelHardwareInfo hw_info, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + // problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, + // seq_len_kv_cache, head_size_qk, head_size_vo] + int num_head_size_blocks = size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))); + int num_seq_len_blocks = size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))); + int num_blocks = num_seq_len_blocks * num_head_size_blocks * size(shape<0>(problem_size) * shape<1>(problem_size)); + + return Params{num_blocks, {num_seq_len_blocks}, {num_head_size_blocks}, {shape<1>(problem_size)}, hw_info}; + } + + template + static dim3 get_grid_shape(Params const& params) { + auto queue = compat::get_default_queue(); + auto dev = queue.get_device(); + const size_t maxSubgroups = dev.template get_info(); + // TODO (Codeplay): revert this back to std::min(params.num_blocks, + // params.hw_info.sm_count) once performance issue is fixed. + dim3 grid(std::min(params.num_blocks, ceil_div(params.hw_info.sm_count * maxSubgroups, Num_SGs)), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int seq_len_block, head_size_block, bidh; + params.divmod_head_size_block(block_decode, head_size_block, block_decode); + params.divmod_seq_len_block(block_decode, seq_len_block, block_decode); + params.divmod_num_heads(block_decode, bidh, block_decode); + return make_coord(head_size_block, seq_len_block, block_decode, bidh); + } + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler& operator++() { + block_idx += GridDimX(); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace kernel + +struct IndividualScheduler {}; +struct PersistentScheduler {}; + +namespace detail { + +template +struct TileSchedulerSelector { + static_assert(cutlass::detail::dependent_false, "Could not select a tile scheduler for given parameters."); +}; + +// Default (void) maps to XeFlashIndividualTileScheduler +template +struct TileSchedulerSelector>> { + using Scheduler = typename TileSchedulerSelector::Scheduler; +}; + +template +struct TileSchedulerSelector< + IndividualScheduler, + ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashIndividualTileScheduler; +}; + +template +struct TileSchedulerSelector< + PersistentScheduler, + ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashPersistentTileScheduler; +}; +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp new file mode 100644 index 0000000..e944945 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -0,0 +1,594 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice,this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "xe_flash_attn_chunk_prefill_mma.hpp" + +namespace cutlass::flash_attention::kernel { + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveSoftmaxEpilogue_, + class CollectiveEpilogue_, + class TileScheduler_ = void> +class FMHAPrefillChunk; +/////////////////////////////////////////////////////////////////////////////// +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveSoftmaxEpilogue_, + class CollectiveEpilogue_, + class TileScheduler_> +class FMHAPrefillChunk { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // ProblemShape: + static_assert( + rank(ProblemShape{}) == 8, + "ProblemShape{} should be "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using TiledMmaQK = typename CollectiveMainloop::TiledMmaQK; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementQ = typename CollectiveMainloop::ElementQ; + using StrideQ = typename CollectiveMainloop::StrideQ; + using ElementK = typename CollectiveMainloop::ElementK; + using StrideK = typename CollectiveMainloop::StrideK; + using ElementV = typename CollectiveMainloop::ElementV; + using StrideV = typename CollectiveMainloop::StrideV; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; + using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; + using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; + + static_assert( + cute::is_void_v or cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementO = typename CollectiveEpilogue::ElementO; + using StrideO = typename CollectiveEpilogue::StrideO; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileShapeOutput = typename CollectiveEpilogue::TileShapeOutput; + using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput; + + static_assert( + cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr bool LocalMask = CollectiveMainloop::LocalMask; + + static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + + static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; + static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; + static constexpr int QK_BLK_K = CollectiveMainloop::QK_BLK_K; + + static constexpr int QK_ATOM_N = CollectiveMainloop::QK_ATOM_N; + static constexpr int QK_ATOM_K = CollectiveMainloop::QK_ATOM_K; + + static constexpr int QK_SG_M = CollectiveMainloop::QK_SG_M; + + static constexpr int Epilogue_BLK_N = get<1>(TileShapeOutput{}); + static constexpr int Epilogue_BLK_K = get<2>(TileShapeOutput{}); + + static constexpr int PV_ATOM_M = CollectiveMainloop::PV_ATOM_M; + static constexpr int PV_ATOM_N = CollectiveMainloop::PV_ATOM_N; + static constexpr int PV_ATOM_K = CollectiveMainloop::PV_ATOM_K; + + static constexpr auto Num_SGs = PV_ATOM_N * PV_ATOM_M * PV_ATOM_K; + static constexpr int Vec = CollectiveMainloop::Vec; + static constexpr int FragsM = CollectiveMainloop::FragsM; + // The FragsN here used for Creation of S matrix so we use the FragsN for S + // shape + static constexpr int FragsN = CollectiveMainloop::FragsNS; + + static constexpr int VSlicer = + get<1>(TileShapeOutput{}) / (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = + decltype(make_shape(Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), Int{})); + + static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + // Device side arguments + struct Arguments { + gemm::GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + SoftmaxArguments softmax{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + gemm::GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + SoftmaxParams softmax; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + (void)workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})}; + } + + static bool can_implement(Arguments const& args) { + bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or + (args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + return mode_implementable; + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length(select<3, 5>(problem_shape), batch); + } else { + return select<3, 5>(problem_shape); + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // Separate out problem shape for convenience + + // "ProblemShape{} should be "); + auto batch = get<0>(params.problem_shape); + auto num_heads_q = get<1>(params.problem_shape); + auto num_heads_kv = get<2>(params.problem_shape); + + auto& head_size_qk = get<6>(params.problem_shape); + auto& head_size_vo = get<7>(params.problem_shape); + // Preconditions + static_assert( + cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert( + cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert( + cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); + + int thread_idx = int(ThreadIdxX()); + // int sub_group_id = thread_idx / SubgroupSize; + auto sub_group_id = get_sub_group_id(); + auto local_id = get_sub_group_local_id(); + + TileScheduler tile_scheduler{params.scheduler}; + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, + // batch_blk_idx, num_heads_blk_idx + + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto blk_n_coord = 0; // nums_head_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_heads_idx + auto batch_coord = get<2>(blk_coord); // batch_blk_idx + + // For variable sequence length case, batch is considered to be 1 (same + // as group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch + // would have it's own seq_len_qo and seq_len_kv) iff !is_var_len: + // batch_size = batch * num_heads + // auto blk_l_coord = q_head_coord; + + // Get problem shape for the current batch_blk_idx. For variable + // sequence length, it loads the sequence length from Global memory for + // the given batch_blk_idx and returns the appropriate problem_shape. + // For fixed sequence length, sequence_length_shape == select<3, 4, + // 5>(params.problem_shape). sequence_length_shape = [batch, + // num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, + // head_size_qk, head_size_vo] + auto sequence_length_shape = get_sequence_length_shape(params.problem_shape, batch_coord); + + auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; + // int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; + // For variable sequence length case, batch is considered to be 1 (same + // as group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch + // would have it's own seq_len_qo and seq_len_kv) iff !is_var_len: + // batch_size = batch * num_heads + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) + // and check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { + continue; + } + + const int seq_coord = + cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % seq_len_qo); + // auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) + // auto discard_seq_coord = seq_len_qo - offset; // 1024 + // auto full_tile_offset = seq_len_kv - offset; // 0 + + // const int seq_len = seq_len_kv; + // CausalMask + // ? full_tile_offset + + // cute::min(seq_len_kv, seq_coord - discard_seq_coord) + + // QK_SG_M + // : seq_len_kv; + + const int kv_splits_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); + const int kv_splits = kv_splits_cache; + + int tiles_per_page = params.mainloop.page_size / QK_BLK_N; + + Tensor mQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + + Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + + // block_size and head_size are the same size. So no coord is needed. + Tensor mQ_mk = mQ_mkl(_, _, 0); + + Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gV_cache = local_tile(mV_cache_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); + + // we limit the horisontal size to two subgroup, the empirical resutls + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop over + // to cover all the data needed) + auto tiled_prefetch_q = + cute::prefetch_selector, Int>, Num_SGs>( + mainloop_params.gmem_tiled_copy_q); + + auto tiled_prefetch_k_cache = + cute::prefetch_selector, Int>, Num_SGs>( + mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute:: + prefetch_selector, Int>, Num_SGs>( + mainloop_params.gmem_tiled_copy_v_cache); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K_cache = tiled_prefetch_k_cache.get_slice(thread_idx); + auto thr_prefetch_V_cache = tiled_prefetch_v_cache.get_slice(thread_idx); + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + + // assuming the copy function is the same otherwise this need to have its + // own tile_prefetch + auto pKgK_cache = thr_prefetch_K_cache.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V_cache.partition_S(gV_cache); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + auto& prefetch_K = tiled_prefetch_k_cache; + auto& pKgK1_ = pKgK_cache; + + int cached_nblock = 0; + if constexpr (PagedKV) { + // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size);// max_page_size_per_seq + // int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; + int batch_offset = batch_coord * mainloop_params.max_num_pages_per_seq; + cached_nblock = mainloop_params.ptr_page_table[batch_offset // page table for this batch + ] * tiles_per_page; // base block idx of physical page + } + // The headsize for both cached and non-cached version is the same + for (int j = 0; j < size<4>(pKgK1_); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; i++) { + prefetch(prefetch_K, pKgK1_(_, _, _, i, j)); + } + } + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime containt 1 + // max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + // when causal mask is true. It is not possible to set the scope + // of the barrier to workgroup level as the number n block is + // different for each subgroup due to triangular nature of causal based + // operation + static constexpr int barrier_scope = CausalMask ? 3 : 2; + + int q_start_coord = blk_m_coord * QK_BLK_M; + int q_end_coord = cute::min(q_start_coord + QK_BLK_M, seq_len_qo); + int seq_diff = seq_len_kv_cache - seq_len_qo; + + CUTLASS_PRAGMA_UNROLL + for (int split = 0; split < kv_splits; split++) { + barrier_arrive(barrier_scope); + + int kv_start_coord = split * QK_BLK_N; + + if constexpr (CausalMask) { + if (kv_start_coord >= q_end_coord + seq_diff) break; + } + + // // = 0, all KV is kv_cache + // 1) Load KV (performed inside mmaQK) + auto gK_ = gK_cache(_, _, cached_nblock, _); + auto gV_ = gV_cache(_, _, cached_nblock); + // 2) Create Tensor S + Tensor tSr = make_tensor(Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), + // head_size_qk, batch* num_heads_q / group_head_q), which can be merged + // into one gemm for (int i = 0; i < q_group_size; ++i) { + collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params); + + if constexpr (LocalMask) { + // Sliding windows + // mask the elements of each tile where j - left > i || j + right < i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv_cache - seq_len_qo; + // int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = + col_idx > cute::min(seq_len_kv_cache, row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + + if constexpr (PagedKV) { + // // if constexpr(!(CausalMask || LocalMask) && PagedKV) { + // // Processing Not divisible, mask padding + // const int item_id = thread_idx % SubgroupSize; + // int col_idx = item_id + split * cute::min(QK_BLK_N, + // seq_len_kv_cache + seq_len_kv); + // CUTLASS_PRAGMA_UNROLL + // for (int n = 0; n < FragsN; n++, col_idx += + // get<1>(MmaAtomShape())) { // 4 + // CUTLASS_PRAGMA_UNROLL + // for (int m = 0; m < FragsM; m++) { // 2 + // int row_idx = m * Vec + seq_coord; + // CUTLASS_PRAGMA_UNROLL + // for (int row = 0; row < Vec; row++) { // 8 + // if (col_idx >= seq_len_kv_cache + seq_len_kv || row_idx + + // row >= seq_len_qo) { + // tSr(row, m, n) = ElementAccumulator{-INFINITY}; + // } + // } + // } + // } + + int col_start = local_id + kv_start_coord; + int col_end = col_start + (FragsN - 1) * get<1>(MmaAtomShape()); + if (col_end >= seq_len_kv_cache) { + int col_idx = col_start; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + if (col_idx >= seq_len_kv_cache) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + if constexpr (CausalMask) { + int row_start = q_start_coord + sub_group_id * QK_SG_M; + if (row_start + seq_diff < col_end) { + int col_idx = col_start; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + if (col_idx > row_start + seq_diff) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + int row_idx = row_start + m * Vec + row; + if (row_idx + seq_diff < col_idx) tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + } + } + auto& tiled_prefetch_v_ = tiled_prefetch_v_cache; + auto& pVgV_ = pVgV_cache; + int v_prefetch_idx = cached_nblock; + for (int i = 0; i < size<1>(pVgV_); i++) { + prefetch(tiled_prefetch_v_, pVgV_(_, i, _, v_prefetch_idx)); + } + int next_cached_nblock = split + 1; + if constexpr (PagedKV) { + // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + // int batch_offset = + // is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; + int curr_batch_pages = mainloop_params.max_num_pages_per_seq; // max_page_size_per_seq + int batch_offset = batch_coord * curr_batch_pages; + int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = params.mainloop.ptr_page_table + [batch_offset + // page table for this batch + next_page_logical_idx // split (tile idx) to logical + // page idx + ] * tiles_per_page + // base block idx of physical page + next_cached_nblock % tiles_per_page; // offset within page + } else { + next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the + // boundary between batches + } + } + + // 4) Fused softmax + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(split == 0, tSr, max_reg, sum_reg, out_reg); + + // 5) Perform GEMM O = S*V + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params); + // ... prefetch next tile ... + // Prefetch the next Q tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + + cached_nblock = next_cached_nblock; + // Prefetch the next K tile + // there is no need to gaurd it with if statememt as prefetch will + // ignore out of bound reading + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_cache); j++) { + prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); + } + barrier_wait(barrier_scope); + } + + // Epilogue + auto epilogue_params = CollectiveEpilogue::template get_updated_copies( + params.epilogue, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0); + epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, out_reg, max_reg, sum_reg); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention::kernel diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp new file mode 100644 index 0000000..94a5b66 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class FlashChunkPrefillEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template < + class MMAOperation_, + class TileShapeOutput_, + class SubgroupLayout_, + class ElementCompute_, + class ElementO_, + class StrideO_, + class ElementLSE_, + class CopyOpO_> +class FlashChunkPrefillEpilogue< + epilogue::IntelXeXMX16, + MMAOperation_, + TileShapeOutput_, + SubgroupLayout_, + ElementCompute_, + ElementO_, + StrideO_, + ElementLSE_, + CopyOpO_> { + public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using ElementO = ElementO_; + using StrideO = StrideO_; + using ElementLSE = ElementLSE_; + using CopyOpO = CopyOpO_; + using SubgroupLayout = SubgroupLayout_; + using TileShapeOutput = TileShapeOutput_; + using TiledMmaOutput = + typename TiledMMAHelper, Layout, SubgroupLayout>::TiledMMA; + using GmemTiledCopyO = CopyOpO; + using ElementOutput = ElementO_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementCompute_; + using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape()))); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(TileShapeOutput{}) == 3, "TileShapeOutput must be rank-3: [CTA_M_QO, CTA_N_VO, CTA_K_PV]"); + static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]"); + + using CopyThreadShape = Shape<_1, Int>; + + using traits_store_O = Copy_Traits; + using atom_load_O = Copy_Atom; + using val_layout_load_O = decltype(make_layout(shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_O = decltype(make_tiled_copy(atom_load_O{}, Layout{}, val_layout_load_O{})); + + private: + constexpr static bool is_destination_supported = not cute::is_void_v; + + public: + using EmptyType = cute::tuple<>; + + struct TensorStorageImpl : cute::tuple {}; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + ElementO const* ptr_O; + StrideO dO; + }; + + // Device side epilogue params + struct Params { + XE_Copy_O xe_store_o; + }; + + // + // Methods + // + template + CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = + problem_shape; + auto tensorO = make_tensor( + make_gmem_ptr(static_cast(args.ptr_O)), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_vo, batch), args.dO)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return { + xe_store_o, + }; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillEpilogue(Params const& params_, TensorStorage const&) : params(params_) {} + + template + CUTLASS_DEVICE void operator()( + ProblemShape problem_shape, + SequenceLengthShape sequence_length_shape, + TileCoord tile_coord, + FragOut& out, + FragMax const& max, + FragSum& sum) { + using namespace cute; + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v>; + + using FragOutLayout = typename FragOut::layout_type; + + constexpr int Vec = shape<0>(FragOutLayout{}); + constexpr int FragsM = shape<1>(FragOutLayout{}); + constexpr int FragsN = size(select<2, 3>(shape(FragOutLayout{}))); + + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto out_reg = make_tensor(static_cast(out).data(), Shape, Int, Int>{}); + + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragsM; y++) { + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < Vec; x++) { + int indx = y * Vec + x; + auto cur_sum = reduce_over_group(sg, sum(indx), sycl::plus<>()); + auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.0f : sycl::native::recip(cur_sum); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + out_reg(x, y, z) *= cur_scale; + } + } + } + + // Indexing variables + auto [batch, num_heads_q, num_heads_kv, head_size_vo] = select<0, 1, 2, 7>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + // Represent the full output tensor + Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1)); + + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + // Tile the output tensor per WG + Tensor g_wg_O = + local_tile(mO_mnl, select<0, 1>(TileShapeOutput{}), make_coord(m_coord, n_coord, 0)); // (BLK_M,BLK_N,m,n,l) + static constexpr auto ATOM_N = get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape()); + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + // Tile the output tensor per SG + Tensor gO = + local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg, n_sg, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX()); + Tensor tOgO = thread_xe_store_o.partition_D(gO); + + Tensor final_out_reg = make_fragment_like(out_reg); + // iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion + // so we call copy() which internally performs a static_cast op on the data. + // for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization. + if constexpr (cute::is_same_v) { + copy(out_reg, final_out_reg); + } else { + Tensor temp = convert_type(out_reg); + copy(temp, final_out_reg); + } + copy(params.xe_store_o, final_out_reg, tOgO); + } + + // SequenceLengthShapeType = Shape + // For Fixed Sequence Length, ProblemShapeType = Shape + // For Variable Sequence Length, ProblemShapeType = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, + ProblemShapeType const& problem_shape, + SequenceLengthShapeType const& sequence_length_shape, + int const& l_coord, + int const& q_head_coord) { + auto [num_heads_q, num_heads_kv, head_size_vo] = select<1, 2, 7>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + int offset_o = 0; + if constexpr (VarLen) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord] + q_head_coord * head_size_vo; + } else { + offset_o = num_heads_q * head_size_vo * seq_len_qo * l_coord + q_head_coord * head_size_vo; + } + auto store_traits = static_cast(params.xe_store_o); + ElementO* base_ptr = (ElementO*)store_traits.base_ptr; + auto shape_o = make_shape(static_cast(seq_len_qo), num_heads_q * head_size_vo, 1); + StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o); + auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return Params{xe_store_o}; + } + + private: + Params const& params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp new file mode 100644 index 0000000..a4e3df0 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -0,0 +1,467 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "fmha_fusion.hpp" + +//////////////////////////////////////////////////////////// +namespace {} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::flash_attention::collective { +using namespace cute; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class ProblemShapeType_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class MMAOperation_, + class TileShapeQK_, + class TileShapePV_, + class SubgroupLayout_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_, + bool LocalMask_, + bool PagedKV_> +struct FlashChunkPrefillMma { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ProblemShapeType_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class MMAOperation_, + class TileShapeQK_, + class TileShapePV_, + class SubgroupLayout_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_, + bool LocalMask_, + bool PagedKV_> +struct FlashChunkPrefillMma< + gemm::MainloopIntelXeXMX16, + ProblemShapeType_, + ElementQ_, + StrideQ_, + ElementK_, + StrideK_, + ElementV_, + StrideV_, + MMAOperation_, + TileShapeQK_, + TileShapePV_, + SubgroupLayout_, + GmemTiledCopyQ_, + GmemTiledCopyK_, + GmemTiledCopyV_, + CausalMask_, + LocalMask_, + PagedKV_> { + // + // Type Aliases + // + using DispatchPolicy = gemm::MainloopIntelXeXMX16; + using TileShapeQK = TileShapeQK_; + using TileShapePV = TileShapePV_; + using SubgroupLayout = SubgroupLayout_; + using ProblemShapeType = ProblemShapeType_; + using ElementQ = ElementQ_; + using StrideQ = StrideQ_; + using ElementK = ElementK_; + using StrideK = StrideK_; + using ElementV = ElementV_; + using StrideV = StrideV_; + using GmemTiledCopyQ = GmemTiledCopyQ_; + using GmemTiledCopyK = GmemTiledCopyK_; + using GmemTiledCopyV = GmemTiledCopyV_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaAtom = MMA_Atom; + + using TiledMmaQK = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; + + using TiledMmaPV = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; + using ElementAccumulator = typename TiledMmaQK::ValTypeC; + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + static constexpr bool PagedKV = PagedKV_; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename MmaAtom::Shape_MNK; + + static constexpr auto PV_ATOM_M = decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = decltype(get<2>(SubgroupLayout{}.shape()))::value; + + using SubgroupTileShapePV = decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); + static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); + static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); + + // This TiledMma is only required to serve the specific tiling requirements + // for matrix K. This is due to the consumption of matrix K by all subgroups + // within a workgroup. + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + + using SubgroupTileShapeQK = + decltype(cute::shape_div(TileShapeQK{}, SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + + static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v>; + + using FragsShapeS = decltype(cute::shape_div( + take<0, 2>(SubgroupTileShapeQK{}), take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + static constexpr int FragsM = get<0>(FragsShapeS{}); + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + + static constexpr uint32_t MaxThreadsPerBlock = size(SubgroupLayout{}) * SubgroupSize; + using CopyThreadShape = Shape<_1, Int>; + + using traits_load_Q = Copy_Traits; + using atom_load_Q = Copy_Atom; + using val_layout_load_Q = decltype(make_layout(shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy(atom_load_Q{}, Layout{}, val_layout_load_Q{})); + + using traits_load_K = Copy_Traits; + using atom_load_K = Copy_Atom; + using val_layout_load_K = decltype(make_layout(shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy(atom_load_K{}, Layout{}, val_layout_load_K{})); + + using traits_load_V = Copy_Traits; + using atom_load_V = Copy_Atom; + using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); + + // Host side kernel arguments + struct Arguments { + ElementQ const* ptr_Q; + StrideQ dQ; + ElementK const* ptr_K_cache; + StrideK dK_cache; + ElementV const* ptr_V_cache; + StrideV dV_cache; + // Paged KV Cache + int const* ptr_page_table; + int page_size; + int max_num_pages_per_seq; + int window_left; + int window_right; + }; + + struct Params { + XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k_cache; + XE_Copy_V gmem_tiled_copy_v_cache; + int const* ptr_page_table; + int page_size; + int max_num_pages_per_seq; + int window_left; + int window_right; + }; + + // + // Methods + // + + FlashChunkPrefillMma() = default; + + static constexpr Params + to_underlying_arguments(ProblemShapeType const& problem_shape, Arguments const& args, void* workspace) { + (void)workspace; + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = + problem_shape; + + auto tensorQ = make_tensor( + make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ)); + auto tensorK_cache = make_tensor( + make_gmem_ptr(args.ptr_K_cache), + make_layout(make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch), args.dK_cache)); + auto tensorV_cache = make_tensor( + make_gmem_ptr(args.ptr_V_cache), + make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), args.dV_cache)); + + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + + return Params{ + copyQ, + copyK_cache, + copyV_cache, + args.ptr_page_table, + args.page_size, + args.max_num_pages_per_seq, + args.window_left, + args.window_right}; + } + + template + CUTLASS_DEVICE void mmaQK( + FragQccum& accum, + TensorQ gQ, + TensorK gK, + FragSrc const& frag_src, + int const& k_tile_count, + Params const& params) { + auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; + + int thread_idx = static_cast(ThreadIdxX()); + auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); + auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx); + // Instantiate the MMA object + TiledMmaQK tiled_mma; + // To make all threads in a warp have the same global tensors pass in the + // index of thread 0 in each warp + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); + auto thread_mma_k = tiled_mma.get_slice(0); + + Tensor tCgQ = thread_mma_q.partition_A(gQ); + Tensor tCgK = thread_mma_k.partition_B(gK); + + // Create fragments + // TODO(Codeplay): fix this, this is probably not general + Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); + + // Retile registers for copies + Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); + Tensor tKrK = thr_copy_K.retile_D(tCrK); + + // Retile global tile for copies + Tensor tQgQ = thr_copy_Q.retile_S(tCgQ); + Tensor tKgK = thr_copy_K.retile_S(tCgK); + + // + // Mainloop + // + + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); + copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); +#if 0 +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(0, 0)) { + print("======================= Q: \n"); + PRINT(gQ); + PRINT(tCrQ); + PRINT(tCgQ); + PRINT(tQrQ); + PRINT(tQgQ); + + print("===================== K :\n"); + PRINT(gK); + PRINT(tCrK); + PRINT(tCgK); + PRINT(tKrK); + PRINT(tKgK); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapeQK{}); + } +#undef PRINT +#endif + } + } + + template + CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + CUTLASS_DEVICE void mmaPV( + FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params) { + auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache ; + + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMmaPV tiled_mma; + // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid + // Register spill + Tensor gV_ = take<0, 3>(local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + Tensor tCgV = thread_mma.partition_B(gV_); + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); + + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); + Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV); + Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); + +#if CUTLASS_ENABLE_DEBUG_PRINTS +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("===================== V :\n"); + PRINT(gV); + PRINT(tCrV); + PRINT(tCgV); + PRINT(tVrV); + PRINT(tVgV); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapePV{}); + } +#undef PRINT +#endif + + // 7) Convert S to P (FP32 -> BF16) + Tensor tPr = convert_type(tSr); + // + // Mainloop + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tile_count; i++) { + copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + } + } + + // SequenceLengthShape = Shape + // For Fixed Sequence Length, ProblemShape = Shape For Variable Sequence Length, ProblemShape = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, + ProblemShape const& problem_shape, + SequenceLengthShape const& sequence_length_shape, + int const& l_coord, + int const& q_head_coord = 0) { + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 6, 7>(problem_shape); + auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, offset_v_cache = 0; + int total_seq_len_kv_cache = 0; + if constexpr (is_var_len) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + auto kv_cached_cumulative_length = get<5>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + q_head_coord * head_size_qk; + + offset_k_cache = kv_head_coord * head_size_qk; + offset_v_cache = kv_head_coord * head_size_vo; + total_seq_len_kv_cache = get<5>(problem_shape).total_length; + } else { + + } + + auto q_traits = static_cast(params.gmem_tiled_copy_q); + const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr; + auto k_traits_cache = static_cast(params.gmem_tiled_copy_k_cache); + const ElementK* k_cache_ptr = (const ElementK*)k_traits_cache.base_ptr; + auto v_traits_cache = static_cast(params.gmem_tiled_copy_v_cache); + const ElementV* v_cache_ptr = (const ElementV*)v_traits_cache.base_ptr; + // NHD format{batch, seq_len, head, dim_head} + // stride {seq_len*head*dim_head, head*dim_head, dim_head, 1} + auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); + + auto shape_k_cache = make_shape( + static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), head_size_qk * num_heads_kv, 1); + StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + auto shape_v_cache = make_shape( + head_size_vo * num_heads_kv, static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), 1); + StrideV stride_v_cache = cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache)); + auto tensorV_cache = + make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + return Params{ + copyQ, + copyK_cache, + copyV_cache, + params.ptr_page_table, + params.page_size, + params.max_num_pages_per_seq, + params.window_left, + params.window_right}; + } +}; + +} // namespace cutlass::flash_attention::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp new file mode 100644 index 0000000..c2bc020 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing online softmax. +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class FlashChunkPrefillSoftmaxEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template +class FlashChunkPrefillSoftmaxEpilogue { + public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using Element = Element_; + + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + + using GmemTiledCopyOut = void; + + // Host side epilogue arguments + struct Arguments { + Element const scale; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + static constexpr Params to_underlying_arguments(Arguments const& args) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + Element val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + template + static size_t get_workspace_size() { + return 0; + } + + template + static cutlass::Status initialize_workspace() { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement() { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillSoftmaxEpilogue(Params const& params_) : params(params_) {} + + template + CUTLASS_DEVICE void scale_exp_log2(FragAcc& frag_s, FragMax const& max, FragSum& sum) { + auto g = compat::get_nd_item<1>().get_sub_group(); + const auto max_scale = max * params.scale; + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + const auto max_scale_bcast = group_broadcast(g, max_scale, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + // continue; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + sum(indx) += frag_s(base_indx); + } + } + } + + template + CUTLASS_DEVICE void reduce_max(FragSrc& src, FragMax& max) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto maxptr = group_broadcast(sg, max, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + maxptr = sycl::max(maxptr, src(base_indx)); + src(base_indx) *= params.scale; + } + maxptr = reduce_over_group(sg, maxptr, sycl::maximum<>()); + if (indx == sg.get_local_id()[0]) { + max = maxptr; + } + } + } + + template + CUTLASS_DEVICE void operator()(bool is_first, FragAcc& frag_s, FragMax& max, FragSum& sum, FragOut& out) { + auto max_prev = max; + using FragAccLayout = typename FragAcc::layout_type; + using FragOutLayout = typename FragOut::layout_type; + constexpr int Vec = get<0>(FragAccLayout{}.shape()); + constexpr int FragsM = get<1>(FragAccLayout{}.shape()); + constexpr int FragsNAcc = get<2>(FragAccLayout{}.shape()); + constexpr int FragsNOut = size(select<2, 3>(FragOutLayout{}.shape())); + reduce_max(frag_s, max); + static_assert(Vec * FragsM % 8 == 0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows."); + if (!is_first) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + Element max_scale{max * params.scale}; + Element exp_scale; + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || (std::isinf(max_prev) && max_prev < 0)) { + exp_scale = 0.f; + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto max_scale_bcast = group_broadcast(sg, max_scale, indx); + auto exp_scale_bcast = group_broadcast(sg, exp_scale, indx); + sum(indx) *= exp_scale_bcast; + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNAcc; z++) { + auto base_indx = indx + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + // continue; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + sum(indx) += frag_s(base_indx); + } + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNOut; z++) { + auto base_indx = indx + (z * Vec * FragsM); + out(base_indx) *= exp_scale_bcast; + } + } + } else { + scale_exp_log2(frag_s, max, sum); + } + } + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 340eddc..22cd0c0 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -63,16 +63,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "fwd(Tensor! q," " Tensor k," " Tensor v," - " Tensor? k_new," - " Tensor? v_new," " Tensor? q_v," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," - " Tensor? cu_seqlens_k_new," " int? max_seqlen_q," " int? max_seqlen_k," " Tensor? page_table," - " Tensor? num_pages," " Tensor? kv_batch_idx," " Tensor? leftpad_k," " Tensor? rotary_cos," From f183b23ef0bd334c57dc6ab2ff00b111c8ecd300 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Fri, 3 Oct 2025 00:56:54 +0800 Subject: [PATCH 24/25] move check into Utils.h lint fix --- include/sgl_flash_kernel_ops.h | 18 +-- python/sgl_kernel/flash_attn.py | 16 -- src/sycl/TripleOps.cpp | 16 +- src/sycl/Utils.h | 5 +- src/sycl/chunked_prefill.cpp | 150 ++++++------------ .../chunk_prefill/xe_chunk_prefill.hpp | 13 +- .../xe_flash_attn_chunk_prefill_epilogue.hpp | 4 +- .../xe_flash_attn_chunk_prefill_mma.hpp | 11 +- ...sh_attn_chunk_prefill_softmax_epilogue.hpp | 30 ++-- 9 files changed, 96 insertions(+), 167 deletions(-) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index a4b6f86..ed162d4 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -53,15 +53,15 @@ std::vector mha_fwd( std::optional& cu_seqlens_k_, // b+1 std::optional max_seqlen_q_, std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) float const softmax_scale, bool is_causal, int window_size_left, diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 14ce759..ccf735e 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -178,25 +178,9 @@ def flash_attn_with_kvcache( ) * q.size(1) max_seqlen_q = q.size(1) q = q.view(-1, q.size(-2), q.size(-1)).contiguous() - # if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new - # cu_seqlens_k_new = torch.arange( - # 0, k.size(0) + 1, dtype=torch.int, device=k.device - # ) - # elif k is None: - # cu_seqlens_k_new = torch.zeros_like( - # cu_seqlens_q, dtype=torch.int32, device=q.device - # ) if cache_seqlens is not None: max_seqlen_k = cache_seqlens.max().item() assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) - # max_page_size_per_seq = page_table.size(1) - # # will delete later - # num_pages_per_seq = torch.arange( - # 0, - # cache_seqlens.size(0) * max_page_size_per_seq, - # max_page_size_per_seq, - # device=cache_seqlens.device, - # ).to(torch.int32) cu_seqlens_k = torch.concat( ( torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), diff --git a/src/sycl/TripleOps.cpp b/src/sycl/TripleOps.cpp index 02999a4..3549f6b 100644 --- a/src/sycl/TripleOps.cpp +++ b/src/sycl/TripleOps.cpp @@ -88,8 +88,8 @@ struct op_and_mul_functor { template void get_config( - const Tensor& input, - const Tensor& out, + const at::Tensor& input, + const at::Tensor& out, int64_t& numel, int64_t& dim, int64_t& wg_size, @@ -111,7 +111,7 @@ void get_config( } template -void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) { +void silu_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) { auto _input = reinterpret_cast(input.data_ptr()); auto _out = reinterpret_cast(out.data_ptr()); @@ -136,7 +136,7 @@ void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) { return; } -void silu_and_mul(Tensor& out, Tensor& input) { +void silu_and_mul(at::Tensor& out, at::Tensor& input) { input = input.contiguous(); out = out.contiguous(); @@ -152,7 +152,7 @@ void silu_and_mul(Tensor& out, Tensor& input) { } template -void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) { +void gelu_tanh_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) { auto _input = reinterpret_cast(input.data_ptr()); auto _out = reinterpret_cast(out.data_ptr()); @@ -177,7 +177,7 @@ void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) { return; } -void gelu_tanh_and_mul(Tensor& out, Tensor& input) { +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { input = input.contiguous(); out = out.contiguous(); @@ -193,7 +193,7 @@ void gelu_tanh_and_mul(Tensor& out, Tensor& input) { } template -void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) { +void gelu_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) { auto _input = reinterpret_cast(input.data_ptr()); auto _out = reinterpret_cast(out.data_ptr()); @@ -218,7 +218,7 @@ void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) { return; } -void gelu_and_mul(Tensor& out, Tensor& input) { +void gelu_and_mul(at::Tensor& out, at::Tensor& input) { input = input.contiguous(); out = out.contiguous(); diff --git a/src/sycl/Utils.h b/src/sycl/Utils.h index d9edbbf..bd3c194 100644 --- a/src/sycl/Utils.h +++ b/src/sycl/Utils.h @@ -7,7 +7,10 @@ #define SYCL_MAX_SUB_GROUP_SIZE dpcppMaxSubGroupSize() -using namespace at; +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") using DeviceId = at::DeviceIndex; diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 1fd23b7..d29733f 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -5,6 +5,7 @@ #include +#include "Utils.h" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" @@ -313,10 +314,10 @@ struct KernelRunner { {// static_cast(params.q_ptr), static_cast(params.q_ptr), stride_Q, - // static_cast(params.knew_ptr), - // stride_K, - // static_cast(params.vnew_ptr), - // stride_V, + // static_cast(params.knew_ptr), + // stride_K, + // static_cast(params.vnew_ptr), + // stride_V, static_cast(params.k_ptr), stride_K_cache, static_cast(params.v_ptr), @@ -440,11 +441,6 @@ struct FMHAConfig { } }; -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") -#define CHECK_SHAPE(x, ...) \ - TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - inline int round_up_headdim(int head_size) { if (head_size <= 64) { return 64; @@ -475,15 +471,15 @@ std::vector mha_fwd( std::optional& cu_seqlens_k_, // b+1 std::optional max_seqlen_q_, std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) const float softmax_scale_, bool is_causal, int window_size_left, @@ -720,60 +716,6 @@ std::vector mha_fwd( params.num_pages = num_pages; } - // if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma - // at::Tensor k_new, v_new; - // TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); - // TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); - // at::Tensor cu_seqlens_k_new; - // bool const is_varlen_k_new = k_new_.value().dim() == 3; - // if (is_varlen_k_new) { - // cu_seqlens_k_new = cu_seqlens_k_new_.value(); - // CHECK_DEVICE(cu_seqlens_k_new); - // CHECK_CONTIGUOUS(cu_seqlens_k_new); - // TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); - // } - // k_new = k_new_.value(); - // v_new = v_new_.value(); - // TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); - // TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); - // CHECK_DEVICE(k_new); - // CHECK_DEVICE(v_new); - // TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); - // TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); - // int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 1; - // int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1) : k_new.size(0); - // if (!is_varlen_k_new) { - // CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - // CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); - // } else { - // CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - // CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); - // CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); - // } - // params.seqlen_knew = seqlen_k_new; - // params.total_knew = total_k_new; - // params.knew_ptr = k_new.data_ptr(); - // params.vnew_ptr = v_new.data_ptr(); - // // All stride are in elements, not bytes. - // params.knew_row_stride = k_new.stride(-3); - // params.vnew_row_stride = v_new.stride(-3); - // params.knew_head_stride = k_new.stride(-2); - // params.vnew_head_stride = v_new.stride(-2); - // if (!is_varlen_k_new) { - // params.knew_batch_stride = k_new.stride(0); - // params.vnew_batch_stride = v_new.stride(0); - // } - // if (is_varlen_k_new) { - // params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); - // } - // } else { - // TORCH_CHECK(cu_seqlens_k_new_.has_value(), "cu_seqlens_k_new all zeros"); - // params.seqlen_knew = 0; - // params.total_knew = 0; - // at::Tensor cu_seqlens_k_new = cu_seqlens_k_new_.value(); - // params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); - // } - if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK( @@ -850,37 +792,37 @@ std::vector mha_fwd( case 64: FMHAConfig< true, - Shape<_128, _64, _64>, - Shape<_128, _32, _64>, - Shape<_128, _64, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; case 96: FMHAConfig< true, - Shape<_128, _64, _32>, - Shape<_128, _32, _64>, - Shape<_128, _96, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_128, _64, _32>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _96, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; case 128: FMHAConfig< true, - Shape<_128, _64, _64>, - Shape<_128, _32, _64>, - Shape<_128, _128, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; case 192: FMHAConfig< true, - Shape<_256, _64, _64>, - Shape<_256, _32, _64>, - Shape<_256, _192, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_256, _64, _64>, + cute::Shape<_256, _32, _64>, + cute::Shape<_256, _192, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; default: @@ -891,37 +833,37 @@ std::vector mha_fwd( case 64: FMHAConfig< false, - Shape<_128, _64, _64>, - Shape<_128, _32, _64>, - Shape<_128, _64, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; case 96: FMHAConfig< false, - Shape<_128, _64, _32>, - Shape<_128, _32, _64>, - Shape<_128, _96, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_128, _64, _32>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _96, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; case 128: FMHAConfig< false, - Shape<_128, _64, _64>, - Shape<_128, _32, _64>, - Shape<_128, _128, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; case 192: FMHAConfig< false, - Shape<_256, _64, _64>, - Shape<_256, _32, _64>, - Shape<_256, _192, _64>, - Layout, Stride<_1, _1, _1>>, + cute::Shape<_256, _64, _64>, + cute::Shape<_256, _32, _64>, + cute::Shape<_256, _192, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, PipelineStages>::run(params); break; default: diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index e944945..1384bca 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -346,7 +346,7 @@ class FMHAPrefillChunk { auto mainloop_params = CollectiveMainloop::get_updated_copies( params.mainloop, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); - // we limit the horisontal size to two subgroup, the empirical resutls + // we limit the horizontal size to two subgroup, the empirical results // show that reading the two cacheline side by side in gives better // performance and anything after that does not have an effect on // performance. // (64 here for float b float when possible and loop over @@ -380,7 +380,8 @@ class FMHAPrefillChunk { int cached_nblock = 0; if constexpr (PagedKV) { // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size);// max_page_size_per_seq - // int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; + // int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * + // curr_batch_pages; int batch_offset = batch_coord * mainloop_params.max_num_pages_per_seq; cached_nblock = mainloop_params.ptr_page_table[batch_offset // page table for this batch ] * tiles_per_page; // base block idx of physical page @@ -397,11 +398,11 @@ class FMHAPrefillChunk { // workgroup_shape Tensor out_reg = make_tensor(AccumeShape{}); - // There are 16 workitem and 16 max per subgroup, each worktime containt 1 + // There are 16 workitem and 16 max per subgroup, each worktime contains 1 // max and cumulatively, they calculate the max per subgroup ElementAccumulator max_reg{-INFINITY}; - // The sum reg each contains a 2d tesnor for 8 x 2 This is number of - // sequence lenght process per subgroup + // The sum reg each contains a 2d tensor for 8 x 2 This is number of + // sequence length process per subgroup Tensor sum_reg = make_tensor(Shape, Int>{}); clear(sum_reg); @@ -570,7 +571,7 @@ class FMHAPrefillChunk { cached_nblock = next_cached_nblock; // Prefetch the next K tile - // there is no need to gaurd it with if statememt as prefetch will + // there is no need to guard it with if statement as prefetch will // ignore out of bound reading CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<4>(pKgK_cache); j++) { diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp index 94a5b66..3437049 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp @@ -206,8 +206,8 @@ class FlashChunkPrefillEpilogue< for (int y = 0; y < FragsM; y++) { CUTLASS_PRAGMA_UNROLL for (int x = 0; x < Vec; x++) { - int indx = y * Vec + x; - auto cur_sum = reduce_over_group(sg, sum(indx), sycl::plus<>()); + int index = y * Vec + x; + auto cur_sum = reduce_over_group(sg, sum(index), sycl::plus<>()); auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.0f : sycl::native::recip(cur_sum); CUTLASS_PRAGMA_UNROLL for (int z = 0; z < FragsN; z++) { diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index a4e3df0..4c21c3b 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -265,7 +265,7 @@ struct FlashChunkPrefillMma< FragSrc const& frag_src, int const& k_tile_count, Params const& params) { - auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; + auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; int thread_idx = static_cast(ThreadIdxX()); auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); @@ -342,9 +342,9 @@ struct FlashChunkPrefillMma< } template - CUTLASS_DEVICE void mmaPV( - FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params) { - auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache ; + CUTLASS_DEVICE void + mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params) { + auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache; int thread_idx = static_cast(ThreadIdxX()); // Instantiate the MMA object @@ -422,7 +422,6 @@ struct FlashChunkPrefillMma< offset_v_cache = kv_head_coord * head_size_vo; total_seq_len_kv_cache = get<5>(problem_shape).total_length; } else { - } auto q_traits = static_cast(params.gmem_tiled_copy_q); @@ -435,7 +434,7 @@ struct FlashChunkPrefillMma< // stride {seq_len*head*dim_head, head*dim_head, dim_head, 1} auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); - + auto shape_k_cache = make_shape( static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), head_size_qk * num_heads_kv, 1); StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp index c2bc020..75f8931 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp @@ -110,11 +110,11 @@ class FlashChunkPrefillSoftmaxEpilogue().get_sub_group(); const auto max_scale = max * params.scale; CUTLASS_PRAGMA_UNROLL - for (int indx = 0; indx < Vec * FragsM; indx++) { - const auto max_scale_bcast = group_broadcast(g, max_scale, indx); + for (int index = 0; index < Vec * FragsM; index++) { + const auto max_scale_bcast = group_broadcast(g, max_scale, index); CUTLASS_PRAGMA_UNROLL for (int z = 0; z < FragsN; z++) { - auto base_indx = indx + (z * Vec * FragsM); + auto base_indx = index + (z * Vec * FragsM); if constexpr (LocalMask) { if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { @@ -128,7 +128,7 @@ class FlashChunkPrefillSoftmaxEpilogue().get_sub_group(); CUTLASS_PRAGMA_UNROLL - for (int indx = 0; indx < Vec * FragsM; indx++) { - auto maxptr = group_broadcast(sg, max, indx); + for (int index = 0; index < Vec * FragsM; index++) { + auto maxptr = group_broadcast(sg, max, index); CUTLASS_PRAGMA_UNROLL for (int z = 0; z < FragsN; z++) { - auto base_indx = indx + (z * Vec * FragsM); + auto base_indx = index + (z * Vec * FragsM); maxptr = sycl::max(maxptr, src(base_indx)); src(base_indx) *= params.scale; } maxptr = reduce_over_group(sg, maxptr, sycl::maximum<>()); - if (indx == sg.get_local_id()[0]) { + if (index == sg.get_local_id()[0]) { max = maxptr; } } @@ -178,13 +178,13 @@ class FlashChunkPrefillSoftmaxEpilogue Date: Fri, 3 Oct 2025 01:07:36 +0800 Subject: [PATCH 25/25] enable FA on PVC --- python/sgl_kernel/flash_attn.py | 3 +-- tests/test_flash_attention.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index ccf735e..b93b374 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -19,8 +19,7 @@ def is_fa3_supported(device=None) -> bool: or torch.cuda.get_device_capability(device)[0] == 8 ) and (torch.version.cuda >= "12.3") elif torch.xpu.is_available(): - device_name = torch.xpu.get_device_properties(0).name - return "B580" in device_name or "e211" in device_name + return torch.xpu.get_device_properties().has_fp64 else: return False diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index f9a0756..3224d1a 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -37,8 +37,7 @@ def is_fa3_supported(device=None) -> bool: or torch.cuda.get_device_capability(device)[0] == 8 ) and (torch.version.cuda >= "12.3") elif torch.xpu.is_available(): - device_name = torch.xpu.get_device_properties(0).name - return "B580" in device_name or "e211" in device_name + return torch.xpu.get_device_properties().has_fp64 else: return False