diff --git a/CMakeLists.txt b/CMakeLists.txt index 005457e..6c0e6f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,7 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21") set(TORCH_SUPPORTED_VERSION_XPU "2.8.0") set(ENABLE_MOE_KERNEL OFF) +set(FA2_ENABLED ON) # # Try to find python package with an executable that exactly matches @@ -146,16 +147,65 @@ endif() if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_SRC + "csrc/xpu/cutlass_sycl_demo.cpp" "csrc/xpu/layernorm.cpp" "csrc/xpu/torch_bindings.cpp" ) include_directories("/usr/include") - set(CMPLR_ROOT $ENV{CMPLR_ROOT}) - set(CMAKE_CXX_COMPILER icpx) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) + message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) - list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + + + # add cutlass dependency + set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "chunk_prefill_BMG" CACHE STRING "CUTLASS revision to use") + + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + FetchContent_Declare( + cutlass-sycl + GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) + + # cutlass compilation flags + set(CUTLASS_ENABLE_SYCL "ON") + # set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures") + set(CMAKE_EXPORT_COMPILE_COMMANDS "ON") + set(CUTLASS_ENABLE_BENCHMARKS "OFF") + # disable cuda + set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") + # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) + # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) + + + FetchContent_MakeAvailable(cutlass-sycl) + set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") + set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") + set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "") + message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}") + + # header only library + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL") + list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET") + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED") + list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0") + list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always") + endif() message(STATUS "Enabling C extension.") @@ -169,9 +219,47 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) +# +# flash attention _C extension +# + +if (FA2_ENABLED) + message(STATUS "Enabling fa2 extension.") + file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp") + + # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-O3") + list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") + + define_gpu_extension_target( + _vllm_fa2_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES + csrc/flash_attn/flash_api.cpp + ${FA2_GEN_SRCS} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + + # target_include_directories(_vllm_fa2_C PRIVATE + # csrc/flash_attn + # csrc/flash_attn/src) +endif () + + # # _moe_C extension # @@ -192,6 +280,7 @@ if (ENABLE_MOE_KERNEL) ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) endif() diff --git a/cmake/toolchain.cmake b/cmake/toolchain.cmake new file mode 100644 index 0000000..f8cf8b9 --- /dev/null +++ b/cmake/toolchain.cmake @@ -0,0 +1,6 @@ +# use this file to set the compiler and flags for SYCL + +set(CMPLR_ROOT $ENV{CMPLR_ROOT}) +message(STATUS "CMPLR_ROOT: ${CMPLR_ROOT}") +set(CMAKE_CXX_COMPILER ${CMPLR_ROOT}/bin/icpx) +set(CMAKE_C_COMPILER ${CMPLR_ROOT}/bin/icx) \ No newline at end of file diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 4d0ce1c..5f9cdeb 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -22,6 +22,7 @@ #define REGISTER_EXTENSION(NAME) \ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ - STRINGIFY(NAME), nullptr, 0, nullptr}; \ + STRINGIFY(NAME), nullptr, 0, nullptr, \ + nullptr, nullptr, nullptr, nullptr}; \ return PyModule_Create(&module); \ } diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp new file mode 100644 index 0000000..4a15d5e --- /dev/null +++ b/csrc/flash_attn/flash_api.cpp @@ -0,0 +1,82 @@ +#include "pytorch_shim.h" + +#include "core/registration.h" +#include "xpu/cutlass_kernels/chunk_prefill.hpp" +#include + +namespace FLASH_NAMESPACE { + +std::vector mha_varlen_fwd( + const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := + // \sum_{i=0}^{b} s_i or num_blocks x page_block_size + // x num_heads_k x head_size if there's a block_table. + const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := + // \sum_{i=0}^{b} s_i or num_blocks x page_block_size + // x num_heads_k x head_size if there's a block_table. + std::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& leftpad_k_, // batch_size + at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, std::optional gen_) { + at::Tensor out; + if(out_.has_value()) { + out = *out_; + } + else { + out = torch::zeros_like(q); + } + + cutlass_chunk_prefill_impl( + q, + k, + v, + out, + block_table_, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal); + + if(return_softmax) { + auto softmax_lse = torch::empty_like(out); + return {out, softmax_lse}; + } + else { + at::Tensor softmax_lse; + return {out, softmax_lse}; + } +} +} // namespace FLASH_NAMESPACE + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def( + "varlen_fwd(Tensor q, Tensor k, Tensor v, Tensor!? out, Tensor " + "cu_seqlens_q, " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor " + "block_table, Tensor? alibi_slopes, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " + "softmax_scale, bool zero_tensors, " + "bool is_causal, int window_size_left, int window_size_right, float " + "softcap, bool return_softmax, " + "Generator? gen) -> Tensor[]"); + ops.impl("varlen_fwd", torch::kXPU, + make_pytorch_shim(&FLASH_NAMESPACE::mha_varlen_fwd)); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/flash_attn/pytorch_shim.h b/csrc/flash_attn/pytorch_shim.h new file mode 100644 index 0000000..82f5f5c --- /dev/null +++ b/csrc/flash_attn/pytorch_shim.h @@ -0,0 +1,110 @@ +#pragma once + +#include + +/** + * Unfortunately, the type signatures of the flash_attn ops are not compatible + * with the PyTorch library bindings. To get around that we use + * `make_pytorch_shim` which creates a lambda that exponses the API using + * PyTorch compatible types to the types, then converts them to the types + * expected by the flash_attn ops. This shims allows us to make minimal changes + * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch library bindings: + * - `int` + * - `float` + * - `std::optional &` + * - `std::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const std::optional&` + * - `const std::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { return arg; } +}; + +template +using pytorch_library_compatible_type_t = + typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type( + pytorch_library_compatible_type_t arg) { + return pytorch_library_compatible_type::convert_from_type(arg); +} + +// Map `std::optional &` -> `const std::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type&> { + using type = const std::optional&; + static std::optional& convert_from_type(const std::optional& arg) { + return const_cast&>(arg); + } +}; + +// Map `std::optional` -> +// `std::optional>` +// (NOTE: tested for `std::optional` -> `std::optional`) +template +struct pytorch_library_compatible_type> { + using type = std::optional>; + static std::optional> convert_from_type( + std::optional arg) { + return arg; + } +}; + +// Map `std::optional&` -> `const std::optional&` +template <> +struct pytorch_library_compatible_type&> { + using type = const std::optional&; + static std::optional& convert_from_type( + const std::optional& arg) { + return const_cast&>( + reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), + "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), + "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template <> +struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), + "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret (*fun)(Args... args)) { + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp new file mode 100644 index 0000000..c96b8dd --- /dev/null +++ b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp @@ -0,0 +1,331 @@ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "utils.hpp" + +using namespace cute; + +struct chunk_prefill_args_t { + void* query; + void* key; + void* value; + void* out; + void* block_table; + void* num_blocks_per_seq; + void* cu_seqlens_q; + void* cu_seqlens_k; + int max_queries; + int max_keys; + int total_seqlen_q; + int total_seqlen_k; + float sm_scale; + int batch_size; + int num_heads_q; + int num_heads_k; + int head_size; + int max_blocks_per_seq; + int block_size; + bool is_causal; +}; + +template struct KernelLauncher { + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + ProblemShapeType initialize(const chunk_prefill_args_t &args) { + auto problem_shape = cute::make_tuple( + 1, + args.num_heads_q, + args.num_heads_k, + args.total_seqlen_q, + args.total_seqlen_k, + args.total_seqlen_k, + args.head_size, + args.head_size); + auto problem_shape_out = cute::make_tuple( + args.batch_size, + args.num_heads_q, + args.num_heads_k, + cutlass::fmha::collective::VariableLength{args.max_queries}, // cu_q + cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_k + cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_v + args.head_size, + args.head_size); + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + auto group_q_size = num_heads_q / num_heads_kv; + auto group_q_num = num_heads_q / group_q_size; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv)); + + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); + + get<3>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_q); + get<4>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); + get<5>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); + + return problem_shape_out; + } + + cutlass::Status run(const chunk_prefill_args_t &args, const cutlass::KernelHardwareInfo &hw_info) { + ProblemShapeType problem_size = initialize(args); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + reinterpret_cast(args.query), stride_Q, + reinterpret_cast(args.key), stride_K, + reinterpret_cast(args.value), stride_V, + reinterpret_cast(args.key), stride_K_cache, + reinterpret_cast(args.value), stride_V_cache, + static_cast(args.block_table), + args.block_size, + static_cast(args.num_blocks_per_seq) + }, + {args.sm_scale}, + {reinterpret_cast(args.out), stride_O}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params); + + syclcompat::wait(); + return cutlass::Status::kSuccess; + } + + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + +// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace syclcompat::experimental; + auto event = launch>( + launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size}}, + params); +#else + syclcompat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size + }; + syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = syclcompat::experimental::launch>(policy, params); +#endif + + EventManager::getInstance().addEvent(event); + } +}; + +template struct FMHAKernel { + + template + static void run(const chunk_prefill_args_t &args) { + cutlass::KernelHardwareInfo hw_info; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using ElementInputKV = ElementInputQ; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, + GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, + PagedKV>; + + using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; + + KernelLauncher launcher; + + launcher.run(args, hw_info); + } + + static void dispatch(const chunk_prefill_args_t &args) { + if(args.is_causal) { + run(args); + } + else { + run(args); + } + } +}; + +template +void policy_dispatch( + CutlassType cuType, + const chunk_prefill_args_t& args) { + const int PipelineStages = 2; + if(cuType == CutlassType::half) { + FMHAKernel::dispatch(args); + } + else { + FMHAKernel::dispatch(args); + } +} + +void chunk_prefill_kernel( + CutlassType cuType, + const chunk_prefill_args_t& args) { + if(args.head_size == HEAD_SIZE_LIMIT_0) { + policy_dispatch(cuType, args); + } else if(args.head_size == HEAD_SIZE_LIMIT_1) { + policy_dispatch(cuType, args); + } + else if(args.head_size == HEAD_SIZE_LIMIT_2) { + policy_dispatch(cuType, args); + } +} + +void cutlass_chunk_prefill_impl( + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + double sm_scale, + bool is_causal) { + int num_block = key_cache.size(0); + int block_size = key_cache.size(1); + int num_heads_q = query.size(1); + int num_heads_kv = key_cache.size(2); + int head_size = query.size(2); + int batch_size = cu_seqlens_q.numel() - 1; + int max_blocks_per_seq = block_table.size(1); + int total_seqlen_q = query.size(0); + int total_seqlen_k = num_block * block_size; + at::Tensor num_blocks_per_seq = torch::div(cu_seqlens_k, block_size); + + chunk_prefill_args_t args = { + query.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + out.data_ptr(), + block_table.data_ptr(), + num_blocks_per_seq.data_ptr(), + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + static_cast(sm_scale), + batch_size, + num_heads_q, + num_heads_kv, + head_size, + max_blocks_per_seq, + block_size, + is_causal + }; + CutlassType cuType = aten_to_Cutlass_dtype(query); + chunk_prefill_kernel(cuType, args); +} diff --git a/csrc/xpu/cutlass_kernels/utils.hpp b/csrc/xpu/cutlass_kernels/utils.hpp new file mode 100644 index 0000000..4715419 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/utils.hpp @@ -0,0 +1,49 @@ +#pragma once +#include "torch/all.h" +#include + +#define HEAD_SIZE_LIMIT_0 64 +#define HEAD_SIZE_LIMIT_1 128 +#define HEAD_SIZE_LIMIT_2 256 +#define HEAD_SIZE_LIMIT_3 512 + +enum class CutlassType { + half, + bfloat16, +}; + +inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { + CutlassType cuType; + if (input.scalar_type() == torch::kHalf) { + cuType = CutlassType::half; + } else if (input.scalar_type() == torch::kBFloat16) { + cuType = CutlassType::bfloat16; + } else { + TORCH_INTERNAL_ASSERT( + false, + ""); + } + return cuType; +} + +using namespace cute; +struct chunk_policy_head64 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct chunk_policy_head128 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct chunk_policy_head256 { + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; \ No newline at end of file diff --git a/csrc/xpu/cutlass_sycl_demo.cpp b/csrc/xpu/cutlass_sycl_demo.cpp new file mode 100644 index 0000000..7254b65 --- /dev/null +++ b/csrc/xpu/cutlass_sycl_demo.cpp @@ -0,0 +1,520 @@ + + +#include +#include +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +using namespace cute; + +/// Helper to initialize a block of device data +template +bool initialize_block(Element* block, std::size_t size, uint64_t seed = 2023) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform(block, size, seed, + scope_max, scope_min, 0); + + syclcompat::wait(); + return true; +} + +template +bool initialize_block(cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + return initialize_block(block.get(), block.size(), seed); +} + +template +void initialize_mixed_dtype_block( + cutlass::DeviceAllocation& block_device, + cutlass::DeviceAllocation& block_device_dq, uint64_t seed) { + static_assert(cute::sizeof_bits_v >= 8); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + int bits_input = cute::sizeof_bits_v; + T1 scope_max, scope_min; + if (bits_input == 1) { + scope_max = T1(2); + scope_min = T1(0); + } else if (bits_input <= 8) { + scope_max = T1(2); + scope_min = T1(-2); + } else { + scope_max = T1(8); + scope_min = T1(-8); + } + + std::uniform_int_distribution<> dist(scope_min, scope_max); + + if constexpr (cute::sizeof_bits_v >= 8) { + auto block_host = std::vector(block_device.size()); + auto block_host_dq = std::vector(block_device.size()); + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_dq.copy_from_host(block_host_dq.data()); + } else { + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + auto block_host_dq = std::vector(array_size); + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i].get()); + } + + static constexpr auto elements_per_byte = + cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = block_device.size() / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device( + block_device.get() + (i * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), array_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + i * array_size, block_host_dq.data(), + array_size); + } + + auto tail_size = block_device.size() % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device( + block_device.get() + (loop_cnt * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), tail_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + loop_cnt * array_size, block_host_dq.data(), + tail_size); + } + } +} + +template +inline bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// TODO(Codeplay): use on device initialisation for this +template +inline void random_fill(T* src, int seed, size_t N, float max, float min) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + std::random_device rd; + std::mt19937 gen(seed); + std::uniform_real_distribution dis(min, max); + auto buff = std::vector(N); + + for (size_t i = 0; i < N; ++i) { + buff[i] = (T)(dis(gen)); + } + syclcompat::memcpy(src, buff.data(), N); + syclcompat::wait(); + } else { + assert(0 & "Not supported dtype"); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options() + : help(false), + error(false), + m(5120), + n(4096), + k(4096), + l(1), + iterations(20), + alpha(1.f), + beta(0.f) {} + + // Parses the command line + void parse(int argc, char const** args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream& print_usage(std::ostream& out) const { + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage " + "statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of " + "the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner { + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation + block_ref_D; // Reference GEMM result for verification + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, + ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, + cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. + // default in-order queue + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // Complete the stride by combining static layout info (StrideA) with + // runtime size info (M,K,L) + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, + const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = + ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, + block_C.get(), + stride_C, + block_D.get(), + stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n + << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = + (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' + << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", + tflops / cute_time, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +void cutlass_sycl_demo(torch::Tensor& a) { + // + // Parse options + // + // + std::cout << a.sizes() << std::endl; + + Options options; + + /* options.parse(argc, argv); */ + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a + // given device ID. This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and + // computation between elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = + bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = + bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // The 2D block copy operations used for the A and B matrices + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, + // combining both additional hardware (sub-groups for Intel BMG) and + // iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom + // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group + // layout (8x4x1). The TiledMMA constructed using TiledMMAHelper has the + // property that each sub-group operates on a single contiguous chunk of the + // work-group TileShape. For this configuration, this implies that each + // sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major + // (stride 4,1,0) for performance reasons. + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper< + MMA_Atom, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch + // from A and B. + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // This is the 'default' epilogue operation (Linear Combination) which + // performs everything in: (D = alpha * (A*B) + beta * C) aside from the + // (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for + // more complex epilogue examples. + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, ElementComputeEpilogue, ElementAccumulator, + ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; + + // FusionCallbacks ties the EpilogueOp to an implementation (based on the + // dispatch policy/architecture) and defines the epilogue arguments. + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + EpilogueDispatchPolicy, EpilogueOp, TileShape, + decltype(tile_shape(TiledMma()))>; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & + // load/stores any auxiliary data required + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, TileShape, ElementAccumulator, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + ElementOutput, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + FusionCallBacks, + XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + void, void, + XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + void, void>; + + // GEMM Mainloop - iteration over blocks in K dimension + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, TileShape, ElementInputA, + cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + ElementInputB, + cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + TiledMma, GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + // Define the whole kernel (mainloop and epilogue) + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Defer global problem shape definition to + // runtime + CollectiveMainloop, CollectiveEpilogue>; + + // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the + // launch, and e.g. persistent scratch memory if required. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); +} diff --git a/csrc/xpu/helper.h b/csrc/xpu/helper.h new file mode 100644 index 0000000..4bc345c --- /dev/null +++ b/csrc/xpu/helper.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) + #include "cutlass/util/sycl_timer.hpp" +#else + #include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU + * stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; \ No newline at end of file diff --git a/csrc/xpu/mha.h b/csrc/xpu/mha.h new file mode 100644 index 0000000..a18cc0c --- /dev/null +++ b/csrc/xpu/mha.h @@ -0,0 +1,16 @@ +#pragma once + + +void cutlass_chunk_prefill_impl( + at::Tensor& query, // [seq_q, heads, head_size] + at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + at::Tensor& value_cache, + at::Tensor& out, + at::Tensor& block_table, + at::Tensor& cu_seqlens_q, + at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + double sm_scale, + bool is_causal +); \ No newline at end of file diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index acf62e6..7c3dbdc 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -6,4 +6,4 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, double epsilon); \ No newline at end of file + torch::Tensor& weight, double epsilon); diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 39c193e..9c9e0a2 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -31,6 +31,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm); + + ops.def("cutlass_sycl_demo(Tensor a) -> ()"); + ops.impl("cutlass_sycl_demo", torch::kXPU, &cutlass_sycl_demo); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index 346a13d..c7d8660 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,7 @@ def configure(self, ext: CMakeExtension) -> None: cmake_args = [ '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + '-DCMAKE_TOOLCHAIN_FILE=cmake/toolchain.cmake' ] verbose = envs.VERBOSE @@ -258,6 +259,7 @@ def run(self): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C")) + ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fa2_C")) if ext_modules: cmdclass = {"build_ext": cmake_build_ext} diff --git a/tests/flash_attn/test_flash_attn_varlen_func.py b/tests/flash_attn/test_flash_attn_varlen_func.py new file mode 100644 index 0000000..fab86b4 --- /dev/null +++ b/tests/flash_attn/test_flash_attn_varlen_func.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func + +DTYPES = [torch.half, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_flash_attn_varlen_func(dtype): + torch.set_default_device("xpu") + batch_size = 1 + seq_len = 4 + num_heads = 8 + head_dim = 16 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + + max_seqlen_q = seq_len + cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) + max_seqlen_k = seq_len + cu_seqlens_k = cu_seqlens_q + + # Call the flash attention function + output = flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, + max_seqlen_k, cu_seqlens_k) + + assert output is not None + assert output.dtype == dtype + assert output.shape == (batch_size, seq_len, num_heads, head_dim) diff --git a/tests/test_cutlass_op.py b/tests/test_cutlass_op.py new file mode 100644 index 0000000..f575ae9 --- /dev/null +++ b/tests/test_cutlass_op.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm_xpu_kernels._C # noqa F401 + +DTYPES = [torch.half, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_cutlass_op(dtype: torch.dtype, ): + torch.set_default_device("xpu") + a = torch.zeros((2, 3), dtype=dtype, device="xpu") + torch.ops._C.cutlass_sycl_demo(a) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py new file mode 100644 index 0000000..8635cc2 --- /dev/null +++ b/vllm_xpu_kernels/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py new file mode 100644 index 0000000..33e7de8 --- /dev/null +++ b/vllm_xpu_kernels/flash_attn_interface.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +#isort: off +try: + from . import _vllm_fa2_C # noqa: F401 + FA2_UNAVAILABLE_REASON = None + FA2_AVAILABLE = True +except ImportError as e: + FA2_UNAVAILABLE_REASON = str(e) + FA2_AVAILABLE = False + +#isort: on + +DEFAULT_FA_VERSION = 2 + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q, + cu_seqlens_q, + max_seqlen_k, + cu_seqlens_k=None, # only used for non-paged prefill + seqused_k=None, + q_v=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size: Optional[list[int]] = None, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, + return_softmax_lse=False, + out=None, + # FA3 Only + scheduler_metadata=None, + q_descale=None, + k_descale=None, + v_descale=None, + num_splits: int = 0, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, +): + assert cu_seqlens_k is not None or seqused_k is not None, \ + "cu_seqlens_k or seqused_k must be provided" + assert cu_seqlens_k is None or seqused_k is None, \ + "cu_seqlens_k and seqused_k cannot be provided at the same time" + + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + # custom op does not support non-tuple input + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) + + if fa_version == 2: + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale") + if num_splits > 1: + raise NotImplementedError("FA2 does not support num_splits > 1") + out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp + # still wants it so we pass all zeros + dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, + seqused_k, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + real_window_size[0], + real_window_size[1], + softcap, + return_softmax_lse and dropout_p > 0, + None, + ) + else: + raise NotImplementedError("not support yet") + return (out, softmax_lse) if return_softmax_lse else out