diff --git a/CMakeLists.txt b/CMakeLists.txt index 005457e..713438f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,8 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21") set(TORCH_SUPPORTED_VERSION_XPU "2.8.0") set(ENABLE_MOE_KERNEL OFF) +set(FA2_ENABLED OFF) +set(FP8_ENABLED ON) # # Try to find python package with an executable that exactly matches @@ -146,16 +148,65 @@ endif() if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_SRC + "csrc/xpu/cutlass_sycl_demo.cpp" "csrc/xpu/layernorm.cpp" "csrc/xpu/torch_bindings.cpp" ) include_directories("/usr/include") - set(CMPLR_ROOT $ENV{CMPLR_ROOT}) - set(CMAKE_CXX_COMPILER icpx) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) + message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) - list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + + + # add cutlass dependency + set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "scaled_mm" CACHE STRING "CUTLASS revision to use") + + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + FetchContent_Declare( + cutlass-sycl + GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) + + # cutlass compilation flags + set(CUTLASS_ENABLE_SYCL "ON") + # set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures") + set(CMAKE_EXPORT_COMPILE_COMMANDS "ON") + set(CUTLASS_ENABLE_BENCHMARKS "OFF") + # disable cuda + set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") + # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) + # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) + + + FetchContent_MakeAvailable(cutlass-sycl) + set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") + set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") + set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "") + message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}") + + # header only library + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL") + list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET") + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED") + list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0") + list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always") + endif() message(STATUS "Enabling C extension.") @@ -169,9 +220,82 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) +# +# flash attention _C extension +# + +if (FA2_ENABLED) + message(STATUS "Enabling fa2 extension.") + file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp") + + # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-O3") + list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") + + define_gpu_extension_target( + _vllm_fa2_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES + csrc/flash_attn/flash_api.cpp + ${FA2_GEN_SRCS} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + + # target_include_directories(_vllm_fa2_C PRIVATE + # csrc/flash_attn + # csrc/flash_attn/src) +endif () + +if (FP8_ENABLED) + message(STATUS "Enabling FP8 extension.") + file(GLOB FP8_GEN_SRCS "csrc/quantization/*.cpp") + + # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-O3") + list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") + # XPU FLAGS + list(APPEND VLLM_GPU_FLAGS "-O3" "-DNDEBUG") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10") + + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen") + list(APPEND VLLM_GPU_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread") + + define_gpu_extension_target( + _vllm_fp8_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FP8_GEN_SRCS} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + + # target_include_directories(_vllm_fa2_C PRIVATE + # csrc/flash_attn + # csrc/flash_attn/src) +endif () + # # _moe_C extension # @@ -192,6 +316,7 @@ if (ENABLE_MOE_KERNEL) ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) endif() diff --git a/cmake/toolchain.cmake b/cmake/toolchain.cmake new file mode 100644 index 0000000..f8cf8b9 --- /dev/null +++ b/cmake/toolchain.cmake @@ -0,0 +1,6 @@ +# use this file to set the compiler and flags for SYCL + +set(CMPLR_ROOT $ENV{CMPLR_ROOT}) +message(STATUS "CMPLR_ROOT: ${CMPLR_ROOT}") +set(CMAKE_CXX_COMPILER ${CMPLR_ROOT}/bin/icpx) +set(CMAKE_C_COMPILER ${CMPLR_ROOT}/bin/icx) \ No newline at end of file diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 9b6d7ab..46d1713 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -1,4 +1,7 @@ #pragma once +#pragma push_macro("printf") +#undef printf + #include @@ -14,7 +17,7 @@ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME // could be a macro instead of a literal token. -#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) // REGISTER_EXTENSION allows the shared library to be loaded and initialized @@ -22,6 +25,9 @@ #define REGISTER_EXTENSION(NAME) \ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ - STRINGIFY(NAME), nullptr, 0, nullptr}; \ + STRINGIFY(NAME), nullptr, 0, nullptr, \ + nullptr, nullptr, nullptr, nullptr}; \ return PyModule_Create(&module); \ } + +#pragma pop_macro("printf") diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp new file mode 100644 index 0000000..4a15d5e --- /dev/null +++ b/csrc/flash_attn/flash_api.cpp @@ -0,0 +1,82 @@ +#include "pytorch_shim.h" + +#include "core/registration.h" +#include "xpu/cutlass_kernels/chunk_prefill.hpp" +#include + +namespace FLASH_NAMESPACE { + +std::vector mha_varlen_fwd( + const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := + // \sum_{i=0}^{b} s_i or num_blocks x page_block_size + // x num_heads_k x head_size if there's a block_table. + const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := + // \sum_{i=0}^{b} s_i or num_blocks x page_block_size + // x num_heads_k x head_size if there's a block_table. + std::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& leftpad_k_, // batch_size + at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, std::optional gen_) { + at::Tensor out; + if(out_.has_value()) { + out = *out_; + } + else { + out = torch::zeros_like(q); + } + + cutlass_chunk_prefill_impl( + q, + k, + v, + out, + block_table_, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal); + + if(return_softmax) { + auto softmax_lse = torch::empty_like(out); + return {out, softmax_lse}; + } + else { + at::Tensor softmax_lse; + return {out, softmax_lse}; + } +} +} // namespace FLASH_NAMESPACE + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def( + "varlen_fwd(Tensor q, Tensor k, Tensor v, Tensor!? out, Tensor " + "cu_seqlens_q, " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor " + "block_table, Tensor? alibi_slopes, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " + "softmax_scale, bool zero_tensors, " + "bool is_causal, int window_size_left, int window_size_right, float " + "softcap, bool return_softmax, " + "Generator? gen) -> Tensor[]"); + ops.impl("varlen_fwd", torch::kXPU, + make_pytorch_shim(&FLASH_NAMESPACE::mha_varlen_fwd)); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/flash_attn/pytorch_shim.h b/csrc/flash_attn/pytorch_shim.h new file mode 100644 index 0000000..82f5f5c --- /dev/null +++ b/csrc/flash_attn/pytorch_shim.h @@ -0,0 +1,110 @@ +#pragma once + +#include + +/** + * Unfortunately, the type signatures of the flash_attn ops are not compatible + * with the PyTorch library bindings. To get around that we use + * `make_pytorch_shim` which creates a lambda that exponses the API using + * PyTorch compatible types to the types, then converts them to the types + * expected by the flash_attn ops. This shims allows us to make minimal changes + * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch library bindings: + * - `int` + * - `float` + * - `std::optional &` + * - `std::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const std::optional&` + * - `const std::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { return arg; } +}; + +template +using pytorch_library_compatible_type_t = + typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type( + pytorch_library_compatible_type_t arg) { + return pytorch_library_compatible_type::convert_from_type(arg); +} + +// Map `std::optional &` -> `const std::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type&> { + using type = const std::optional&; + static std::optional& convert_from_type(const std::optional& arg) { + return const_cast&>(arg); + } +}; + +// Map `std::optional` -> +// `std::optional>` +// (NOTE: tested for `std::optional` -> `std::optional`) +template +struct pytorch_library_compatible_type> { + using type = std::optional>; + static std::optional> convert_from_type( + std::optional arg) { + return arg; + } +}; + +// Map `std::optional&` -> `const std::optional&` +template <> +struct pytorch_library_compatible_type&> { + using type = const std::optional&; + static std::optional& convert_from_type( + const std::optional& arg) { + return const_cast&>( + reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), + "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), + "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template <> +struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), + "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret (*fun)(Args... args)) { + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} diff --git a/csrc/quantization/base_gemm.h b/csrc/quantization/base_gemm.h new file mode 100644 index 0000000..55dcc45 --- /dev/null +++ b/csrc/quantization/base_gemm.h @@ -0,0 +1,256 @@ +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" +#include "sycl_common.hpp" + +namespace gpu::cutlass_kernel { +namespace basic_gemm { + +using namespace cute; +using bf16 = sycl::ext::oneapi::bfloat16; +using fp16 = sycl::half; + +struct Args { + int m, n, k, l; + float alpha, beta; + + Args() = default; + Args(int m, int n, int k, int l, float alpha, float beta) + : m(m), n(n), k(k), l(l), alpha(alpha), beta(beta) {} +}; + +template +struct GemmRunner { + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + std::cout << M << " " << N << " " << K << " " << L << " " << std::endl; + + // Complete the stride by combining static layout info (StrideA) with + // runtime size info (M,K,L) + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // block_A.reset(static_cast(M) * K * L); + // block_B.reset(static_cast(K) * N * L); + // block_C.reset(static_cast(M) * N * L); + // block_D.reset(static_cast(M) * N * L); + // initialize_block(block_A, seed + 2023); + // initialize_block(block_B, seed + 2022); + // initialize_block(block_C, seed + 2021); + } + + cutlass::Status run( + sycl::queue* stream, + const Args& args, + const cutlass::KernelHardwareInfo& hw_info, + T* inputA, + T* inputB, + float* inputC, + float* res) { + std::cout << "into Runner" << std::endl; + std::cout << args.alpha << " " << args.beta << std::endl; + + ProblemShapeType problem_size = + ProblemShapeType{args.m, args.n, args.k, args.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {inputA, stride_A, inputB, stride_B}, + {{args.alpha, args.beta}, inputC, stride_C, res, stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Invalid Problem Size: " << args.m << 'x' << args.n << 'x' + << args.k << 'x' << args.l << std::endl; + std::exit(1); + } + + std::cout << "into initialize" << std::endl; + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get(), stream)); + + std::cout << "into gemm run" << std::endl; + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + std::cout << "finish gemm run" << std::endl; + + return cutlass::Status::kSuccess; + } +}; + +template +void gemm_functor( + sycl::queue* stream, + void* inputA, + void* inputB, + void* inputC, + void* res, + const int m, + const int n, + const int k, + float alpha, + float beta) { + Args args(m, n, k, 1, alpha, beta); + + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = T; + using ElementInputB = T; + using ElementOutput = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_N; + + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = typename TiledMMAHelper< + MMA_Atom, + cute::Layout, + cute::Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, + ElementComputeEpilogue, + ElementAccumulator, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + EpilogueDispatchPolicy, + EpilogueOp, + TileShape, + decltype(tile_shape(TiledMma()))>; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, + void, + XE_2D_U32x8x16_ST_N, + void, + void>; + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + void, + void, + cute::identity, + GmemTiledCopyB, + void, + void, + cute::identity>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + GemmRunner runner; + CUTLASS_CHECK(runner.run( + stream, + args, + hw_info, + reinterpret_cast(inputA), + reinterpret_cast(inputB), + reinterpret_cast(inputC), + reinterpret_cast(res))); +} + +} // namespace basic_gemm +} // namespace gpu::cutlass_kernel diff --git a/csrc/quantization/cutlass_kernels.cpp b/csrc/quantization/cutlass_kernels.cpp new file mode 100644 index 0000000..af696bd --- /dev/null +++ b/csrc/quantization/cutlass_kernels.cpp @@ -0,0 +1,88 @@ +#include "base_gemm.h" +#include "scaled_mm.h" + +#include +#include +#include + +#include +/* #include "pytorch_shim.h" */ + +#include "core/registration.h" +#include +#include "xpu/utils.h" + +namespace gpu::cutlass_kernel { + +at::Tensor basic_gemm_func( + const at::Tensor& inputA, + const at::Tensor& inputB, + const at::Tensor& inputC, + const at::Tensor& res, + double alpha, + double beta) { + int m = inputA.size(0); + int n = inputB.size(0); + int k = inputA.size(1); + + auto dpcpp_queue = vllm::xpu::vllmGetQueue(); + basic_gemm::gemm_functor( + &dpcpp_queue, + inputA.data_ptr(), + inputB.data_ptr(), + inputC.data_ptr(), + res.data_ptr(), + m, + n, + k, + alpha, + beta); + return res; +} + +at::Tensor scaled_mm_func( + at::Tensor& inputA, // [M, K], fp8_e4m3 + at::Tensor& inputB, // [N, K], fp8_e4m3 + at::Tensor& scaleA, // [M, K], half + at::Tensor& scaleB, // [N, K], half + at::Tensor& res, // [M, N], float + double alpha, + double beta) { + int m = inputA.size(0); + int n = inputB.size(0); + int k = inputA.size(1); + + auto dpcpp_queue = vllm::xpu::vllmGetQueue(); + if (inputA.scalar_type() != at::kFloat8_e4m3fn) { + std::cout << "error:wrong datatype" << std::endl; + return at::Tensor(); + } + + scaled_mm::kernel_functor( + &dpcpp_queue, + inputA.data_ptr(), + inputB.data_ptr(), + scaleA.data_ptr(), + scaleB.data_ptr(), + res.data_ptr(), + m, + n, + k, + alpha, + beta); + return res; +} + +} // namespace gpu::cutlass_kernel + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("cutlass_gemm(Tensor inputA, Tensor inputB, Tensor inputC, Tensor res, float alpha, float beta) -> Tensor"); + ops.impl("cutlass_gemm", torch::kXPU, gpu::cutlass_kernel::basic_gemm_func); + + ops.def("cutlass_scaled_mm(Tensor inputA, Tensor inputB, Tensor scaleA, Tensor scaleB, Tensor res, float alpha, float beta) -> Tensor"); + ops.impl("cutlass_scaled_mm", torch::kXPU, gpu::cutlass_kernel::scaled_mm_func); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME); + + diff --git a/csrc/quantization/helper.h b/csrc/quantization/helper.h new file mode 100644 index 0000000..9c9d03a --- /dev/null +++ b/csrc/quantization/helper.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_timer.hpp" +#else +#include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU + * stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; diff --git a/csrc/quantization/scaled_mm.h b/csrc/quantization/scaled_mm.h new file mode 100644 index 0000000..44919e1 --- /dev/null +++ b/csrc/quantization/scaled_mm.h @@ -0,0 +1,269 @@ +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" +#include "sycl_common.hpp" + +#include "scaled_mm/collective/xe_scaled_mm_mma_fp8.hpp" +#include "scaled_mm/kernel/xe_scaled_mm_fp8.hpp" + +namespace gpu::cutlass_kernel { +namespace scaled_mm { + +using namespace cute; + +struct Arguments { + int m, n, k, l; + float alpha, beta; + + Arguments() = default; + Arguments(int m, int n, int k, int l, float alpha, float beta) + : m(m), n(n), k(k), l(l), alpha(alpha), beta(beta) {} +}; + +template +struct GemmRunner { + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideScaleA = typename Gemm::GemmKernel::StrideA; + using StrideScaleB = typename Gemm::GemmKernel::StrideB; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAcc = typename Gemm::ElementAccumulator; + using ElementScaleA = cutlass::half_t; + using ElementScaleB = cutlass::half_t; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + StrideScaleA stride_scaleA; + StrideScaleB stride_scaleB; + + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_scaleA = cutlass::make_cute_packed_stride( + StrideScaleA{}, cute::make_shape(M, K, L)); + stride_scaleB = cutlass::make_cute_packed_stride( + StrideScaleB{}, cute::make_shape(N, K, L)); + } + + cutlass::Status run( + sycl::queue* stream, + const Arguments& args, + const cutlass::KernelHardwareInfo& hw_info, + ElementA* inputA, + ElementB* inputB, + ElementScaleA* scaleA, + ElementScaleB* scaleB, + float* res) { + ProblemShapeType problem_size = + ProblemShapeType{args.m, args.n, args.k, args.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {inputA, + stride_A, + inputB, + stride_B, + scaleA, + stride_scaleA, + scaleB, + stride_scaleB}, + {{args.alpha, args.beta}, nullptr, stride_C, res, stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (args.n < 16) { + std::cout + << "Invalid Problem Size: N must be >= 16 for FP8 input with F16 MMA (XE_8x16x16_F32F16F16F32_TT). Got N=" + << args.n << std::endl; + std::exit(1); + } + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Invalid Problem Size: " << args.m << 'x' << args.n << 'x' + << args.k << 'x' << args.l << std::endl; + std::exit(1); + } + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(stream); + + stream->throw_asynchronous(); + + return cutlass::Status::kSuccess; + } +}; + +/* +Several points: +1. scaled_A/B must be half, need to check if could be bf16, convert before +passing +2. scaled_A/B must be same shape as A, B, need to broadcast before kernel +3. inputC is not needed, need to check if empty pointer accpetable +4. output need to be converted from fp32 to bf16 + */ +void kernel_functor( + sycl::queue* stream, + void* inputA, + void* inputB, + void* scaleA, + void* scaleB, + void* res, + const int m, + const int n, + const int k, + float alpha, + float beta) { + Arguments args(m, n, k, 1, alpha, beta); + + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = cutlass::float_e4m3_t; + using ElementInputB = cutlass::float_e4m3_t; + using ElementOutput = float; + using ElementScale = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using TileShape = Shape<_256, _256, _32>; + using GmemTiledCopyA = + XE_2D_U8x32x32_LD_N; // Note: This shape has to match the shape used for + // the scaling factors + using GmemTiledCopyB = + XE_2D_U8x32x32_LD_V; // Note: This shape has to match the shape used for + // the scaling factors + + using TiledMma = typename TiledMMAHelper< + MMA_Atom, + Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelScaledMMW8A8; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, + ElementComputeEpilogue, + ElementAccumulator, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + EpilogueDispatchPolicy, + EpilogueOp, + TileShape, + decltype(tile_shape(TiledMma()))>; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, + void, + XE_2D_U32x8x16_ST_N, + void, + void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + void, + void, + cute::identity, + GmemTiledCopyB, + void, + void, + cute::identity>; + + using GemmKernel = cutlass::gemm::kernel::GemmScaledMM< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + GemmRunner runner; + runner.run( + stream, + args, + hw_info, + reinterpret_cast(inputA), + reinterpret_cast(inputB), + reinterpret_cast(scaleA), + reinterpret_cast(scaleB), + reinterpret_cast(res)); +} + +} // namespace scaled_mm +} // namespace gpu::cutlass_kernel diff --git a/csrc/quantization/sycl_common.hpp b/csrc/quantization/sycl_common.hpp new file mode 100644 index 0000000..2de5bcd --- /dev/null +++ b/csrc/quantization/sycl_common.hpp @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +/// Helper to initialize a block of device data +template +bool initialize_block(Element* block, std::size_t size, uint64_t seed = 2023) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block, size, seed, scope_max, scope_min, 0); + + syclcompat::wait(); + return true; +} + +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + return initialize_block(block.get(), block.size(), seed); +} + +template +void initialize_mixed_dtype_block( + cutlass::DeviceAllocation& block_device, + cutlass::DeviceAllocation& block_device_dq, + uint64_t seed) { + static_assert(cute::sizeof_bits_v >= 8); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + int bits_input = cute::sizeof_bits_v; + T1 scope_max, scope_min; + if (bits_input == 1) { + scope_max = T1(2); + scope_min = T1(0); + } else if (bits_input <= 8) { + scope_max = T1(2); + scope_min = T1(-2); + } else { + scope_max = T1(8); + scope_min = T1(-8); + } + + std::uniform_int_distribution<> dist(scope_min, scope_max); + + if constexpr (cute::sizeof_bits_v >= 8) { + auto block_host = std::vector(block_device.size()); + auto block_host_dq = std::vector(block_device.size()); + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_dq.copy_from_host(block_host_dq.data()); + } else { + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + auto block_host_dq = std::vector(array_size); + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i].get()); + } + + static constexpr auto elements_per_byte = + cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = block_device.size() / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device( + block_device.get() + (i * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), + array_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + i * array_size, + block_host_dq.data(), + array_size); + } + + auto tail_size = block_device.size() % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device( + block_device.get() + (loop_cnt * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), + tail_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + loop_cnt * array_size, + block_host_dq.data(), + tail_size); + } + } +} + +template +inline bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// TODO(Codeplay): use on device initialisation for this +template +inline void random_fill(T* src, int seed, size_t N, float max, float min) { + if constexpr ( + std::is_same_v || std::is_same_v || + std::is_same_v) { + std::random_device rd; + std::mt19937 gen(seed); + std::uniform_real_distribution dis(min, max); + auto buff = std::vector(N); + + for (size_t i = 0; i < N; ++i) { + buff[i] = (T)(dis(gen)); + } + syclcompat::memcpy(src, buff.data(), N); + syclcompat::wait(); + } else { + assert(0 & "Not supported dtype"); + } +} diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp new file mode 100644 index 0000000..c96b8dd --- /dev/null +++ b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp @@ -0,0 +1,331 @@ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "utils.hpp" + +using namespace cute; + +struct chunk_prefill_args_t { + void* query; + void* key; + void* value; + void* out; + void* block_table; + void* num_blocks_per_seq; + void* cu_seqlens_q; + void* cu_seqlens_k; + int max_queries; + int max_keys; + int total_seqlen_q; + int total_seqlen_k; + float sm_scale; + int batch_size; + int num_heads_q; + int num_heads_k; + int head_size; + int max_blocks_per_seq; + int block_size; + bool is_causal; +}; + +template struct KernelLauncher { + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + ProblemShapeType initialize(const chunk_prefill_args_t &args) { + auto problem_shape = cute::make_tuple( + 1, + args.num_heads_q, + args.num_heads_k, + args.total_seqlen_q, + args.total_seqlen_k, + args.total_seqlen_k, + args.head_size, + args.head_size); + auto problem_shape_out = cute::make_tuple( + args.batch_size, + args.num_heads_q, + args.num_heads_k, + cutlass::fmha::collective::VariableLength{args.max_queries}, // cu_q + cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_k + cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_v + args.head_size, + args.head_size); + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + auto group_q_size = num_heads_q / num_heads_kv; + auto group_q_num = num_heads_q / group_q_size; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv)); + + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); + + get<3>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_q); + get<4>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); + get<5>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); + + return problem_shape_out; + } + + cutlass::Status run(const chunk_prefill_args_t &args, const cutlass::KernelHardwareInfo &hw_info) { + ProblemShapeType problem_size = initialize(args); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + reinterpret_cast(args.query), stride_Q, + reinterpret_cast(args.key), stride_K, + reinterpret_cast(args.value), stride_V, + reinterpret_cast(args.key), stride_K_cache, + reinterpret_cast(args.value), stride_V_cache, + static_cast(args.block_table), + args.block_size, + static_cast(args.num_blocks_per_seq) + }, + {args.sm_scale}, + {reinterpret_cast(args.out), stride_O}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params); + + syclcompat::wait(); + return cutlass::Status::kSuccess; + } + + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + +// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace syclcompat::experimental; + auto event = launch>( + launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size}}, + params); +#else + syclcompat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size + }; + syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = syclcompat::experimental::launch>(policy, params); +#endif + + EventManager::getInstance().addEvent(event); + } +}; + +template struct FMHAKernel { + + template + static void run(const chunk_prefill_args_t &args) { + cutlass::KernelHardwareInfo hw_info; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using ElementInputKV = ElementInputQ; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, + GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, + PagedKV>; + + using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; + + KernelLauncher launcher; + + launcher.run(args, hw_info); + } + + static void dispatch(const chunk_prefill_args_t &args) { + if(args.is_causal) { + run(args); + } + else { + run(args); + } + } +}; + +template +void policy_dispatch( + CutlassType cuType, + const chunk_prefill_args_t& args) { + const int PipelineStages = 2; + if(cuType == CutlassType::half) { + FMHAKernel::dispatch(args); + } + else { + FMHAKernel::dispatch(args); + } +} + +void chunk_prefill_kernel( + CutlassType cuType, + const chunk_prefill_args_t& args) { + if(args.head_size == HEAD_SIZE_LIMIT_0) { + policy_dispatch(cuType, args); + } else if(args.head_size == HEAD_SIZE_LIMIT_1) { + policy_dispatch(cuType, args); + } + else if(args.head_size == HEAD_SIZE_LIMIT_2) { + policy_dispatch(cuType, args); + } +} + +void cutlass_chunk_prefill_impl( + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + double sm_scale, + bool is_causal) { + int num_block = key_cache.size(0); + int block_size = key_cache.size(1); + int num_heads_q = query.size(1); + int num_heads_kv = key_cache.size(2); + int head_size = query.size(2); + int batch_size = cu_seqlens_q.numel() - 1; + int max_blocks_per_seq = block_table.size(1); + int total_seqlen_q = query.size(0); + int total_seqlen_k = num_block * block_size; + at::Tensor num_blocks_per_seq = torch::div(cu_seqlens_k, block_size); + + chunk_prefill_args_t args = { + query.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + out.data_ptr(), + block_table.data_ptr(), + num_blocks_per_seq.data_ptr(), + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + static_cast(sm_scale), + batch_size, + num_heads_q, + num_heads_kv, + head_size, + max_blocks_per_seq, + block_size, + is_causal + }; + CutlassType cuType = aten_to_Cutlass_dtype(query); + chunk_prefill_kernel(cuType, args); +} diff --git a/csrc/xpu/cutlass_kernels/utils.hpp b/csrc/xpu/cutlass_kernels/utils.hpp new file mode 100644 index 0000000..4715419 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/utils.hpp @@ -0,0 +1,49 @@ +#pragma once +#include "torch/all.h" +#include + +#define HEAD_SIZE_LIMIT_0 64 +#define HEAD_SIZE_LIMIT_1 128 +#define HEAD_SIZE_LIMIT_2 256 +#define HEAD_SIZE_LIMIT_3 512 + +enum class CutlassType { + half, + bfloat16, +}; + +inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { + CutlassType cuType; + if (input.scalar_type() == torch::kHalf) { + cuType = CutlassType::half; + } else if (input.scalar_type() == torch::kBFloat16) { + cuType = CutlassType::bfloat16; + } else { + TORCH_INTERNAL_ASSERT( + false, + ""); + } + return cuType; +} + +using namespace cute; +struct chunk_policy_head64 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct chunk_policy_head128 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct chunk_policy_head256 { + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; \ No newline at end of file diff --git a/csrc/xpu/cutlass_sycl_demo.cpp b/csrc/xpu/cutlass_sycl_demo.cpp new file mode 100644 index 0000000..7254b65 --- /dev/null +++ b/csrc/xpu/cutlass_sycl_demo.cpp @@ -0,0 +1,520 @@ + + +#include +#include +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +using namespace cute; + +/// Helper to initialize a block of device data +template +bool initialize_block(Element* block, std::size_t size, uint64_t seed = 2023) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform(block, size, seed, + scope_max, scope_min, 0); + + syclcompat::wait(); + return true; +} + +template +bool initialize_block(cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + return initialize_block(block.get(), block.size(), seed); +} + +template +void initialize_mixed_dtype_block( + cutlass::DeviceAllocation& block_device, + cutlass::DeviceAllocation& block_device_dq, uint64_t seed) { + static_assert(cute::sizeof_bits_v >= 8); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + int bits_input = cute::sizeof_bits_v; + T1 scope_max, scope_min; + if (bits_input == 1) { + scope_max = T1(2); + scope_min = T1(0); + } else if (bits_input <= 8) { + scope_max = T1(2); + scope_min = T1(-2); + } else { + scope_max = T1(8); + scope_min = T1(-8); + } + + std::uniform_int_distribution<> dist(scope_min, scope_max); + + if constexpr (cute::sizeof_bits_v >= 8) { + auto block_host = std::vector(block_device.size()); + auto block_host_dq = std::vector(block_device.size()); + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_dq.copy_from_host(block_host_dq.data()); + } else { + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + auto block_host_dq = std::vector(array_size); + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i].get()); + } + + static constexpr auto elements_per_byte = + cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = block_device.size() / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device( + block_device.get() + (i * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), array_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + i * array_size, block_host_dq.data(), + array_size); + } + + auto tail_size = block_device.size() % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device( + block_device.get() + (loop_cnt * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), tail_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + loop_cnt * array_size, block_host_dq.data(), + tail_size); + } + } +} + +template +inline bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// TODO(Codeplay): use on device initialisation for this +template +inline void random_fill(T* src, int seed, size_t N, float max, float min) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + std::random_device rd; + std::mt19937 gen(seed); + std::uniform_real_distribution dis(min, max); + auto buff = std::vector(N); + + for (size_t i = 0; i < N; ++i) { + buff[i] = (T)(dis(gen)); + } + syclcompat::memcpy(src, buff.data(), N); + syclcompat::wait(); + } else { + assert(0 & "Not supported dtype"); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options() + : help(false), + error(false), + m(5120), + n(4096), + k(4096), + l(1), + iterations(20), + alpha(1.f), + beta(0.f) {} + + // Parses the command line + void parse(int argc, char const** args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream& print_usage(std::ostream& out) const { + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage " + "statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of " + "the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner { + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation + block_ref_D; // Reference GEMM result for verification + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, + ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, + cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. + // default in-order queue + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // Complete the stride by combining static layout info (StrideA) with + // runtime size info (M,K,L) + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, + const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = + ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, + block_C.get(), + stride_C, + block_D.get(), + stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n + << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = + (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' + << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", + tflops / cute_time, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +void cutlass_sycl_demo(torch::Tensor& a) { + // + // Parse options + // + // + std::cout << a.sizes() << std::endl; + + Options options; + + /* options.parse(argc, argv); */ + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a + // given device ID. This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and + // computation between elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = + bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = + bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // The 2D block copy operations used for the A and B matrices + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, + // combining both additional hardware (sub-groups for Intel BMG) and + // iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom + // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group + // layout (8x4x1). The TiledMMA constructed using TiledMMAHelper has the + // property that each sub-group operates on a single contiguous chunk of the + // work-group TileShape. For this configuration, this implies that each + // sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major + // (stride 4,1,0) for performance reasons. + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper< + MMA_Atom, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch + // from A and B. + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // This is the 'default' epilogue operation (Linear Combination) which + // performs everything in: (D = alpha * (A*B) + beta * C) aside from the + // (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for + // more complex epilogue examples. + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, ElementComputeEpilogue, ElementAccumulator, + ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; + + // FusionCallbacks ties the EpilogueOp to an implementation (based on the + // dispatch policy/architecture) and defines the epilogue arguments. + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + EpilogueDispatchPolicy, EpilogueOp, TileShape, + decltype(tile_shape(TiledMma()))>; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & + // load/stores any auxiliary data required + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, TileShape, ElementAccumulator, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + ElementOutput, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + FusionCallBacks, + XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + void, void, + XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + void, void>; + + // GEMM Mainloop - iteration over blocks in K dimension + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, TileShape, ElementInputA, + cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + ElementInputB, + cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + TiledMma, GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + // Define the whole kernel (mainloop and epilogue) + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Defer global problem shape definition to + // runtime + CollectiveMainloop, CollectiveEpilogue>; + + // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the + // launch, and e.g. persistent scratch memory if required. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); +} diff --git a/csrc/xpu/helper.h b/csrc/xpu/helper.h new file mode 100644 index 0000000..4bc345c --- /dev/null +++ b/csrc/xpu/helper.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) + #include "cutlass/util/sycl_timer.hpp" +#else + #include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU + * stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; \ No newline at end of file diff --git a/csrc/xpu/mha.h b/csrc/xpu/mha.h new file mode 100644 index 0000000..a18cc0c --- /dev/null +++ b/csrc/xpu/mha.h @@ -0,0 +1,16 @@ +#pragma once + + +void cutlass_chunk_prefill_impl( + at::Tensor& query, // [seq_q, heads, head_size] + at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + at::Tensor& value_cache, + at::Tensor& out, + at::Tensor& block_table, + at::Tensor& cu_seqlens_q, + at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + double sm_scale, + bool is_causal +); \ No newline at end of file diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index 2741088..7c1dcf0 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -2,8 +2,10 @@ #include -void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, double epsilon); \ No newline at end of file +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, double epsilon); + +void cutlass_sycl_demo(torch::Tensor& a); \ No newline at end of file diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 39c193e..9c9e0a2 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -31,6 +31,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm); + + ops.def("cutlass_sycl_demo(Tensor a) -> ()"); + ops.impl("cutlass_sycl_demo", torch::kXPU, &cutlass_sycl_demo); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/xpu/utils.h b/csrc/xpu/utils.h index 2781b48..9d55bfa 100644 --- a/csrc/xpu/utils.h +++ b/csrc/xpu/utils.h @@ -8,24 +8,27 @@ namespace vllm { namespace xpu { -static inline sycl::queue &vllmGetQueue() { +static inline sycl::queue& vllmGetQueue() { auto current_stream = c10::xpu::getCurrentXPUStream(); - auto &queue = current_stream.queue(); + auto& queue = current_stream.queue(); return queue; } -template struct SyclTypeTrait { +template +struct SyclTypeTrait { using Type = T; }; -template <> struct SyclTypeTrait { +template <> +struct SyclTypeTrait { using Type = sycl::half; }; -template <> struct SyclTypeTrait { +template <> +struct SyclTypeTrait { using Type = sycl::ext::oneapi::bfloat16; }; -} // namespace xpu +} // namespace xpu -} // namespace vllm +} // namespace vllm diff --git a/mll_build.sh b/mll_build.sh new file mode 100644 index 0000000..3e5ace0 --- /dev/null +++ b/mll_build.sh @@ -0,0 +1,2 @@ +python3 setup.py clean +VLLM_TARGET_DEVICE=xpu python3 setup.py develop diff --git a/setup.py b/setup.py index 346a13d..9d7f64f 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,7 @@ def configure(self, ext: CMakeExtension) -> None: cmake_args = [ '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + '-DCMAKE_TOOLCHAIN_FILE=cmake/toolchain.cmake' ] verbose = envs.VERBOSE @@ -258,6 +259,8 @@ def run(self): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C")) + # ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fa2_C")) + ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fp8_C")) if ext_modules: cmdclass = {"build_ext": cmake_build_ext} diff --git a/tests/flash_attn/test.py b/tests/flash_attn/test.py new file mode 100644 index 0000000..dfa6539 --- /dev/null +++ b/tests/flash_attn/test.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func + +DTYPES = [torch.half, torch.bfloat16] +dtype = torch.half + +torch.set_default_device("xpu") +batch_size = 1 +seq_len = 512 +num_heads = 8 +head_dim = 128 + +max_seqlen_q = seq_len +cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) +max_seqlen_k = seq_len +cu_seqlens_k = cu_seqlens_q + +block_size = 128 +num_blocks = max_seqlen_q // block_size +max_num_blocks_per_seq = seq_len // block_size + +block_tables = torch.randint(0, num_blocks, (batch_size, max_num_blocks_per_seq), dtype=torch.int32) + +print(block_tables) +print(cu_seqlens_q) + +q = torch.randn(sum(cu_seqlens_q), num_heads, head_dim, dtype=dtype) +k = torch.randn(num_blocks, block_size, num_heads, head_dim, dtype=dtype) +v = torch.randn(num_blocks, block_size, num_heads, head_dim, dtype=dtype) + +# Call the flash attention function +output= flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, + max_seqlen_k, cu_seqlens_k, block_table=block_tables) + +assert output is not None +assert output.dtype == dtype diff --git a/tests/flash_attn/test_flash_attn_varlen_func.py b/tests/flash_attn/test_flash_attn_varlen_func.py new file mode 100644 index 0000000..fab86b4 --- /dev/null +++ b/tests/flash_attn/test_flash_attn_varlen_func.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func + +DTYPES = [torch.half, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_flash_attn_varlen_func(dtype): + torch.set_default_device("xpu") + batch_size = 1 + seq_len = 4 + num_heads = 8 + head_dim = 16 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + + max_seqlen_q = seq_len + cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) + max_seqlen_k = seq_len + cu_seqlens_k = cu_seqlens_q + + # Call the flash attention function + output = flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, + max_seqlen_k, cu_seqlens_k) + + assert output is not None + assert output.dtype == dtype + assert output.shape == (batch_size, seq_len, num_heads, head_dim) diff --git a/tests/quantization/test.py b/tests/quantization/test.py new file mode 100644 index 0000000..f5d3007 --- /dev/null +++ b/tests/quantization/test.py @@ -0,0 +1,152 @@ +import pytest +import torch +import torch.profiler +from math import ceil + +from vllm_xpu_kernels.scaled_mm_interface import cutlass_basic_gemm, cutlass_scaled_mm + +DTYPES = [torch.bfloat16] + +# @pytest.mark.parametrize("dtype", DTYPES) +def test_base_gemm(dtype): + # import ipdb + # ipdb.set_trace() + torch.set_default_device("xpu") + + m = 1024 + n = 2048 + k = 4096 + + inputA = torch.randn(m, k, dtype=dtype) + inputB = torch.randn(n, k, dtype=dtype) + inputC = torch.randn(m, n, dtype=dtype) + res = torch.empty_like(inputC) + + alpha = 1.0 + beta = 0.1 + ref_D = alpha * (inputA @ inputB.t()) + beta * inputC + cutlass_B = inputB.transpose(1, 0).contiguous().transpose(1, 0) + print("cutlassB ", cutlass_B.shape, cutlass_B.stride()) + cutlass_basic_gemm(inputA, cutlass_B, inputC, res, alpha, beta) + print(res) + print(ref_D) + + +def ref_scaled_mm(inputA_fp8, inputB_fp8, scaleA, scaleB, block_size): + m, k = inputA_fp8.shape + n = inputB_fp8.shape[0] + a_scale = scaleA + b_scale = scaleB + block_shape_n = block_size[0] + block_shape_k = block_size[1] + scale_n = b_scale.shape[0] + scale_k = b_scale.shape[1] + + a_scale = a_scale.unsqueeze(-1).repeat(1, 1, block_shape_k) + a_scale = a_scale.reshape(m, scale_k * block_shape_k) + a_scale = a_scale[:, :k] + + b_scale = ( + b_scale.view(-1, 1) + .repeat(1, block_shape_n * block_shape_k) + .view(scale_n, scale_k, block_shape_n, block_shape_k) + .permute(0, 2, 1, 3) + .reshape(scale_n * block_shape_n, scale_k * block_shape_k) + ) + b_scale = b_scale[:n, :k] + a_full = a_scale.half() + b_full = b_scale.half() + + A32 = inputA_fp8.to(torch.float16) + B32 = inputB_fp8.to(torch.float16) + + res = (A32 * a_full) @ (B32 * b_full).T + return res + + +def test_scaled_mm(): + torch.set_default_device("xpu") + + m = 8192 + k = 7168 + n = 1536 + block_n, block_k = (16, 16) + + dtype = torch.float8_e4m3fn + + inputA = torch.randn(m, k, dtype=torch.float32).to(dtype) + inputB = torch.randn(n, k, dtype=torch.float32).to(dtype) + inputC = torch.randn(m, n, dtype=torch.float32) + res = torch.empty_like(inputC) + + scale_k = ceil(k / block_k) + scaleA = torch.rand((m, scale_k), dtype=torch.float16) + scale_n = ceil(n / block_n) + scaleB = torch.rand((scale_n, scale_k), dtype=torch.float16) + + out = cutlass_scaled_mm(inputA, inputB, scaleA, scaleB, [block_n, block_k]) + ref = ref_scaled_mm(inputA, inputB, scaleA, scaleB, [block_n, block_k]) + print(out) + print(ref) + + try: + torch.testing.assert_close(out.float(), ref.float(), rtol=5e-1, atol=1.5e-1) + print("a and b are close enough") + except AssertionError as e: + print("a and b are different") + print(e) + +def profile_scaled_mm(): + torch.set_default_device("xpu") + + m = 8192 + k = 7168 + n = 1536 + block_n, block_k = (16, 16) + + dtype = torch.float8_e4m3fn + + inputA = torch.randn(m, k, dtype=torch.float32).to(dtype) + inputB = torch.randn(n, k, dtype=torch.float32).to(dtype) + inputC = torch.randn(m, n, dtype=torch.float32) + res = torch.empty_like(inputC) + + scale_k = ceil(k / block_k) + scaleA = torch.rand((m, scale_k), dtype=torch.float16) + scale_n = ceil(n / block_n) + scaleB = torch.rand((scale_n, scale_k), dtype=torch.float16) + + for _ in range(5): + out = cutlass_scaled_mm(inputA, inputB, scaleA, scaleB, [block_n, block_k]) + + with torch.profiler.profile( + activities=[ + # torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU + ], + # record_shapes=True, + with_stack=False, + # profile_memory=True, + with_flops=True, + ) as prof: + out = cutlass_scaled_mm(inputA, inputB, scaleA, scaleB, [block_n, block_k]) + + print(prof.key_averages().table( + sort_by="self_xpu_time_total", + row_limit=-1 + )) + + flops = 0 + flops += 2*m*n*k #mm + flops += m*k + n*k #dequante + flops /= 1e9 + mem = 0 + mem += m*k*1 + n*k*1 + m*k*2 + n*k*2 + m*n*4 + print(f"FLOPs: {flops:.3f} G") + print(f"Memory: {mem / 1e6:.3f} MB") + + +if __name__ == '__main__': + # test_base_gemm(DTYPES[0]) + # test_scaled_mm() + profile_scaled_mm() diff --git a/tests/test_cutlass_op.py b/tests/test_cutlass_op.py new file mode 100644 index 0000000..f575ae9 --- /dev/null +++ b/tests/test_cutlass_op.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm_xpu_kernels._C # noqa F401 + +DTYPES = [torch.half, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_cutlass_op(dtype: torch.dtype, ): + torch.set_default_device("xpu") + a = torch.zeros((2, 3), dtype=dtype, device="xpu") + torch.ops._C.cutlass_sycl_demo(a) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py new file mode 100644 index 0000000..1a94582 --- /dev/null +++ b/vllm_xpu_kernels/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 +from .scaled_mm_interface import cutlass_basic_gemm, cutlass_scaled_mm diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py new file mode 100644 index 0000000..33e7de8 --- /dev/null +++ b/vllm_xpu_kernels/flash_attn_interface.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +#isort: off +try: + from . import _vllm_fa2_C # noqa: F401 + FA2_UNAVAILABLE_REASON = None + FA2_AVAILABLE = True +except ImportError as e: + FA2_UNAVAILABLE_REASON = str(e) + FA2_AVAILABLE = False + +#isort: on + +DEFAULT_FA_VERSION = 2 + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q, + cu_seqlens_q, + max_seqlen_k, + cu_seqlens_k=None, # only used for non-paged prefill + seqused_k=None, + q_v=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size: Optional[list[int]] = None, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, + return_softmax_lse=False, + out=None, + # FA3 Only + scheduler_metadata=None, + q_descale=None, + k_descale=None, + v_descale=None, + num_splits: int = 0, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, +): + assert cu_seqlens_k is not None or seqused_k is not None, \ + "cu_seqlens_k or seqused_k must be provided" + assert cu_seqlens_k is None or seqused_k is None, \ + "cu_seqlens_k and seqused_k cannot be provided at the same time" + + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + # custom op does not support non-tuple input + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) + + if fa_version == 2: + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale") + if num_splits > 1: + raise NotImplementedError("FA2 does not support num_splits > 1") + out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp + # still wants it so we pass all zeros + dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, + seqused_k, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + real_window_size[0], + real_window_size[1], + softcap, + return_softmax_lse and dropout_p > 0, + None, + ) + else: + raise NotImplementedError("not support yet") + return (out, softmax_lse) if return_softmax_lse else out diff --git a/vllm_xpu_kernels/scaled_mm_interface.py b/vllm_xpu_kernels/scaled_mm_interface.py new file mode 100644 index 0000000..bf39e46 --- /dev/null +++ b/vllm_xpu_kernels/scaled_mm_interface.py @@ -0,0 +1,65 @@ +import torch +from torch.nn.modules.utils import _pair +from torch import nn, Tensor +from typing import List + +from . import _vllm_fp8_C + +def cutlass_basic_gemm(inputA, inputB, inputC, res, alpha=1.0, beta=0.0): + print("init cutlass_basic_gemm") + return torch.ops._vllm_fp8_C.cutlass_gemm(inputA, inputB, inputC, res, alpha, beta) + + +def cutlass_scaled_mm( + inputA, + inputB, + scaleA, + scaleB, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +): + m = inputA.shape[0] + n = inputB.shape[0] + k = inputA.shape[1] + + a_scale = scaleA + b_scale = scaleB + block_shape_n = block_size[0] + block_shape_k = block_size[1] + scale_n = b_scale.shape[0] + scale_k = b_scale.shape[1] + + a_scale = a_scale.unsqueeze(-1).repeat(1, 1, block_shape_k) + a_scale = a_scale.reshape(m, scale_k * block_shape_k) + a_scale = a_scale[:, :k] + + b_scale = ( + b_scale.view(-1, 1) + .repeat(1, block_shape_n * block_shape_k) + .view(scale_n, scale_k, block_shape_n, block_shape_k) + .permute(0, 2, 1, 3) + .reshape(scale_n * block_shape_n, scale_k * block_shape_k) + ) + b_scale = b_scale[:n, :k] + + inputA = inputA.contiguous() + inputB = inputB.transpose(1, 0).contiguous().transpose(1, 0) + a_scale = a_scale.contiguous() + b_scale = b_scale.transpose(1, 0).contiguous().transpose(1, 0) + + a_scale = a_scale.half() + b_scale = b_scale.half() + + res_float = torch.empty([m, n], dtype=torch.float, device=inputA.device) + + torch.ops._vllm_fp8_C.cutlass_scaled_mm( + inputA, + inputB, + a_scale, + b_scale, + res_float, + 1.0, + 0.0, + ) + res = res_float.to(dtype=output_dtype) + return res