From d41cf57ba3a59660205efdb3ad3c1fdd1667e18e Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 1 Aug 2025 00:59:28 +0800 Subject: [PATCH 01/47] add flash attention interface Signed-off-by: Kunshang Ji --- CMakeLists.txt | 28 +++++ csrc/core/registration.h | 2 +- csrc/flash_attn/flash_api.cpp | 61 ++++++++++ csrc/flash_attn/pytorch_shim.h | 110 ++++++++++++++++++ csrc/xpu/ops.h | 6 +- csrc/xpu/utils.h | 17 +-- setup.py | 1 + .../flash_attn/test_flash_attn_varlen_func.py | 35 ++++++ vllm_xpu_kernels/__init__.py | 4 + vllm_xpu_kernels/flash_attn_interface.py | 98 ++++++++++++++++ 10 files changed, 351 insertions(+), 11 deletions(-) create mode 100644 csrc/flash_attn/flash_api.cpp create mode 100644 csrc/flash_attn/pytorch_shim.h create mode 100644 tests/flash_attn/test_flash_attn_varlen_func.py create mode 100644 vllm_xpu_kernels/__init__.py create mode 100644 vllm_xpu_kernels/flash_attn_interface.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 005457e..e02009b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,7 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21") set(TORCH_SUPPORTED_VERSION_XPU "2.8.0") set(ENABLE_MOE_KERNEL OFF) +set(FA2_ENABLED ON) # # Try to find python package with an executable that exactly matches @@ -172,6 +173,33 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) +# +# flash attention _C extension +# + +if (FA2_ENABLED) + message(STATUS "Enabling fa2 extension.") + file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp") + + define_gpu_extension_target( + _vllm_fa2_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES + csrc/flash_attn/flash_api.cpp + ${FA2_GEN_SRCS} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + + # target_include_directories(_vllm_fa2_C PRIVATE + # csrc/flash_attn + # csrc/flash_attn/src) +endif () + + # # _moe_C extension # diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 9b6d7ab..4d0ce1c 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -14,7 +14,7 @@ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME // could be a macro instead of a literal token. -#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) // REGISTER_EXTENSION allows the shared library to be loaded and initialized diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp new file mode 100644 index 0000000..ad28dec --- /dev/null +++ b/csrc/flash_attn/flash_api.cpp @@ -0,0 +1,61 @@ +#include "pytorch_shim.h" + +#include "core/registration.h" +#include + +namespace FLASH_NAMESPACE { + +std::vector mha_varlen_fwd( + at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := + // \sum_{i=0}^{b} s_i or num_blocks x page_block_size + // x num_heads_k x head_size if there's a block_table. + const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := + // \sum_{i=0}^{b} s_i or num_blocks x page_block_size + // x num_heads_k x head_size if there's a block_table. + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& leftpad_k_, // batch_size + std::optional& + block_table_, // batch_size x max_num_blocks_per_seq + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, + const float softmax_scale, const bool zero_tensors, bool is_causal, + int window_size_left, int window_size_right, const float softcap, + const bool return_softmax, std::optional gen_) { + at::Tensor out; + out = torch::zeros_like(q); + + const auto sizes = q.sizes(); + const int total_q = q.sizes()[0]; + auto opts = q.options(); + int num_heads = sizes[1]; + + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + return {out, softmax_lse}; +} +} // namespace FLASH_NAMESPACE + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def( + "varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor " + "cu_seqlens_q, " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? " + "block_table, Tensor? alibi_slopes, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " + "softmax_scale, bool zero_tensors, " + "bool is_causal, int window_size_left, int window_size_right, float " + "softcap, bool return_softmax, " + "Generator? gen) -> Tensor[]"); + ops.impl("varlen_fwd", torch::kXPU, + make_pytorch_shim(&FLASH_NAMESPACE::mha_varlen_fwd)); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/flash_attn/pytorch_shim.h b/csrc/flash_attn/pytorch_shim.h new file mode 100644 index 0000000..82f5f5c --- /dev/null +++ b/csrc/flash_attn/pytorch_shim.h @@ -0,0 +1,110 @@ +#pragma once + +#include + +/** + * Unfortunately, the type signatures of the flash_attn ops are not compatible + * with the PyTorch library bindings. To get around that we use + * `make_pytorch_shim` which creates a lambda that exponses the API using + * PyTorch compatible types to the types, then converts them to the types + * expected by the flash_attn ops. This shims allows us to make minimal changes + * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch library bindings: + * - `int` + * - `float` + * - `std::optional &` + * - `std::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const std::optional&` + * - `const std::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { return arg; } +}; + +template +using pytorch_library_compatible_type_t = + typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type( + pytorch_library_compatible_type_t arg) { + return pytorch_library_compatible_type::convert_from_type(arg); +} + +// Map `std::optional &` -> `const std::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type&> { + using type = const std::optional&; + static std::optional& convert_from_type(const std::optional& arg) { + return const_cast&>(arg); + } +}; + +// Map `std::optional` -> +// `std::optional>` +// (NOTE: tested for `std::optional` -> `std::optional`) +template +struct pytorch_library_compatible_type> { + using type = std::optional>; + static std::optional> convert_from_type( + std::optional arg) { + return arg; + } +}; + +// Map `std::optional&` -> `const std::optional&` +template <> +struct pytorch_library_compatible_type&> { + using type = const std::optional&; + static std::optional& convert_from_type( + const std::optional& arg) { + return const_cast&>( + reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), + "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), + "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template <> +struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), + "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret (*fun)(Args... args)) { + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index 2741088..acf62e6 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -2,8 +2,8 @@ #include -void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, double epsilon); \ No newline at end of file +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, double epsilon); \ No newline at end of file diff --git a/csrc/xpu/utils.h b/csrc/xpu/utils.h index 2781b48..9d55bfa 100644 --- a/csrc/xpu/utils.h +++ b/csrc/xpu/utils.h @@ -8,24 +8,27 @@ namespace vllm { namespace xpu { -static inline sycl::queue &vllmGetQueue() { +static inline sycl::queue& vllmGetQueue() { auto current_stream = c10::xpu::getCurrentXPUStream(); - auto &queue = current_stream.queue(); + auto& queue = current_stream.queue(); return queue; } -template struct SyclTypeTrait { +template +struct SyclTypeTrait { using Type = T; }; -template <> struct SyclTypeTrait { +template <> +struct SyclTypeTrait { using Type = sycl::half; }; -template <> struct SyclTypeTrait { +template <> +struct SyclTypeTrait { using Type = sycl::ext::oneapi::bfloat16; }; -} // namespace xpu +} // namespace xpu -} // namespace vllm +} // namespace vllm diff --git a/setup.py b/setup.py index 346a13d..7f6dc32 100644 --- a/setup.py +++ b/setup.py @@ -258,6 +258,7 @@ def run(self): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C")) + ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fa2_C")) if ext_modules: cmdclass = {"build_ext": cmake_build_ext} diff --git a/tests/flash_attn/test_flash_attn_varlen_func.py b/tests/flash_attn/test_flash_attn_varlen_func.py new file mode 100644 index 0000000..fab86b4 --- /dev/null +++ b/tests/flash_attn/test_flash_attn_varlen_func.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func + +DTYPES = [torch.half, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_flash_attn_varlen_func(dtype): + torch.set_default_device("xpu") + batch_size = 1 + seq_len = 4 + num_heads = 8 + head_dim = 16 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + + max_seqlen_q = seq_len + cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) + max_seqlen_k = seq_len + cu_seqlens_k = cu_seqlens_q + + # Call the flash attention function + output = flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, + max_seqlen_k, cu_seqlens_k) + + assert output is not None + assert output.dtype == dtype + assert output.shape == (batch_size, seq_len, num_heads, head_dim) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py new file mode 100644 index 0000000..1de5cdd --- /dev/null +++ b/vllm_xpu_kernels/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +import vllm_xpu_kernels._vllm_fa2_C # noqa: F401 + +from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py new file mode 100644 index 0000000..39d34aa --- /dev/null +++ b/vllm_xpu_kernels/flash_attn_interface.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +DEFAULT_FA_VERSION = 2 + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q, + cu_seqlens_q, + max_seqlen_k, + cu_seqlens_k=None, # only used for non-paged prefill + seqused_k=None, + q_v=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size: Optional[list[int]] = None, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, + return_softmax_lse=False, + out=None, + # FA3 Only + scheduler_metadata=None, + q_descale=None, + k_descale=None, + v_descale=None, + num_splits: int = 0, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, +): + assert cu_seqlens_k is not None or seqused_k is not None, \ + "cu_seqlens_k or seqused_k must be provided" + assert cu_seqlens_k is None or seqused_k is None, \ + "cu_seqlens_k and seqused_k cannot be provided at the same time" + assert block_table is None or seqused_k is not None, \ + "seqused_k must be provided if block_table is provided" + + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + # custom op does not support non-tuple input + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) + + if fa_version == 2: + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale") + if num_splits > 1: + raise NotImplementedError("FA2 does not support num_splits > 1") + out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp + # still wants it so we pass all zeros + dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, + seqused_k, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + real_window_size[0], + real_window_size[1], + softcap, + return_softmax_lse and dropout_p > 0, + None, + ) + else: + raise NotImplementedError("not support yet") + return (out, softmax_lse) if return_softmax_lse else out From ce9f31dfe7869390191839b95ef61347adf3a9a2 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 1 Aug 2025 15:57:09 +0800 Subject: [PATCH 02/47] update interface Signed-off-by: Kunshang Ji --- vllm_xpu_kernels/__init__.py | 1 - vllm_xpu_kernels/flash_attn_interface.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py index 1de5cdd..8635cc2 100644 --- a/vllm_xpu_kernels/__init__.py +++ b/vllm_xpu_kernels/__init__.py @@ -1,4 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 -import vllm_xpu_kernels._vllm_fa2_C # noqa: F401 from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py index 39d34aa..de7b764 100644 --- a/vllm_xpu_kernels/flash_attn_interface.py +++ b/vllm_xpu_kernels/flash_attn_interface.py @@ -3,6 +3,17 @@ import torch +#isort: off +try: + from . import _vllm_fa2_C # noqa: F401 + FA2_UNAVAILABLE_REASON = None + FA2_AVAILABLE = True +except ImportError as e: + FA2_UNAVAILABLE_REASON = str(e) + FA2_AVAILABLE = False + +#isort: on + DEFAULT_FA_VERSION = 2 From fb6784fda3d8ab6ca126163ee5ca4aa8d7bc467a Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Mon, 4 Aug 2025 14:58:00 +0800 Subject: [PATCH 03/47] add cutlass deps (#1) * add cutlass Signed-off-by: Kunshang Ji * fix import Signed-off-by: Kunshang Ji --------- Signed-off-by: Kunshang Ji --- CMakeLists.txt | 55 +++- cmake/toolchain.cmake | 6 + csrc/xpu/cutlass_sycl_demo.cpp | 520 +++++++++++++++++++++++++++++++++ csrc/xpu/helper.h | 127 ++++++++ csrc/xpu/ops.h | 4 +- csrc/xpu/torch_bindings.cpp | 3 + setup.py | 1 + tests/test_cutlass_op.py | 17 ++ 8 files changed, 729 insertions(+), 4 deletions(-) create mode 100644 cmake/toolchain.cmake create mode 100644 csrc/xpu/cutlass_sycl_demo.cpp create mode 100644 csrc/xpu/helper.h create mode 100644 tests/test_cutlass_op.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e02009b..e4ac468 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,16 +147,64 @@ endif() if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_SRC + "csrc/xpu/cutlass_sycl_demo.cpp" "csrc/xpu/layernorm.cpp" "csrc/xpu/torch_bindings.cpp" ) include_directories("/usr/include") - set(CMPLR_ROOT $ENV{CMPLR_ROOT}) - set(CMAKE_CXX_COMPILER icpx) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) + message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) - list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + + + # add cutlass dependency + set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "v3.9-0.2" CACHE STRING "CUTLASS revision to use") + + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + FetchContent_Declare( + cutlass-sycl + GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) + + # cutlass compilation flags + set(CUTLASS_ENABLE_SYCL "ON") + # set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures") + set(CMAKE_EXPORT_COMPILE_COMMANDS "ON") + set(CUTLASS_ENABLE_BENCHMARKS "OFF") + # disable cuda + set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") + # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) + # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) + + + FetchContent_MakeAvailable(cutlass-sycl) + set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") + set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") + message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}") + + # header only library + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL") + list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET") + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED") + list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0") + list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always") + endif() message(STATUS "Enabling C extension.") @@ -170,6 +218,7 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/cmake/toolchain.cmake b/cmake/toolchain.cmake new file mode 100644 index 0000000..f8cf8b9 --- /dev/null +++ b/cmake/toolchain.cmake @@ -0,0 +1,6 @@ +# use this file to set the compiler and flags for SYCL + +set(CMPLR_ROOT $ENV{CMPLR_ROOT}) +message(STATUS "CMPLR_ROOT: ${CMPLR_ROOT}") +set(CMAKE_CXX_COMPILER ${CMPLR_ROOT}/bin/icpx) +set(CMAKE_C_COMPILER ${CMPLR_ROOT}/bin/icx) \ No newline at end of file diff --git a/csrc/xpu/cutlass_sycl_demo.cpp b/csrc/xpu/cutlass_sycl_demo.cpp new file mode 100644 index 0000000..7254b65 --- /dev/null +++ b/csrc/xpu/cutlass_sycl_demo.cpp @@ -0,0 +1,520 @@ + + +#include +#include +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +using namespace cute; + +/// Helper to initialize a block of device data +template +bool initialize_block(Element* block, std::size_t size, uint64_t seed = 2023) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform(block, size, seed, + scope_max, scope_min, 0); + + syclcompat::wait(); + return true; +} + +template +bool initialize_block(cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + return initialize_block(block.get(), block.size(), seed); +} + +template +void initialize_mixed_dtype_block( + cutlass::DeviceAllocation& block_device, + cutlass::DeviceAllocation& block_device_dq, uint64_t seed) { + static_assert(cute::sizeof_bits_v >= 8); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + int bits_input = cute::sizeof_bits_v; + T1 scope_max, scope_min; + if (bits_input == 1) { + scope_max = T1(2); + scope_min = T1(0); + } else if (bits_input <= 8) { + scope_max = T1(2); + scope_min = T1(-2); + } else { + scope_max = T1(8); + scope_min = T1(-8); + } + + std::uniform_int_distribution<> dist(scope_min, scope_max); + + if constexpr (cute::sizeof_bits_v >= 8) { + auto block_host = std::vector(block_device.size()); + auto block_host_dq = std::vector(block_device.size()); + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_dq.copy_from_host(block_host_dq.data()); + } else { + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + auto block_host_dq = std::vector(array_size); + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i].get()); + } + + static constexpr auto elements_per_byte = + cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = block_device.size() / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device( + block_device.get() + (i * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), array_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + i * array_size, block_host_dq.data(), + array_size); + } + + auto tail_size = block_device.size() % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device( + block_device.get() + (loop_cnt * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), tail_size); + cutlass::device_memory::copy_to_device( + block_device_dq.get() + loop_cnt * array_size, block_host_dq.data(), + tail_size); + } + } +} + +template +inline bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// TODO(Codeplay): use on device initialisation for this +template +inline void random_fill(T* src, int seed, size_t N, float max, float min) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + std::random_device rd; + std::mt19937 gen(seed); + std::uniform_real_distribution dis(min, max); + auto buff = std::vector(N); + + for (size_t i = 0; i < N; ++i) { + buff[i] = (T)(dis(gen)); + } + syclcompat::memcpy(src, buff.data(), N); + syclcompat::wait(); + } else { + assert(0 & "Not supported dtype"); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options() + : help(false), + error(false), + m(5120), + n(4096), + k(4096), + l(1), + iterations(20), + alpha(1.f), + beta(0.f) {} + + // Parses the command line + void parse(int argc, char const** args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream& print_usage(std::ostream& out) const { + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage " + "statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of " + "the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner { + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation + block_ref_D; // Reference GEMM result for verification + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, + ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, + cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. + // default in-order queue + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // Complete the stride by combining static layout info (StrideA) with + // runtime size info (M,K,L) + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, + const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = + ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, + block_C.get(), + stride_C, + block_D.get(), + stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n + << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = + (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' + << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", + tflops / cute_time, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +void cutlass_sycl_demo(torch::Tensor& a) { + // + // Parse options + // + // + std::cout << a.sizes() << std::endl; + + Options options; + + /* options.parse(argc, argv); */ + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a + // given device ID. This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and + // computation between elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = + bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = + bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // The 2D block copy operations used for the A and B matrices + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, + // combining both additional hardware (sub-groups for Intel BMG) and + // iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom + // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group + // layout (8x4x1). The TiledMMA constructed using TiledMMAHelper has the + // property that each sub-group operates on a single contiguous chunk of the + // work-group TileShape. For this configuration, this implies that each + // sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major + // (stride 4,1,0) for performance reasons. + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper< + MMA_Atom, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch + // from A and B. + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // This is the 'default' epilogue operation (Linear Combination) which + // performs everything in: (D = alpha * (A*B) + beta * C) aside from the + // (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for + // more complex epilogue examples. + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, ElementComputeEpilogue, ElementAccumulator, + ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; + + // FusionCallbacks ties the EpilogueOp to an implementation (based on the + // dispatch policy/architecture) and defines the epilogue arguments. + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + EpilogueDispatchPolicy, EpilogueOp, TileShape, + decltype(tile_shape(TiledMma()))>; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & + // load/stores any auxiliary data required + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, TileShape, ElementAccumulator, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + ElementOutput, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + FusionCallBacks, + XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + void, void, + XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + void, void>; + + // GEMM Mainloop - iteration over blocks in K dimension + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, TileShape, ElementInputA, + cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + ElementInputB, + cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to + // CUTLASS 3.x representation + TiledMma, GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + // Define the whole kernel (mainloop and epilogue) + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Defer global problem shape definition to + // runtime + CollectiveMainloop, CollectiveEpilogue>; + + // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the + // launch, and e.g. persistent scratch memory if required. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); +} diff --git a/csrc/xpu/helper.h b/csrc/xpu/helper.h new file mode 100644 index 0000000..4bc345c --- /dev/null +++ b/csrc/xpu/helper.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) + #include "cutlass/util/sycl_timer.hpp" +#else + #include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU + * stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; \ No newline at end of file diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index acf62e6..7c1dcf0 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -6,4 +6,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, double epsilon); \ No newline at end of file + torch::Tensor& weight, double epsilon); + +void cutlass_sycl_demo(torch::Tensor& a); \ No newline at end of file diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 39c193e..9c9e0a2 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -31,6 +31,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm); + + ops.def("cutlass_sycl_demo(Tensor a) -> ()"); + ops.impl("cutlass_sycl_demo", torch::kXPU, &cutlass_sycl_demo); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index 7f6dc32..c7d8660 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,7 @@ def configure(self, ext: CMakeExtension) -> None: cmake_args = [ '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + '-DCMAKE_TOOLCHAIN_FILE=cmake/toolchain.cmake' ] verbose = envs.VERBOSE diff --git a/tests/test_cutlass_op.py b/tests/test_cutlass_op.py new file mode 100644 index 0000000..f575ae9 --- /dev/null +++ b/tests/test_cutlass_op.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm_xpu_kernels._C # noqa F401 + +DTYPES = [torch.half, torch.bfloat16] + + +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_cutlass_op(dtype: torch.dtype, ): + torch.set_default_device("xpu") + a = torch.zeros((2, 3), dtype=dtype, device="xpu") + torch.ops._C.cutlass_sycl_demo(a) From ce27fa219e0a04265f4ff83a8f59fef5626facc6 Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Thu, 7 Aug 2025 00:05:16 -0700 Subject: [PATCH 04/47] add chunk_prefill step<1> --- CMakeLists.txt | 12 +- csrc/flash_attn/flash_api.cpp | 53 +++- csrc/xpu/cutlass_kernels/chunk_prefill.hpp | 341 +++++++++++++++++++++ csrc/xpu/cutlass_kernels/utils.hpp | 26 ++ csrc/xpu/mha.h | 16 + 5 files changed, 430 insertions(+), 18 deletions(-) create mode 100644 csrc/xpu/cutlass_kernels/chunk_prefill.hpp create mode 100644 csrc/xpu/cutlass_kernels/utils.hpp create mode 100644 csrc/xpu/mha.h diff --git a/CMakeLists.txt b/CMakeLists.txt index e4ac468..d03540b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") "csrc/xpu/cutlass_sycl_demo.cpp" "csrc/xpu/layernorm.cpp" "csrc/xpu/torch_bindings.cpp" + "csrc/flash_attn/flash_api.cpp" ) include_directories("/usr/include") list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) @@ -157,7 +158,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) - list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) + list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=0" "-fno-sycl-instrument-device-code" "-DVLLM_BUILD_XPU_OPS" ) list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) @@ -166,12 +167,12 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v3.9-0.2" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "chunk_prefill_BMG" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( cutlass-sycl - GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git + GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE @@ -196,7 +197,8 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") FetchContent_MakeAvailable(cutlass-sycl) set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") - message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}") + set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "") + message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}") # header only library list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL") @@ -218,6 +220,7 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -269,6 +272,7 @@ if (ENABLE_MOE_KERNEL) ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) endif() diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ad28dec..4a9b417 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1,45 +1,70 @@ #include "pytorch_shim.h" #include "core/registration.h" +#include "xpu/cutlass_kernels/chunk_prefill.hpp" #include namespace FLASH_NAMESPACE { std::vector mha_varlen_fwd( - at::Tensor& - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := // \sum_{i=0}^{b} s_i or num_blocks x page_block_size // x num_heads_k x head_size if there's a block_table. const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := // \sum_{i=0}^{b} s_i or num_blocks x page_block_size // x num_heads_k x head_size if there's a block_table. - std::optional& - out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& cu_seqlens_q, // b+1 const at::Tensor& cu_seqlens_k, // b+1 - std::optional& - seqused_k, // b. If given, only this many elements of each batch + std::optional& seqused_k, // b. If given, only this many elements of each batch // element's keys are used. std::optional& leftpad_k_, // batch_size - std::optional& - block_table_, // batch_size x max_num_blocks_per_seq + at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq std::optional& alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, const int max_seqlen_k, const float p_dropout, - const float softmax_scale, const bool zero_tensors, bool is_causal, - int window_size_left, int window_size_right, const float softcap, + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_) { at::Tensor out; - out = torch::zeros_like(q); + if(out_.has_value()) { + out = *out_; + } + else { + out = torch::zeros_like(q); + } const auto sizes = q.sizes(); const int total_q = q.sizes()[0]; auto opts = q.options(); int num_heads = sizes[1]; - auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + cutlass_chunk_prefill_impl( + q, + k, + v, + out, + block_table_, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal); - return {out, softmax_lse}; + if(return_softmax) { + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + return {out, softmax_lse}; + } + else { + return {out}; + } } } // namespace FLASH_NAMESPACE diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp new file mode 100644 index 0000000..7adb27d --- /dev/null +++ b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp @@ -0,0 +1,341 @@ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "utils.hpp" + +using namespace cute; + +struct chunk_prefill_args_t { + void* query; + void* key; + void* value; + void* out; + void* block_table; + void* num_blocks_per_seq; + void* cu_seqlens_q; + void* cu_seqlens_k; + int max_queries; + int max_keys; + int total_seqlen_q; + int total_seqlen_k; + float sm_scale; + int batch_size; + int num_heads_q; + int num_heads_k; + int head_size; + int max_blocks_per_seq; + int block_size; + bool is_causal; +}; + +template struct KernelLauncher { + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + ProblemShapeType initialize(const chunk_prefill_args_t &args) { + auto problem_shape = cute::make_tuple( + 1, + args.num_heads_q, + args.num_heads_k, + args.total_seqlen_q, + args.total_seqlen_k, + args.total_seqlen_k, + args.head_size, + args.head_size); + auto problem_shape_out = cute::make_tuple( + args.batch_size, + args.num_heads_q, + args.num_heads_k, + cutlass::fmha::collective::VariableLength{args.max_queries}, + cutlass::fmha::collective::VariableLength{args.max_keys}, + cutlass::fmha::collective::VariableLength{args.max_keys}, + args.head_size, + args.head_size); + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + auto group_q_size = num_heads_q / num_heads_kv; + auto group_q_num = num_heads_q / group_q_size; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv)); + + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); + + get<3>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_q); + get<4>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); + get<5>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); + + return problem_shape_out; + } + + cutlass::Status run(const chunk_prefill_args_t &args, const cutlass::KernelHardwareInfo &hw_info) { + + ProblemShapeType problem_size = initialize(args); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + reinterpret_cast(args.query), stride_Q, + reinterpret_cast(args.key), stride_K, + reinterpret_cast(args.value), stride_V, + nullptr, stride_K_cache, + nullptr, stride_V_cache, + reinterpret_cast(args.block_table), + args.block_size, + reinterpret_cast(args.num_blocks_per_seq) + }, + {args.sm_scale}, + {reinterpret_cast(args.value), stride_O}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params); + + syclcompat::wait(); + } + + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + +// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace syclcompat::experimental; + auto event = launch>( + launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size}}, + params); +#else + syclcompat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size + }; + syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = syclcompat::experimental::launch>(policy, params); +#endif + + EventManager::getInstance().addEvent(event); + } +}; + +template struct FMHAKernel { + + template + static void run(const chunk_prefill_args_t &args) { + cutlass::KernelHardwareInfo hw_info; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using ElementInputKV = ElementInputQ; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, + GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, + PagedKV>; + + using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; + + KernelLauncher launcher; + + launcher.run(args, hw_info); + return 0; + } + + static void dispatch(const chunk_prefill_args_t &args) { + if(args.is_causal) { + run(args); + } + else { + run(args); + } + } +}; + +void chunk_prefill_kernel( + CutlassType cuType, + const chunk_prefill_args_t& args) { + const int PipelineStages = 2; + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + if(args.head_size == HEAD_SIZE_LIMIT_0) { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + } else if(args.head_size == HEAD_SIZE_LIMIT_1) { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + } + else if(args.head_size == HEAD_SIZE_LIMIT_2) { + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + } + + if(cuType == CutlassType::half) { + FMHAKernel::dispatch(args); + } + else { + FMHAKernel::dispatch(args); + } +} + +void cutlass_chunk_prefill_impl( + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + double sm_scale, + bool is_causal) { + int num_block = key_cache.size(0); + int block_size = key_cache.size(1); + int num_heads_q = query.size(1); + int num_heads_kv = key_cache.size(2); + int head_size = query.size(2); + int batch_size = cu_seqlens_q.numel() - 1; + int max_blocks_per_seq = block_table.size(1); + int total_seqlen_q = query.size(0); + int total_seqlen_k = num_block * block_size; + at::Tensor num_blocks_per_seq = block_table.slice(0, 1) - block_table.slice(0, 0, -1); + num_blocks_per_seq = torch::div(num_blocks_per_seq, block_size, "ceil"); + + chunk_prefill_args_t args = { + query.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + out.data_ptr(), + block_table.data_ptr(), + num_blocks_per_seq.data_ptr(), + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + static_cast(sm_scale), + batch_size, + num_heads_q, + num_heads_kv, + head_size, + max_blocks_per_seq, + block_size, + is_causal + }; + + CutlassType cuType = aten_to_Cutlass_dtype(query); + chunk_prefill_kernel(cuType, args); +} diff --git a/csrc/xpu/cutlass_kernels/utils.hpp b/csrc/xpu/cutlass_kernels/utils.hpp new file mode 100644 index 0000000..503f329 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/utils.hpp @@ -0,0 +1,26 @@ +#pragma once +#include "torch/all.h" + +#define HEAD_SIZE_LIMIT_0 64 +#define HEAD_SIZE_LIMIT_1 128 +#define HEAD_SIZE_LIMIT_2 256 +#define HEAD_SIZE_LIMIT_3 512 + +enum class CutlassType { + half, + bfloat16, +}; + +inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { + CutlassType cuType; + if (input.scalar_type() == torch::kHalf) { + cuType = CutlassType::half; + } else if (input.scalar_type() == torch::kBFloat16) { + cuType = CutlassType::bfloat16; + } else { + TORCH_INTERNAL_ASSERT( + false, + ""); + } + return cuType; +} diff --git a/csrc/xpu/mha.h b/csrc/xpu/mha.h new file mode 100644 index 0000000..a18cc0c --- /dev/null +++ b/csrc/xpu/mha.h @@ -0,0 +1,16 @@ +#pragma once + + +void cutlass_chunk_prefill_impl( + at::Tensor& query, // [seq_q, heads, head_size] + at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + at::Tensor& value_cache, + at::Tensor& out, + at::Tensor& block_table, + at::Tensor& cu_seqlens_q, + at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + double sm_scale, + bool is_causal +); \ No newline at end of file From ed0f846d7d91f8b8f1f21d0e5c23aeb1284f23f0 Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Thu, 7 Aug 2025 00:13:11 -0700 Subject: [PATCH 05/47] fix register --- csrc/core/registration.h | 3 ++- csrc/xpu/cutlass_kernels/chunk_prefill.hpp | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 4d0ce1c..5f9cdeb 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -22,6 +22,7 @@ #define REGISTER_EXTENSION(NAME) \ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ - STRINGIFY(NAME), nullptr, 0, nullptr}; \ + STRINGIFY(NAME), nullptr, 0, nullptr, \ + nullptr, nullptr, nullptr, nullptr}; \ return PyModule_Create(&module); \ } diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp index 7adb27d..72e3124 100644 --- a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp +++ b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp @@ -240,7 +240,6 @@ template launcher; launcher.run(args, hw_info); - return 0; } static void dispatch(const chunk_prefill_args_t &args) { From b02a5a876f7a6c21d145383b046a86d6d407fa5a Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Thu, 7 Aug 2025 07:39:14 +0000 Subject: [PATCH 06/47] fix cmake --- CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d03540b..865455b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,7 +150,6 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") "csrc/xpu/cutlass_sycl_demo.cpp" "csrc/xpu/layernorm.cpp" "csrc/xpu/torch_bindings.cpp" - "csrc/flash_attn/flash_api.cpp" ) include_directories("/usr/include") list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) @@ -243,6 +242,10 @@ if (FA2_ENABLED) COMPILE_FLAGS ${VLLM_GPU_FLAGS} LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) From a4a76eef49b3bca71c1938954c9c82b69e9ba813 Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Fri, 8 Aug 2025 06:30:39 +0000 Subject: [PATCH 07/47] debug msg --- csrc/flash_attn/flash_api.cpp | 14 +++++--------- csrc/xpu/cutlass_kernels/chunk_prefill.hpp | 17 ++++++++++++++--- vllm_xpu_kernels/flash_attn_interface.py | 2 -- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 4a9b417..4a15d5e 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -40,11 +40,6 @@ std::vector mha_varlen_fwd( out = torch::zeros_like(q); } - const auto sizes = q.sizes(); - const int total_q = q.sizes()[0]; - auto opts = q.options(); - int num_heads = sizes[1]; - cutlass_chunk_prefill_impl( q, k, @@ -59,20 +54,21 @@ std::vector mha_varlen_fwd( is_causal); if(return_softmax) { - auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + auto softmax_lse = torch::empty_like(out); return {out, softmax_lse}; } else { - return {out}; + at::Tensor softmax_lse; + return {out, softmax_lse}; } } } // namespace FLASH_NAMESPACE TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( - "varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor " + "varlen_fwd(Tensor q, Tensor k, Tensor v, Tensor!? out, Tensor " "cu_seqlens_q, " - "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor " "block_table, Tensor? alibi_slopes, " "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " "softmax_scale, bool zero_tensors, " diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp index 72e3124..1a25d0e 100644 --- a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp +++ b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp @@ -113,7 +113,7 @@ template struct KernelLauncher { } cutlass::Status run(const chunk_prefill_args_t &args, const cutlass::KernelHardwareInfo &hw_info) { - + std::cout << "into launcher run" << std::endl; ProblemShapeType problem_size = initialize(args); typename FMHAChunkPrefillKernel::Arguments arguments{ @@ -152,9 +152,11 @@ template struct KernelLauncher { run(params); syclcompat::wait(); + return cutlass::Status::kSuccess; } static void run(typename FMHAChunkPrefillKernel::Params params) { + std::cout << "into final run" << std::endl; dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); @@ -166,12 +168,14 @@ template struct KernelLauncher { // Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension #if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + std::cout << "into scratch mem" << std::endl; using namespace syclcompat::experimental; auto event = launch>( launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, kernel_properties{sycl_exp::sub_group_size}}, params); #else + std::cout << "into no scratch mem" << std::endl; syclcompat::experimental::launch_properties launch_props { sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), }; @@ -203,6 +207,7 @@ template static void run(const chunk_prefill_args_t &args) { + std::cout << "into FMHAKernel run" << std::endl; cutlass::KernelHardwareInfo hw_info; using LayoutQ = cutlass::layout::RowMajor; @@ -300,6 +305,10 @@ void cutlass_chunk_prefill_impl( int max_seqlen_k, double sm_scale, bool is_causal) { + std::cout << "into cutlass_chunk_prefill_impl" << std::endl; + std::cout << "query.size(): " << query.sizes().vec() << std::endl; + std::cout << "key_cache.size(): " << key_cache.sizes().vec() << std::endl; + std::cout << "block_table.size(): " << block_table.sizes().vec() << std::endl; int num_block = key_cache.size(0); int block_size = key_cache.size(1); int num_heads_q = query.size(1); @@ -309,8 +318,10 @@ void cutlass_chunk_prefill_impl( int max_blocks_per_seq = block_table.size(1); int total_seqlen_q = query.size(0); int total_seqlen_k = num_block * block_size; - at::Tensor num_blocks_per_seq = block_table.slice(0, 1) - block_table.slice(0, 0, -1); - num_blocks_per_seq = torch::div(num_blocks_per_seq, block_size, "ceil"); + at::Tensor num_blocks_per_seq = cu_seqlens_k.slice(0, 1) - cu_seqlens_k.slice(0, 0, -1); + std::cout << "cu_seqlens_k: " << cu_seqlens_k << std::endl; + num_blocks_per_seq = torch::div(num_blocks_per_seq, block_size); + std::cout << "num_blocks_per_seq: " << num_blocks_per_seq << std::endl; chunk_prefill_args_t args = { query.data_ptr(), diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py index de7b764..33e7de8 100644 --- a/vllm_xpu_kernels/flash_attn_interface.py +++ b/vllm_xpu_kernels/flash_attn_interface.py @@ -55,8 +55,6 @@ def flash_attn_varlen_func( "cu_seqlens_k or seqused_k must be provided" assert cu_seqlens_k is None or seqused_k is None, \ "cu_seqlens_k and seqused_k cannot be provided at the same time" - assert block_table is None or seqused_k is not None, \ - "seqused_k must be provided if block_table is provided" if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) From ee1b7195fc589d00dddb9997c2b870986c4fd377 Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Mon, 11 Aug 2025 08:13:26 +0000 Subject: [PATCH 08/47] functional ready --- CMakeLists.txt | 7 +- csrc/xpu/cutlass_kernels/chunk_prefill.hpp | 76 ++++++++-------------- csrc/xpu/cutlass_kernels/utils.hpp | 23 +++++++ 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 865455b..6c0e6f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -157,7 +157,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) - list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=0" "-fno-sycl-instrument-device-code" "-DVLLM_BUILD_XPU_OPS" ) + list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) @@ -232,6 +232,11 @@ if (FA2_ENABLED) message(STATUS "Enabling fa2 extension.") file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp") + # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-O3") + list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") + define_gpu_extension_target( _vllm_fa2_C DESTINATION vllm_xpu_kernels diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp index 1a25d0e..c96b8dd 100644 --- a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp +++ b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp @@ -87,9 +87,9 @@ template struct KernelLauncher { args.batch_size, args.num_heads_q, args.num_heads_k, - cutlass::fmha::collective::VariableLength{args.max_queries}, - cutlass::fmha::collective::VariableLength{args.max_keys}, - cutlass::fmha::collective::VariableLength{args.max_keys}, + cutlass::fmha::collective::VariableLength{args.max_queries}, // cu_q + cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_k + cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_v args.head_size, args.head_size); auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; @@ -113,7 +113,6 @@ template struct KernelLauncher { } cutlass::Status run(const chunk_prefill_args_t &args, const cutlass::KernelHardwareInfo &hw_info) { - std::cout << "into launcher run" << std::endl; ProblemShapeType problem_size = initialize(args); typename FMHAChunkPrefillKernel::Arguments arguments{ @@ -123,14 +122,14 @@ template struct KernelLauncher { reinterpret_cast(args.query), stride_Q, reinterpret_cast(args.key), stride_K, reinterpret_cast(args.value), stride_V, - nullptr, stride_K_cache, - nullptr, stride_V_cache, - reinterpret_cast(args.block_table), + reinterpret_cast(args.key), stride_K_cache, + reinterpret_cast(args.value), stride_V_cache, + static_cast(args.block_table), args.block_size, - reinterpret_cast(args.num_blocks_per_seq) + static_cast(args.num_blocks_per_seq) }, {args.sm_scale}, - {reinterpret_cast(args.value), stride_O}, + {reinterpret_cast(args.out), stride_O}, hw_info}; // Define device-global scratch memory @@ -156,7 +155,6 @@ template struct KernelLauncher { } static void run(typename FMHAChunkPrefillKernel::Params params) { - std::cout << "into final run" << std::endl; dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); @@ -168,14 +166,12 @@ template struct KernelLauncher { // Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension #if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) - std::cout << "into scratch mem" << std::endl; using namespace syclcompat::experimental; auto event = launch>( launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, kernel_properties{sycl_exp::sub_group_size}}, params); #else - std::cout << "into no scratch mem" << std::endl; syclcompat::experimental::launch_properties launch_props { sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), }; @@ -207,7 +203,6 @@ template static void run(const chunk_prefill_args_t &args) { - std::cout << "into FMHAKernel run" << std::endl; cutlass::KernelHardwareInfo hw_info; using LayoutQ = cutlass::layout::RowMajor; @@ -257,39 +252,32 @@ template +void policy_dispatch( CutlassType cuType, const chunk_prefill_args_t& args) { const int PipelineStages = 2; - using ShapeQK = Shape<_128, _64, _64>; - using ShapePV = Shape<_128, _32, _64>; - using ShapeOutPut = Shape<_128, _64, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; - if(args.head_size == HEAD_SIZE_LIMIT_0) { - using ShapeQK = Shape<_128, _64, _64>; - using ShapePV = Shape<_128, _32, _64>; - using ShapeOutPut = Shape<_128, _64, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; - } else if(args.head_size == HEAD_SIZE_LIMIT_1) { - using ShapeQK = Shape<_128, _64, _64>; - using ShapePV = Shape<_128, _32, _64>; - using ShapeOutPut = Shape<_128, _128, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + if(cuType == CutlassType::half) { + FMHAKernel::dispatch(args); } - else if(args.head_size == HEAD_SIZE_LIMIT_2) { - using ShapeQK = Shape<_256, _64, _64>; - using ShapePV = Shape<_256, _32, _64>; - using ShapeOutPut = Shape<_256, _192, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + else { + FMHAKernel::dispatch(args); } +} - if(cuType == CutlassType::half) { - FMHAKernel::dispatch(args); +void chunk_prefill_kernel( + CutlassType cuType, + const chunk_prefill_args_t& args) { + if(args.head_size == HEAD_SIZE_LIMIT_0) { + policy_dispatch(cuType, args); + } else if(args.head_size == HEAD_SIZE_LIMIT_1) { + policy_dispatch(cuType, args); } - else { - FMHAKernel::dispatch(args); + else if(args.head_size == HEAD_SIZE_LIMIT_2) { + policy_dispatch(cuType, args); } } @@ -305,10 +293,6 @@ void cutlass_chunk_prefill_impl( int max_seqlen_k, double sm_scale, bool is_causal) { - std::cout << "into cutlass_chunk_prefill_impl" << std::endl; - std::cout << "query.size(): " << query.sizes().vec() << std::endl; - std::cout << "key_cache.size(): " << key_cache.sizes().vec() << std::endl; - std::cout << "block_table.size(): " << block_table.sizes().vec() << std::endl; int num_block = key_cache.size(0); int block_size = key_cache.size(1); int num_heads_q = query.size(1); @@ -318,10 +302,7 @@ void cutlass_chunk_prefill_impl( int max_blocks_per_seq = block_table.size(1); int total_seqlen_q = query.size(0); int total_seqlen_k = num_block * block_size; - at::Tensor num_blocks_per_seq = cu_seqlens_k.slice(0, 1) - cu_seqlens_k.slice(0, 0, -1); - std::cout << "cu_seqlens_k: " << cu_seqlens_k << std::endl; - num_blocks_per_seq = torch::div(num_blocks_per_seq, block_size); - std::cout << "num_blocks_per_seq: " << num_blocks_per_seq << std::endl; + at::Tensor num_blocks_per_seq = torch::div(cu_seqlens_k, block_size); chunk_prefill_args_t args = { query.data_ptr(), @@ -345,7 +326,6 @@ void cutlass_chunk_prefill_impl( block_size, is_causal }; - CutlassType cuType = aten_to_Cutlass_dtype(query); chunk_prefill_kernel(cuType, args); } diff --git a/csrc/xpu/cutlass_kernels/utils.hpp b/csrc/xpu/cutlass_kernels/utils.hpp index 503f329..4715419 100644 --- a/csrc/xpu/cutlass_kernels/utils.hpp +++ b/csrc/xpu/cutlass_kernels/utils.hpp @@ -1,5 +1,6 @@ #pragma once #include "torch/all.h" +#include #define HEAD_SIZE_LIMIT_0 64 #define HEAD_SIZE_LIMIT_1 128 @@ -24,3 +25,25 @@ inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { } return cuType; } + +using namespace cute; +struct chunk_policy_head64 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct chunk_policy_head128 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct chunk_policy_head256 { + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; \ No newline at end of file From 4ef938fbc7735b565f56c3fccb49bca37cbaa072 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 20 Aug 2025 23:34:56 -0700 Subject: [PATCH 09/47] dev base Signed-off-by: Ma, Liangliang --- tests/flash_attn/test.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/flash_attn/test.py diff --git a/tests/flash_attn/test.py b/tests/flash_attn/test.py new file mode 100644 index 0000000..dfa6539 --- /dev/null +++ b/tests/flash_attn/test.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func + +DTYPES = [torch.half, torch.bfloat16] +dtype = torch.half + +torch.set_default_device("xpu") +batch_size = 1 +seq_len = 512 +num_heads = 8 +head_dim = 128 + +max_seqlen_q = seq_len +cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) +max_seqlen_k = seq_len +cu_seqlens_k = cu_seqlens_q + +block_size = 128 +num_blocks = max_seqlen_q // block_size +max_num_blocks_per_seq = seq_len // block_size + +block_tables = torch.randint(0, num_blocks, (batch_size, max_num_blocks_per_seq), dtype=torch.int32) + +print(block_tables) +print(cu_seqlens_q) + +q = torch.randn(sum(cu_seqlens_q), num_heads, head_dim, dtype=dtype) +k = torch.randn(num_blocks, block_size, num_heads, head_dim, dtype=dtype) +v = torch.randn(num_blocks, block_size, num_heads, head_dim, dtype=dtype) + +# Call the flash attention function +output= flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, + max_seqlen_k, cu_seqlens_k, block_table=block_tables) + +assert output is not None +assert output.dtype == dtype From 480c72fabb4a2d916e11d77526bc0962ab784532 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 21 Aug 2025 19:31:44 -0700 Subject: [PATCH 10/47] base of grouped_gemm_fp8 Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 37 +- csrc/quantization/grouped_gemm_fp8.cpp | 663 +++++++++++++++++++++++++ csrc/quantization/helper.h | 132 +++++ csrc/quantization/sycl_common.hpp | 149 ++++++ setup.py | 2 +- 5 files changed, 978 insertions(+), 5 deletions(-) create mode 100644 csrc/quantization/grouped_gemm_fp8.cpp create mode 100644 csrc/quantization/helper.h create mode 100644 csrc/quantization/sycl_common.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c0e6f3..60945c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,8 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21") set(TORCH_SUPPORTED_VERSION_XPU "2.8.0") set(ENABLE_MOE_KERNEL OFF) -set(FA2_ENABLED ON) +set(FA2_ENABLED OFF) +set(FP8_ENABLED ON) # # Try to find python package with an executable that exactly matches @@ -166,12 +167,12 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "chunk_prefill_BMG" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "sycl-develop" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( cutlass-sycl - GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git + GIT_REPOSITORY https://github.com/intel/cutlass-sycl.git # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE @@ -191,7 +192,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) - + FetchContent_MakeAvailable(cutlass-sycl) set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") @@ -259,6 +260,34 @@ if (FA2_ENABLED) # csrc/flash_attn/src) endif () +if (FP8_ENABLED) + message(STATUS "Enabling FP8 extension.") + file(GLOB FP8_GEN_SRCS "csrc/quantization/*.cpp") + + # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-O3") + list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") + + define_gpu_extension_target( + _vllm_fp8_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FP8_GEN_SRCS} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + + # target_include_directories(_vllm_fa2_C PRIVATE + # csrc/flash_attn + # csrc/flash_attn/src) +endif () # # _moe_C extension diff --git a/csrc/quantization/grouped_gemm_fp8.cpp b/csrc/quantization/grouped_gemm_fp8.cpp new file mode 100644 index 0000000..5f241aa --- /dev/null +++ b/csrc/quantization/grouped_gemm_fp8.cpp @@ -0,0 +1,663 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Group Gemm + + This file is almost a complete copy of 04_bmg_grouped_gemm, + except that it's used for FP8 (E5M2 & E4M3) datatype inputs. + + This example demonstrates fusing multiple GEMM operations into one kernel. + + Note that the scalar arguments to e.g. the standard 00_bmg_gemm example, have been + replaced with vector equivalents, as each individual GEMM has its own inputs and outputs, which + needn't be contiguous in memory. For example, where 00_bmg_gemm receives an `ElementA *` + defining Matrix A, grouped gemm receives a `ElementA **`, i.e. a pointer to pointers, each + pointing to a distinct Matrix A. Likewise, each individual GEMM operation may have its own alpha + and beta factors for linear combination. This example demonstrates two approaches: the user can + provide `options.alpha` and `options.beta`, in which case they will apply to all GEMMs; + otherwise, random values are generated per GEMM. + + Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than standard GEMM, + because each GEMM may have a unique size, only known at runtime. Thus, the scheduler will + distribute an a priori unknown number of tiles to each work-group. See + include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp for implementation. + + Note that for simplicity, this example sets every GEMM in the group to the same shape. + + Verification for this example is a conventional GEMM kernel, executed iteratively per group. + + To build & run this example (from your build dir): + + $ ninja 09_bmg_grouped_gemm_fp8 + $ ./examples/sycl/09_bmg_grouped_gemm_fp8/09_bmg_grouped_gemm_fp8 + + Call with `--help` for information about available options. + + Note: the code may spill registers once compiled which will result in sub-optimal performance. This is because + of an issue inside Intel Graphics Compiler (IGC) related to VectorAliasBBThreshold being debugged internally. + To avoid register spills, build the example by setting the environment variable: + $ export IGC_VectorAliasBBThreshold=10000 +*/ +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_array_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +#include + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementOutput = float; // <- data type of elements in output matrix D + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +// Command line options parsing +struct Options { + + bool error = false; + bool help = false; + + float alpha, beta; + int iterations; + int m, n, k, groups; + std::vector problem_sizes_host; + + Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), + m(5120), n(4096), k(4096), groups(2) { + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("groups", groups, 2); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG Grouped GEMM\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "09_bmg_grouped_gemm_fp8" << " --m=5120 --n=4096 --k=4096 --groups=5 --alpha=2.5 --beta=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementAccumulator = ElementOutput; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation problem_sizes; + + // This example defines all matrices in a single allocation (e.g. block_A), but this is not a + // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + // Note, this is an array of pointers to alpha and beta scaling values per group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + + uint64_t seed = 0; + + // + // Methods + // + template + bool verify(const Options &options) { + bool passed = true; + // Verify against individual reference GEMMs + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + cutlass::DeviceAllocation block_A_fp16(block_A.size()); + cutlass::DeviceAllocation block_B_fp16(block_B.size()); + + // fp8 -> fp16 + convert_dtype( + block_A.get(), + block_A_fp16.get(), + block_A.size() + ); + convert_dtype( + block_B.get(), + block_B_fp16.get(), + block_B.size() + ); + + cutlass::TensorRef ref_A(block_A_fp16.get() + offset_A.at(i), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B_fp16.get() + offset_B.at(i), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), LayoutD::packed({M, N})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha_host.at(i), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta_host.at(i), + ref_C, + ref_D, + ElementAccumulator(0), + 1, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Wait for kernel to finish + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + if(!passed) + break; + } + return passed; + } + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +template +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast(rand() % 5 + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers + // (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + + /// Populates a Gemm::Arguments structure from the given commandline options + typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + + return arguments; + } + + template + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) { + allocate(options); + initialize(options); + + Gemm gemm_op; + + auto arguments = args_from_options(options, hw_info, host_problem_shapes_available); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + syclcompat::wait(); + + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); + if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e4m3_t"<< std::endl; + } else if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e5m2_t"<< std::endl; + } else { + static_assert(cutlass::detail::dependent_false, "Not a valid fp8 datatype."); + } + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + } + + return cutlass::Status::kSuccess; + } + +}; + + +template +int launcher(Options& options) +{ + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementType, + cutlass::gemm::TagToStrideA_t, + ElementType, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::GroupScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.template run(options, hw_info)); + + return 0; +} + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + launcher(options); + launcher(options); + return 0; +} diff --git a/csrc/quantization/helper.h b/csrc/quantization/helper.h new file mode 100644 index 0000000..2b1a045 --- /dev/null +++ b/csrc/quantization/helper.h @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_timer.hpp" +#else +#include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) + { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() + { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() + { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; diff --git a/csrc/quantization/sycl_common.hpp b/csrc/quantization/sycl_common.hpp new file mode 100644 index 0000000..06fcd44 --- /dev/null +++ b/csrc/quantization/sycl_common.hpp @@ -0,0 +1,149 @@ +/*************************************************************************************************** +* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +/// Helper to initialize a block of device data +template +bool initialize_block(Element* block, std::size_t size, uint64_t seed=2023) { + + Element scope_max = Element(1 << cute::ceil_div(std::numeric_limits::digits, 4)); + Element scope_min = cute::is_signed::value ? Element(-scope_max) : Element(0); + + cutlass::reference::device::BlockFillRandomUniform( + block, size, seed, scope_max, scope_min, 0); + + syclcompat::wait(); + return true; +} + +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + return initialize_block(block.get(), block.size(), seed); +} + +template +void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, + cutlass::DeviceAllocation& block_device_dq, + uint64_t seed) { + static_assert(cute::sizeof_bits_v >= 8); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + T1 scope_max = T1(1 << cute::ceil_div(std::numeric_limits::digits, 4)); + T1 scope_min = cute::is_signed::value ? T1(-scope_max) : T1(0); + + std::uniform_int_distribution<> dist(scope_min, scope_max); + + if constexpr (cute::sizeof_bits_v >= 8) { + auto block_host = std::vector(block_device.size()); + auto block_host_dq = std::vector(block_device.size()); + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_dq.copy_from_host(block_host_dq.data()); + } else { + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + auto block_host_dq = std::vector(array_size); + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i].get()); + } + + static constexpr auto elements_per_byte = cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = block_device.size() / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device(block_device.get() + (i * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), + array_size / elements_per_byte); + cutlass::device_memory::copy_to_device(block_device_dq.get() + i * array_size, + block_host_dq.data(), + array_size); + } + + auto tail_size = block_device.size() % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device(block_device.get() + (loop_cnt * array_size) / elements_per_byte, + raw_pointer_cast(block_host.begin()), + tail_size / elements_per_byte); + cutlass::device_memory::copy_to_device(block_device_dq.get() + loop_cnt * array_size, + block_host_dq.data(), + tail_size); + } + } +} + +template +inline +bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +// TODO(Codeplay): use on device initialisation for this +template +inline +void random_fill(T *src, int seed, size_t N, float max, float min) { + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { + std::random_device rd; + std::mt19937 gen(seed); + std::uniform_real_distribution dis(min, max); + auto buff = std::vector(N); + + for (size_t i = 0; i < N; ++i) { + buff[i] = (T)(dis(gen)); + } + syclcompat::memcpy(src, buff.data(), N); + syclcompat::wait(); + } else { + assert(0 & "Not supported dtype"); + } +} + +template +void convert_dtype(const SrcT* d_src, DstT* d_dst, size_t size) { + syclcompat::get_default_queue().parallel_for(size, [=](auto indx) { + d_dst[indx] = static_cast(d_src[indx]); + }).wait(); +} diff --git a/setup.py b/setup.py index c7d8660..ec51ce4 100644 --- a/setup.py +++ b/setup.py @@ -259,7 +259,7 @@ def run(self): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C")) - ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fa2_C")) + ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fp8_C")) if ext_modules: cmdclass = {"build_ext": cmake_build_ext} From 24709b8db1a6e9f3fc106c21f4634e011c67711c Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 26 Aug 2025 11:18:34 +0000 Subject: [PATCH 11/47] update func Signed-off-by: Ma, Liangliang --- csrc/quantization/cutlass_kernels.cpp | 45 ++++++++++ ...rouped_gemm_fp8.cpp => grouped_gemm_fp8.h} | 89 +++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 csrc/quantization/cutlass_kernels.cpp rename csrc/quantization/{grouped_gemm_fp8.cpp => grouped_gemm_fp8.h} (89%) diff --git a/csrc/quantization/cutlass_kernels.cpp b/csrc/quantization/cutlass_kernels.cpp new file mode 100644 index 0000000..430e699 --- /dev/null +++ b/csrc/quantization/cutlass_kernels.cpp @@ -0,0 +1,45 @@ +#include "grouped_gemm_fp8.h" + +#include +#include +#include + +#include +/* #include "pytorch_shim.h" */ + +#include "core/registration.h" +#include +#include "xpu/utils.h" + +namespace gpu::cutlass_kernel { + +at::Tensor grouped_gemm_func( + at::Tensor& input, + at::Tensor& weight, + at::Tensor& mnks, // [M, K], half + at::Tensor& res + ) { + auto dpcpp_queue = vllm::xpu::vllmGetQueue(); + if (input.scalar_type() != at::kFloat8_e4m3fn) { + std::cout << "error:wrong datatype" << std::endl; + return at::Tensor(); + } + + grouped_gemm::kernel_functor( + &dpcpp_queue, + input.data_ptr(), + weight.data_ptr(), + mnks.data_ptr(), + res.data_ptr() + ); + return res; +} + +} // namespace gpu::cutlass_kernel + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor mnks) -> Tensor"); + ops.impl("cutlass_grouped_gemm", torch::kXPU, gpu::cutlass_kernel::grouped_gemm_func); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME); diff --git a/csrc/quantization/grouped_gemm_fp8.cpp b/csrc/quantization/grouped_gemm_fp8.h similarity index 89% rename from csrc/quantization/grouped_gemm_fp8.cpp rename to csrc/quantization/grouped_gemm_fp8.h index 5f241aa..8a6629f 100644 --- a/csrc/quantization/grouped_gemm_fp8.cpp +++ b/csrc/quantization/grouped_gemm_fp8.h @@ -98,6 +98,8 @@ using ElementOutput = float; // <- data type of elements in output matr /////////////////////////////////////////////////////////////////////////////////////////////////// +namespace gpu::cutlass_kernel { +namespace grouped_gemm { // Command line options parsing struct Options { @@ -661,3 +663,90 @@ int main(int argc, const char** argv) launcher(options); return 0; } + +void kernel_functor( + sycl::queue* stream, + void* input, + void* weight, + void* mnks, + void* res){ + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementType, + cutlass::gemm::TagToStrideA_t, + ElementType, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::GroupScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.template run(options, hw_info)); + +} + +} // namespace grouped_gemm +} // namespace gpu::cutlass_kernel From f5757a94678951156764a13b293fb637a5df6046 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Fri, 29 Aug 2025 09:09:13 +0000 Subject: [PATCH 12/47] add test Signed-off-by: Ma, Liangliang --- mll_build.sh | 2 + tests/quantization/test_fused_moe.py | 67 +++++++++++++++++++++++++ vllm_xpu_kernels/__init__.py | 1 + vllm_xpu_kernels/fused_moe_interface.py | 10 ++++ 4 files changed, 80 insertions(+) create mode 100644 mll_build.sh create mode 100644 tests/quantization/test_fused_moe.py create mode 100644 vllm_xpu_kernels/fused_moe_interface.py diff --git a/mll_build.sh b/mll_build.sh new file mode 100644 index 0000000..3e5ace0 --- /dev/null +++ b/mll_build.sh @@ -0,0 +1,2 @@ +python3 setup.py clean +VLLM_TARGET_DEVICE=xpu python3 setup.py develop diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py new file mode 100644 index 0000000..ba1194c --- /dev/null +++ b/tests/quantization/test_fused_moe.py @@ -0,0 +1,67 @@ +import pytest +import torch +from math import ceil +from typing import Callable, Optional, Union +from vllm_xpu_kernels.fused_moe_interface import cutlass_fused_moe + +NUM_EXPERTS = [8, 64, 192] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] + +FUSED_MOE_MNK_FACTORS = [ + (1, 128, 128), + (1, 2048, 128), + (33, 2048, 128), + (222, 1024, 1024), + (32768, 128, 128), + (32768, 2048, 511), + (40000, 1024, 1024), +] + +FUSED_MOE_WN16_MNK_FACTORS = [ + (1, 128, 128), + (1, 1024, 1024), + (32, 2048, 128), + (32, 1024, 1024), + (222, 2048, 1024), +] + +DEVICE = "xpu" + +# @pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("ep_size", EP_SIZE) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + ): + # todo: seed + + # + # Setup test data + # + + a = torch.randn((m, k), device=DEVICE, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=DEVICE, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=DEVICE, dtype=dtype) / 10 + + score = torch.randn((m, e), device=DEVICE, dtype=dtype) + cutlass_fused_moe() + +if __name__ == "__main__": + test_fused_moe( + m = 33, + n = 2048, + k = 128, + e = 16, + topk = 2, + ep_size = 1, + dtype = torch.bfloat16 + ) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py index 8635cc2..e6fe0b1 100644 --- a/vllm_xpu_kernels/__init__.py +++ b/vllm_xpu_kernels/__init__.py @@ -1,3 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 +from .fused_moe_interface import cutlass_fused_moe diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py new file mode 100644 index 0000000..2295e91 --- /dev/null +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -0,0 +1,10 @@ +import torch +from torch.nn.modules.utils import _pair +from torch import nn, Tensor +from typing import List + + +# from . import _vllm_fp8_C + +def cutlass_fused_moe(): + pass From 435e6df0ead87e2b7c5fe27d217187eb66023064 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Fri, 29 Aug 2025 12:59:45 +0000 Subject: [PATCH 13/47] update functor Signed-off-by: Ma, Liangliang --- csrc/quantization/cutlass_kernels.cpp | 22 +- csrc/quantization/grouped_gemm.h | 351 +++++++++++ csrc/quantization/grouped_gemm_fp8.h | 752 ------------------------ tests/quantization/test_fused_moe.py | 30 +- vllm_xpu_kernels/fused_moe_interface.py | 49 +- 5 files changed, 433 insertions(+), 771 deletions(-) create mode 100644 csrc/quantization/grouped_gemm.h delete mode 100644 csrc/quantization/grouped_gemm_fp8.h diff --git a/csrc/quantization/cutlass_kernels.cpp b/csrc/quantization/cutlass_kernels.cpp index 430e699..9d13c84 100644 --- a/csrc/quantization/cutlass_kernels.cpp +++ b/csrc/quantization/cutlass_kernels.cpp @@ -13,15 +13,20 @@ namespace gpu::cutlass_kernel { +/* gemm2(group_A, w2, output, offset) */ + at::Tensor grouped_gemm_func( at::Tensor& input, at::Tensor& weight, - at::Tensor& mnks, // [M, K], half - at::Tensor& res + at::Tensor& res, + at::Tensor& offset, + int64_t hidden_size, + int64_t intermediate_size, + int64_t num_of_expert ) { auto dpcpp_queue = vllm::xpu::vllmGetQueue(); - if (input.scalar_type() != at::kFloat8_e4m3fn) { - std::cout << "error:wrong datatype" << std::endl; + if (input.scalar_type() != at::at::kBFloat16) { + std::cout << "error:wrong datatype, current only support bfloat16" << std::endl; return at::Tensor(); } @@ -29,8 +34,11 @@ at::Tensor grouped_gemm_func( &dpcpp_queue, input.data_ptr(), weight.data_ptr(), - mnks.data_ptr(), - res.data_ptr() + res.data_ptr(), + offset.data_ptr(), + hidden_size, + intermediate_size, + num_of_expert ); return res; } @@ -38,7 +46,7 @@ at::Tensor grouped_gemm_func( } // namespace gpu::cutlass_kernel TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor mnks) -> Tensor"); + ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert) -> Tensor"); ops.impl("cutlass_grouped_gemm", torch::kXPU, gpu::cutlass_kernel::grouped_gemm_func); } diff --git a/csrc/quantization/grouped_gemm.h b/csrc/quantization/grouped_gemm.h new file mode 100644 index 0000000..701d38b --- /dev/null +++ b/csrc/quantization/grouped_gemm.h @@ -0,0 +1,351 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Group Gemm + + This file is almost a complete copy of 04_bmg_grouped_gemm, + except that it's used for FP8 (E5M2 & E4M3) datatype inputs. + + This example demonstrates fusing multiple GEMM operations into one kernel. + + Note that the scalar arguments to e.g. the standard 00_bmg_gemm example, have been + replaced with vector equivalents, as each individual GEMM has its own inputs and outputs, which + needn't be contiguous in memory. For example, where 00_bmg_gemm receives an `ElementA *` + defining Matrix A, grouped gemm receives a `ElementA **`, i.e. a pointer to pointers, each + pointing to a distinct Matrix A. Likewise, each individual GEMM operation may have its own alpha + and beta factors for linear combination. This example demonstrates two approaches: the user can + provide `options.alpha` and `options.beta`, in which case they will apply to all GEMMs; + otherwise, random values are generated per GEMM. + + Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than standard GEMM, + because each GEMM may have a unique size, only known at runtime. Thus, the scheduler will + distribute an a priori unknown number of tiles to each work-group. See + include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp for implementation. + + Note that for simplicity, this example sets every GEMM in the group to the same shape. + + Verification for this example is a conventional GEMM kernel, executed iteratively per group. + + To build & run this example (from your build dir): + + $ ninja 09_bmg_grouped_gemm_fp8 + $ ./examples/sycl/09_bmg_grouped_gemm_fp8/09_bmg_grouped_gemm_fp8 + + Call with `--help` for information about available options. + + Note: the code may spill registers once compiled which will result in sub-optimal performance. This is because + of an issue inside Intel Graphics Compiler (IGC) related to VectorAliasBBThreshold being debugged internally. + To avoid register spills, build the example by setting the environment variable: + $ export IGC_VectorAliasBBThreshold=10000 +*/ +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_array_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +#include + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementA = bfloat16_t; // <- data type of elements in input matrix A +using ElementB = bfloat16_t; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace gpu::cutlass_kernel { +namespace grouped_gemm { + +struct Options { + + bool error = false; + bool help = false; + + float alpha, beta; + int iterations; + int m, n, k, groups; + std::vector problem_sizes_host; + + int hidden_size; + int intermediate_size; + int* offset; + int num_of_expert; + + Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), + m(5120), n(4096), k(4096), groups(2) { + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + +}; + + +template +struct GroupedGemmRunner { + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideScaleA = typename Gemm::GemmKernel::StrideA; + using StrideScaleB = typename Gemm::GemmKernel::StrideB; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAcc = typename Gemm::ElementAccumulator; + using ElementScaleA = cutlass::half_t; + using ElementScaleB = cutlass::half_t; + using ElementOffset = int64_t; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + StrideScaleA stride_scaleA; + StrideScaleB stride_scaleB; + + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_scaleA = cutlass::make_cute_packed_stride( + StrideScaleA{}, cute::make_shape(M, K, L)); + stride_scaleB = cutlass::make_cute_packed_stride( + StrideScaleB{}, cute::make_shape(N, K, L)); + } + + cutlass::Status run( + sycl::queue* stream, + const cutlass::KernelHardwareInfo& hw_info, + ElementA* inputA, + ElementB* inputB, + ElementOffset* offset, + ElementOutput* res, + int64_t hidden_size, + int64_t intermediate_size, + int64_t num_of_expert) { + + Options options(offset, hidden_size, intermediate_size, num_of_expert); + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {inputA, + stride_A, + inputB, + stride_B, + scaleA, + stride_scaleA, + scaleB, + stride_scaleB}, + {{args.alpha, args.beta}, nullptr, stride_C, res, stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (args.n < 16) { + std::cout + << "Invalid Problem Size: N must be >= 16 for FP8 input with F16 MMA (XE_8x16x16_F32F16F16F32_TT). Got N=" + << args.n << std::endl; + std::exit(1); + } + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Invalid Problem Size: " << args.m << 'x' << args.n << 'x' + << args.k << 'x' << args.l << std::endl; + std::exit(1); + } + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(stream); + + stream->throw_asynchronous(); + + return cutlass::Status::kSuccess; + } +}; + +void kernel_functor( + sycl::queue* stream, + void* input, + void* weight, + void* res, + void* offset, + int64_t hidden_size, + int64_t intermediate_size, + int64_t num_of_expert){ + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementOffset = int64_t; + using ElementOutput = float; + using ElementScale = cutlass::bfloat16_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using TileShape = Shape<_256, _256, _32>; + using GmemTiledCopyA = + XE_2D_U16x32x32_LD_N; // Note: This shape has to match the shape used for + // the scaling factors + using GmemTiledCopyB = + XE_2D_U16x32x32_LD_V; // Note: This shape has to match the shape used for + // the scaling factors + + using TiledMma = + TiledMMA, + Layout, Stride<_4, _1, _0>>, + Tile, Stride<_1, _32, _8>>, + Layout, Stride<_1, _64, _16>>, _32>>; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::GroupScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + GroupedGemmRunner runner; + runner.run( + stream, + hw_info, + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(offset), + reinterpret_cast(res), + hidden_size, + intermediate_size, + num_of_expert); + +} + +} // namespace grouped_gemm +} // namespace gpu::cutlass_kernel diff --git a/csrc/quantization/grouped_gemm_fp8.h b/csrc/quantization/grouped_gemm_fp8.h deleted file mode 100644 index 8a6629f..0000000 --- a/csrc/quantization/grouped_gemm_fp8.h +++ /dev/null @@ -1,752 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief CUTLASS Intel BMG Group Gemm - - This file is almost a complete copy of 04_bmg_grouped_gemm, - except that it's used for FP8 (E5M2 & E4M3) datatype inputs. - - This example demonstrates fusing multiple GEMM operations into one kernel. - - Note that the scalar arguments to e.g. the standard 00_bmg_gemm example, have been - replaced with vector equivalents, as each individual GEMM has its own inputs and outputs, which - needn't be contiguous in memory. For example, where 00_bmg_gemm receives an `ElementA *` - defining Matrix A, grouped gemm receives a `ElementA **`, i.e. a pointer to pointers, each - pointing to a distinct Matrix A. Likewise, each individual GEMM operation may have its own alpha - and beta factors for linear combination. This example demonstrates two approaches: the user can - provide `options.alpha` and `options.beta`, in which case they will apply to all GEMMs; - otherwise, random values are generated per GEMM. - - Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than standard GEMM, - because each GEMM may have a unique size, only known at runtime. Thus, the scheduler will - distribute an a priori unknown number of tiles to each work-group. See - include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp for implementation. - - Note that for simplicity, this example sets every GEMM in the group to the same shape. - - Verification for this example is a conventional GEMM kernel, executed iteratively per group. - - To build & run this example (from your build dir): - - $ ninja 09_bmg_grouped_gemm_fp8 - $ ./examples/sycl/09_bmg_grouped_gemm_fp8/09_bmg_grouped_gemm_fp8 - - Call with `--help` for information about available options. - - Note: the code may spill registers once compiled which will result in sub-optimal performance. This is because - of an issue inside Intel Graphics Compiler (IGC) related to VectorAliasBBThreshold being debugged internally. - To avoid register spills, build the example by setting the environment variable: - $ export IGC_VectorAliasBBThreshold=10000 -*/ -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/collective/xe_array_epilogue.hpp" -#include "cutlass/epilogue/fusion/xe_callbacks.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/device/gemm_universal.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/collective/collective_mma.hpp" -#include "cutlass/util/GPU_Clock.hpp" - -#include -#include - -#include "cutlass/util/command_line.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/device/gemm_complex.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "sycl_common.hpp" -#include "helper.h" - -#include - -using namespace cute; -using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group - -using ElementAccumulator = float; // <- data type of accumulator -using ElementComputeEpilogue = float; // <- data type of epilogue operations -using ElementOutput = float; // <- data type of elements in output matrix D - -/////////////////////////////////////////////////////////////////////////////////////////////////// - - -namespace gpu::cutlass_kernel { -namespace grouped_gemm { -// Command line options parsing -struct Options { - - bool error = false; - bool help = false; - - float alpha, beta; - int iterations; - int m, n, k, groups; - std::vector problem_sizes_host; - - Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), - m(5120), n(4096), k(4096), groups(2) { - problem_sizes_host.reserve(groups); - for(int i = 0; i < groups; i++) { - problem_sizes_host.push_back({m, n, k}); - } - } - - // Parses the command line - void parse(int argc, char const **args) { - cutlass::CommandLine cmd(argc, args); - - if (cmd.check_cmd_line_flag("help")) { - help = true; - return; - } - - cmd.get_cmd_line_argument("m", m, 5120); - cmd.get_cmd_line_argument("n", n, 4096); - cmd.get_cmd_line_argument("k", k, 4096); - cmd.get_cmd_line_argument("groups", groups, 2); - cmd.get_cmd_line_argument("alpha", alpha, 1.f); - cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); - - assert(groups > 0); - problem_sizes_host.clear(); - problem_sizes_host.reserve(groups); - for(int i = 0; i < groups; i++) { - problem_sizes_host.push_back({m, n, k}); - } - } - - /// Prints the usage statement. - std::ostream & print_usage(std::ostream &out) const { - - out << "BMG Grouped GEMM\n\n" - << "Options:\n\n" - << " --help If specified, displays this usage statement\n\n" - << " --m= Sets the M extent of the GEMM for all groups\n" - << " --n= Sets the N extent of the GEMM for all groups\n" - << " --k= Sets the K extent of the GEMM for all groups\n" - << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" - << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Number of profiling iterations to perform\n\n"; - - out - << "\n\nExamples:\n\n" - << "$ " << "09_bmg_grouped_gemm_fp8" << " --m=5120 --n=4096 --k=4096 --groups=5 --alpha=2.5 --beta=0.5 \n\n"; - - return out; - } - - /// Compute performance in GFLOP/s - double gflops(double runtime_s, std::vector problem_sizes_host) const - { - // Number of real-valued multiply-adds - uint64_t fmas = uint64_t(); - - for (auto const & problem : problem_sizes_host) { - fmas += static_cast(get<0>(problem)) * - static_cast(get<1>(problem)) * - static_cast(get<2>(problem)); - } - // Two flops per multiply-add - uint64_t flop = uint64_t(2) * uint64_t(fmas); - double gflop = double(flop) / double(1.0e9); - return gflop / runtime_s; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class Gemm -> -struct ExampleRunner { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - - using LayoutA = typename Gemm::LayoutA; - using LayoutB = typename Gemm::LayoutB; - using LayoutC = typename Gemm::LayoutC; - using LayoutD = typename Gemm::LayoutD; - - using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementAccumulator = ElementOutput; - - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - // Host-side allocations - std::vector offset_A; - std::vector offset_B; - std::vector offset_C; - std::vector offset_D; - - std::vector stride_A_host; - std::vector stride_B_host; - std::vector stride_C_host; - std::vector stride_D_host; - - std::vector alpha_host; - std::vector beta_host; - - // Device-side allocations - cutlass::DeviceAllocation problem_sizes; - - // This example defines all matrices in a single allocation (e.g. block_A), but this is not a - // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) - cutlass::DeviceAllocation block_A; - cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_D; - cutlass::DeviceAllocation block_ref_D; - - cutlass::DeviceAllocation ptr_A; - cutlass::DeviceAllocation ptr_B; - cutlass::DeviceAllocation ptr_C; - cutlass::DeviceAllocation ptr_D; - cutlass::DeviceAllocation ptr_ref_D; - - cutlass::DeviceAllocation stride_A; - cutlass::DeviceAllocation stride_B; - cutlass::DeviceAllocation stride_C; - cutlass::DeviceAllocation stride_D; - - // Note, this is an array of pointers to alpha and beta scaling values per group - cutlass::DeviceAllocation alpha_device; - cutlass::DeviceAllocation beta_device; - cutlass::DeviceAllocation block_alpha; - cutlass::DeviceAllocation block_beta; - - uint64_t seed = 0; - - // - // Methods - // - template - bool verify(const Options &options) { - bool passed = true; - // Verify against individual reference GEMMs - for (int32_t i = 0; i < options.groups; ++i) { - auto problem = options.problem_sizes_host.at(i); - auto M = get<0>(problem); - auto N = get<1>(problem); - auto K = get<2>(problem); - - cutlass::DeviceAllocation block_A_fp16(block_A.size()); - cutlass::DeviceAllocation block_B_fp16(block_B.size()); - - // fp8 -> fp16 - convert_dtype( - block_A.get(), - block_A_fp16.get(), - block_A.size() - ); - convert_dtype( - block_B.get(), - block_B_fp16.get(), - block_B.size() - ); - - cutlass::TensorRef ref_A(block_A_fp16.get() + offset_A.at(i), LayoutA::packed({M, K})); - cutlass::TensorRef ref_B(block_B_fp16.get() + offset_B.at(i), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), LayoutC::packed({M, N})); - cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), LayoutD::packed({M, N})); - - // - // Compute reference output - // - cutlass::reference::device::GemmComplex( - {M, N, K}, - alpha_host.at(i), - ref_A, - cutlass::ComplexTransform::kNone, - ref_B, - cutlass::ComplexTransform::kNone, - beta_host.at(i), - ref_C, - ref_D, - ElementAccumulator(0), - 1, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); - - // Wait for kernel to finish - syclcompat::wait(); - - // Check if output from CUTLASS kernel and reference kernel are equal or not - passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); - if(!passed) - break; - } - return passed; - } - -/// Allocates device-side data -void allocate(const Options &options) { - int64_t total_elements_A = 0; - int64_t total_elements_B = 0; - int64_t total_elements_C = 0; - int64_t total_elements_D = 0; - - // Compute total allocation sizes across group - for (int32_t i = 0; i < options.groups; ++i) { - - auto problem = options.problem_sizes_host.at(i); - auto M = get<0>(problem); - auto N = get<1>(problem); - auto K = get<2>(problem); - - // Offset into block allocation of each matrix base pointer - offset_A.push_back(total_elements_A); - offset_B.push_back(total_elements_B); - offset_C.push_back(total_elements_C); - offset_D.push_back(total_elements_D); - - int64_t elements_A = M * K; - int64_t elements_B = K * N; - int64_t elements_C = M * N; - int64_t elements_D = M * N; - - total_elements_A += elements_A; - total_elements_B += elements_B; - total_elements_C += elements_C; - total_elements_D += elements_D; - - stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); - stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); - stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); - stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); - - } - - block_A.reset(total_elements_A); - block_B.reset(total_elements_B); - block_C.reset(total_elements_C); - block_D.reset(total_elements_D); - block_ref_D.reset(total_elements_D); - block_alpha.reset(options.groups); - block_beta.reset(options.groups); -} - -/// Initialize operands to be used in the GEMM and reference GEMM -template -void initialize(const Options &options) { - - uint64_t seed = 2020; - - problem_sizes.reset(options.groups); - problem_sizes.copy_from_host(options.problem_sizes_host.data()); - - // - // Assign pointers - // - - std::vector ptr_A_host(options.groups); - std::vector ptr_B_host(options.groups); - std::vector ptr_C_host(options.groups); - std::vector ptr_D_host(options.groups); - std::vector ptr_alpha_host(options.groups); - std::vector ptr_beta_host(options.groups); - - // Compute offsets, alpha & beta over group on host - for (int32_t i = 0; i < options.groups; ++i) { - ptr_A_host.at(i) = block_A.get() + offset_A.at(i); - ptr_B_host.at(i) = block_B.get() + offset_B.at(i); - ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - ptr_D_host.at(i) = block_D.get() + offset_D.at(i); - // Fill host vector of alpha & beta with random values if using per-group values - alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast(rand() % 5 + 1) : options.alpha); - beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); - // Fill host ptr vectors with offset addresses into device alpha/beta blocks - ptr_alpha_host.at(i) = block_alpha.get() + i; - ptr_beta_host.at(i) = block_beta.get() + i; - } - - // Allocate device memory & copy from host - ptr_A.reset(options.groups); - // Per-group alpha and beta - ptr_A.copy_from_host(ptr_A_host.data()); - - ptr_B.reset(options.groups); - ptr_B.copy_from_host(ptr_B_host.data()); - - ptr_C.reset(options.groups); - ptr_C.copy_from_host(ptr_C_host.data()); - - ptr_D.reset(options.groups); - ptr_D.copy_from_host(ptr_D_host.data()); - - stride_A.reset(options.groups); - stride_A.copy_from_host(stride_A_host.data()); - - stride_B.reset(options.groups); - stride_B.copy_from_host(stride_B_host.data()); - - stride_C.reset(options.groups); - stride_C.copy_from_host(stride_C_host.data()); - - stride_D.reset(options.groups); - stride_D.copy_from_host(stride_D_host.data()); - - // Per-group alpha and beta ptrs - alpha_device.reset(options.groups); - alpha_device.copy_from_host(ptr_alpha_host.data()); - beta_device.reset(options.groups); - beta_device.copy_from_host(ptr_beta_host.data()); - - initialize_block(block_A, seed + 2023); - initialize_block(block_B, seed + 2022); - initialize_block(block_C, seed + 2021); - // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers - // (alpha_device/beta_device) are passed instead - block_alpha.copy_from_host(alpha_host.data()); - block_beta.copy_from_host(beta_host.data()); -} - - /// Populates a Gemm::Arguments structure from the given commandline options - typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) - { - typename Gemm::Arguments arguments; - decltype(arguments.epilogue.thread) fusion_args; - - if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { - // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. - fusion_args.alpha = options.alpha; - fusion_args.beta = options.beta; - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = nullptr; - fusion_args.beta_ptr_array = nullptr; - // Single alpha and beta for all groups - fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; - fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; - } - else { - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = alpha_device.get(); - fusion_args.beta_ptr_array = beta_device.get(); - // One alpha and beta per each group - fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; - fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; - } - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; - - // Per-GEMM problem shape info may only exist on the device. - if (host_problem_shapes_available) { - arguments = typename Gemm::Arguments { - cutlass::gemm::GemmUniversalMode::kGrouped, - {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, - hw_info, - {1, RasterOrderOptions::AlongN} - }; - } - else { - arguments = typename Gemm::Arguments { - cutlass::gemm::GemmUniversalMode::kGrouped, - {options.groups, problem_sizes.get(), nullptr}, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, - hw_info, - {1, RasterOrderOptions::AlongN} - }; - } - - return arguments; - } - - template - cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) { - allocate(options); - initialize(options); - - Gemm gemm_op; - - auto arguments = args_from_options(options, hw_info, host_problem_shapes_available); - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - CUTLASS_CHECK(gemm_op.can_implement(arguments)); - - CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); - - // Run the GEMM - CUTLASS_CHECK(gemm_op.run()); - - syclcompat::wait(); - - // Verify that the result is correct - bool passed = verify(options); - std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; - - if(!passed) return cutlass::Status::kErrorInternal; - - if (options.iterations > 0) { - GPU_Clock timer; - timer.start(); - for (int iter = 0; iter < options.iterations; ++iter) { - CUTLASS_CHECK(gemm_op.run()); - } - syclcompat::wait(); - - float cute_time = timer.seconds() * 1000; - double cute_average_time = double(cute_time) / double(options.iterations); - double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); - if constexpr (std::is_same_v) { - std::cout << "Datatype: float_e4m3_t"<< std::endl; - } else if constexpr (std::is_same_v) { - std::cout << "Datatype: float_e5m2_t"<< std::endl; - } else { - static_assert(cutlass::detail::dependent_false, "Not a valid fp8 datatype."); - } - std::cout << " Problem Sizes, Alpha, Beta " << std::endl; - for (int32_t i = 0; i < options.groups; ++i) { - std::cout << " " << options.problem_sizes_host.at(i); - std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; - } - std::cout << " Groups : " << options.groups << std::endl; - std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; - std::cout << " GFLOPS : " << gflops << std::endl; - } - - return cutlass::Status::kSuccess; - } - -}; - - -template -int launcher(Options& options) -{ - // - // Run examples - // - - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This - // information is used by the underlying kernel. - cutlass::KernelHardwareInfo hw_info; - - // Change device_id to another value if you are running on a machine with multiple GPUs and wish - // to use a GPU other than that with device ID 0. - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::RowMajor; - using LayoutD = cutlass::layout::RowMajor; - - using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; - using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; - - // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - constexpr int PipelineStages = 2; - // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; - -// Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementType, - cutlass::gemm::TagToStrideA_t, - ElementType, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::GroupScheduler - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - ExampleRunner runner; - - CUTLASS_CHECK(runner.template run(options, hw_info)); - - return 0; -} - -int main(int argc, const char** argv) -{ - // - // Parse options - // - - Options options; - - options.parse(argc, argv); - - if (options.help) { - options.print_usage(std::cout) << std::endl; - return 0; - } - - if (options.error) { - std::cerr << "Aborting execution." << std::endl; - return -1; - } - - launcher(options); - launcher(options); - return 0; -} - -void kernel_functor( - sycl::queue* stream, - void* input, - void* weight, - void* mnks, - void* res){ - // - // Run examples - // - - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This - // information is used by the underlying kernel. - cutlass::KernelHardwareInfo hw_info; - - // Change device_id to another value if you are running on a machine with multiple GPUs and wish - // to use a GPU other than that with device ID 0. - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::RowMajor; - using LayoutD = cutlass::layout::RowMajor; - - using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; - using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; - - // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - constexpr int PipelineStages = 2; - // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; - -// Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementType, - cutlass::gemm::TagToStrideA_t, - ElementType, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::GroupScheduler - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - ExampleRunner runner; - - CUTLASS_CHECK(runner.template run(options, hw_info)); - -} - -} // namespace grouped_gemm -} // namespace gpu::cutlass_kernel diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py index ba1194c..e28abbc 100644 --- a/tests/quantization/test_fused_moe.py +++ b/tests/quantization/test_fused_moe.py @@ -34,9 +34,9 @@ # @pytest.mark.parametrize("ep_size", EP_SIZE) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_fused_moe( - m: int, - n: int, - k: int, + m: int, # num of tokens + n: int, # intermediate_size + k: int, # hidden_size e: int, topk: int, ep_size: int, @@ -44,16 +44,28 @@ def test_fused_moe( ): # todo: seed - # # Setup test data - # - a = torch.randn((m, k), device=DEVICE, dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device=DEVICE, dtype=dtype) / 10 + w13 = torch.randn((e, 2 * n, k), device=DEVICE, dtype=dtype) / 10 w2 = torch.randn((e, k, n), device=DEVICE, dtype=dtype) / 10 - score = torch.randn((m, e), device=DEVICE, dtype=dtype) - cutlass_fused_moe() + # moe gate + scores = torch.randn((m, e), device=DEVICE, dtype=dtype) + expert_indices, expert_scores = torch.topk(scores, k=topk, dim=-1, sorted=False) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1, 1) + + cutlass_fused_moe(hidden_states=a, + w13=w1, + w2=w2, + topk_weights=flat_expert_weights, + topk_ids=flat_expert_indices, + n_experts_per_token=topk, + inplace=True, + activation="silu", + num_experts=e) + + print("result", hidden_states) if __name__ == "__main__": test_fused_moe( diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 2295e91..00a79a7 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -2,9 +2,52 @@ from torch.nn.modules.utils import _pair from torch import nn, Tensor from typing import List - +import numpy # from . import _vllm_fp8_C +def ref_gemm1(x, weight): + w1, w3 = torch.split(w13, list(weight.shape)[0]/2, dim=0) + act_fn = torch.nn.SiLU() + gate = (x @ w1.T).silu() + up = x @ w3.T + return gate * up + + + +def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, inplace, activation, num_experts): + token_cnt, hidden_size = list(hidden_states.shape) + intermediate_size = list(w2.shape)[-1] + expert_cache = torch.empty_like(hidden_states, shape=(token_cnt, intermediate_size)) + idxs = topk_ids.argsort() + counts = topk_ids.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + num_per_tok = n_experts_per_token + token_idxs = idxs // num_per_tok + experts_intermeidate = [] + + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + expert_w13 = w13[expert_id, :, :] + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idxs] + expert_out = ref_gemm1(expert_tokens, expert_w13) + # expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + # expert_cache.scatter_reduce_( + # 0, + # exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), + # expert_out, + # reduce='sum' + # ) + experts_intermediate.append(expert_out) + + group_A = torch.stack(experts_intermediate, dim=0).contiguous() + output = torch.emtpy_like(x) + offset = tokens_per_expert + print(group_A.shape) + print(output.shape) + print(tokens_per_expert) + # gemm2(group_A, w2, output, offset) -def cutlass_fused_moe(): - pass From f76fb97625f20dfa1e4771dd92aa9cddf28225f0 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Sat, 30 Aug 2025 10:09:15 +0000 Subject: [PATCH 14/47] update grouped_gemm Signed-off-by: Ma, Liangliang --- csrc/quantization/grouped_gemm.h | 312 +++++++++++++++++++++++-------- 1 file changed, 238 insertions(+), 74 deletions(-) diff --git a/csrc/quantization/grouped_gemm.h b/csrc/quantization/grouped_gemm.h index 701d38b..309ba14 100644 --- a/csrc/quantization/grouped_gemm.h +++ b/csrc/quantization/grouped_gemm.h @@ -113,19 +113,26 @@ struct Options { int m, n, k, groups; std::vector problem_sizes_host; - int hidden_size; - int intermediate_size; + int* offset; int num_of_expert; - Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), - m(5120), n(4096), k(4096), groups(2) { - problem_sizes_host.reserve(groups); - for(int i = 0; i < groups; i++) { - problem_sizes_host.push_back({m, n, k}); + Options(int* offset, int N, int K, int ne): + num_of_expert(ne), n(N), k(K), error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100) { + int group_cnt = 0; + for (int i = 0; i < num_of_expert; ++i){ + if (offset[i] != 0){ + group_cnt++; + } } - } - + + problem_sizes_host.reserve(group_cnt); + for (int i = 0; i < num_of_expert; ++i){ + if (offset[i] != 0){ + problem_sizes_host.push_back({offset[i], n, k}); + } + } + groups = group_cnt; }; @@ -158,82 +165,238 @@ struct GroupedGemmRunner { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - StrideA stride_A; - StrideB stride_B; - StrideC stride_C; - StrideD stride_D; - StrideScaleA stride_scaleA; - StrideScaleB stride_scaleB; - - void initialize(const ProblemShapeType& problem_size) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - stride_A = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = - cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = - cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - stride_scaleA = cutlass::make_cute_packed_stride( - StrideScaleA{}, cute::make_shape(M, K, L)); - stride_scaleB = cutlass::make_cute_packed_stride( - StrideScaleB{}, cute::make_shape(N, K, L)); + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation problem_sizes; + + // This example defines all matrices in a single allocation (e.g. block_A), but this is not a + // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + cutlass::DeviceAllocation block_C; + // Note, this is an array of pointers to alpha and beta scaling values per group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + + /// Allocates device-side data +void allocate(const Options &options, int* offset) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.num_of_expert; ++i) { + if (offset[i] == 0){ + total_elements_B += options.n * options.k; + continue + } + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); } + block_C.reset(total_elements_C); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + + void initialize(const Options &options, ElementA * block_A, ElementB * block_B, + ElementD* block_D) { + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A + offset_A.at(i); + ptr_B_host.at(i) = block_B + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D + offset_D.at(i); + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back(static_cast((rand() % 5) + 1)); + beta_host.push_back(static_cast(rand() % 5)); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + + initialize_block(block_C, 666 + 2025); + // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers + // (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + } + + /// Populates a Gemm::Arguments structure from the given commandline options + typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + + return arguments; + } + cutlass::Status run( + const Options& options, sycl::queue* stream, const cutlass::KernelHardwareInfo& hw_info, ElementA* inputA, ElementB* inputB, ElementOffset* offset, - ElementOutput* res, - int64_t hidden_size, - int64_t intermediate_size, - int64_t num_of_expert) { - - Options options(offset, hidden_size, intermediate_size, num_of_expert); - - initialize(problem_size); - - typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {inputA, - stride_A, - inputB, - stride_B, - scaleA, - stride_scaleA, - scaleB, - stride_scaleB}, - {{args.alpha, args.beta}, nullptr, stride_C, res, stride_D}, - hw_info}; - + ElementOutput* res) { + + allocate(options, offset); + initialize(options, inputA, inputB, res); Gemm gemm_op; + auto arguments = args_from_options(options, hw_info, true); + size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - if (args.n < 16) { - std::cout - << "Invalid Problem Size: N must be >= 16 for FP8 input with F16 MMA (XE_8x16x16_F32F16F16F32_TT). Got N=" - << args.n << std::endl; - std::exit(1); - } + CUTLASS_CHECK(gemm_op.can_implement(arguments)); - if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { - std::cout << "Invalid Problem Size: " << args.m << 'x' << args.n << 'x' - << args.k << 'x' << args.l << std::endl; - std::exit(1); - } - - gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); // Run the GEMM - gemm_op.run(stream); + CUTLASS_CHECK(gemm_op.run()); + stream->throw_asynchronous(); @@ -253,7 +416,9 @@ void kernel_functor( // // Run examples // - + + auto offset_ptr = reinterpret_cast offset; + Options options(offset_ptr, hidden_size, intermediate_size, num_of_expert); // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This // information is used by the underlying kernel. cutlass::KernelHardwareInfo hw_info; @@ -335,15 +500,14 @@ void kernel_functor( GroupedGemmRunner runner; runner.run( + options, stream, hw_info, reinterpret_cast(input), reinterpret_cast(weight), reinterpret_cast(offset), - reinterpret_cast(res), - hidden_size, - intermediate_size, - num_of_expert); + reinterpret_cast(res) + ); } From 9408e946580f9f77928a25752945ea3922b24717 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Sun, 31 Aug 2025 13:51:31 +0000 Subject: [PATCH 15/47] build ready Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 2 +- csrc/core/registration.h | 4 ++- csrc/quantization/cutlass_kernels.cpp | 15 ++++----- csrc/quantization/grouped_gemm.h | 31 ++++++++++-------- tests/quantization/test_fused_moe.py | 2 +- vllm_xpu_kernels/fused_moe_interface.py | 42 ++++++++++++++++--------- 6 files changed, 58 insertions(+), 38 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 60945c6..56c0ae2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,7 +167,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "sycl-develop" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 5f9cdeb..e5386ea 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -1,5 +1,6 @@ #pragma once - +#pragma push_macro("printf") +#undef printf #include #define _CONCAT(A, B) A##B @@ -26,3 +27,4 @@ nullptr, nullptr, nullptr, nullptr}; \ return PyModule_Create(&module); \ } +#pragma pop_macro("printf") diff --git a/csrc/quantization/cutlass_kernels.cpp b/csrc/quantization/cutlass_kernels.cpp index 9d13c84..02f5bd4 100644 --- a/csrc/quantization/cutlass_kernels.cpp +++ b/csrc/quantization/cutlass_kernels.cpp @@ -1,15 +1,16 @@ -#include "grouped_gemm_fp8.h" -#include -#include -#include -#include +// #include +// #include +// #include + +// #include /* #include "pytorch_shim.h" */ #include "core/registration.h" #include #include "xpu/utils.h" +#include "grouped_gemm.h" namespace gpu::cutlass_kernel { @@ -25,7 +26,7 @@ at::Tensor grouped_gemm_func( int64_t num_of_expert ) { auto dpcpp_queue = vllm::xpu::vllmGetQueue(); - if (input.scalar_type() != at::at::kBFloat16) { + if (input.scalar_type() != at::kBFloat16) { std::cout << "error:wrong datatype, current only support bfloat16" << std::endl; return at::Tensor(); } @@ -46,7 +47,7 @@ at::Tensor grouped_gemm_func( } // namespace gpu::cutlass_kernel TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert) -> Tensor"); + ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int hidden_size, int intermediate_size, int num_of_expert) -> Tensor"); ops.impl("cutlass_grouped_gemm", torch::kXPU, gpu::cutlass_kernel::grouped_gemm_func); } diff --git a/csrc/quantization/grouped_gemm.h b/csrc/quantization/grouped_gemm.h index 309ba14..e36be61 100644 --- a/csrc/quantization/grouped_gemm.h +++ b/csrc/quantization/grouped_gemm.h @@ -66,6 +66,9 @@ To avoid register spills, build the example by setting the environment variable: $ export IGC_VectorAliasBBThreshold=10000 */ + +#pragma once + #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/collective/xe_array_epilogue.hpp" #include "cutlass/epilogue/fusion/xe_callbacks.hpp" @@ -114,10 +117,9 @@ struct Options { std::vector problem_sizes_host; - int* offset; int num_of_expert; - Options(int* offset, int N, int K, int ne): + Options(int64_t * offset, int N, int K, int ne): num_of_expert(ne), n(N), k(K), error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100) { int group_cnt = 0; for (int i = 0; i < num_of_expert; ++i){ @@ -129,19 +131,22 @@ struct Options { problem_sizes_host.reserve(group_cnt); for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ - problem_sizes_host.push_back({offset[i], n, k}); + problem_sizes_host.push_back({static_cast(offset[i]), n, k}); } } groups = group_cnt; + + } }; template struct GroupedGemmRunner { - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + using StrideScaleA = typename Gemm::GemmKernel::StrideA; using StrideScaleB = typename Gemm::GemmKernel::StrideB; @@ -160,8 +165,8 @@ struct GroupedGemmRunner { using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + using ElementAccumulator = ElementOutput; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -203,7 +208,7 @@ struct GroupedGemmRunner { cutlass::DeviceAllocation block_beta; /// Allocates device-side data -void allocate(const Options &options, int* offset) { +void allocate(const Options &options, int64_t* offset) { int64_t total_elements_A = 0; int64_t total_elements_B = 0; int64_t total_elements_C = 0; @@ -213,7 +218,7 @@ void allocate(const Options &options, int* offset) { for (int32_t i = 0; i < options.num_of_expert; ++i) { if (offset[i] == 0){ total_elements_B += options.n * options.k; - continue + continue; } auto problem = options.problem_sizes_host.at(i); @@ -248,7 +253,7 @@ void allocate(const Options &options, int* offset) { } void initialize(const Options &options, ElementA * block_A, ElementB * block_B, - ElementD* block_D) { + ElementOutput* block_D) { problem_sizes.reset(options.groups); problem_sizes.copy_from_host(options.problem_sizes_host.data()); @@ -417,7 +422,7 @@ void kernel_functor( // Run examples // - auto offset_ptr = reinterpret_cast offset; + auto offset_ptr = reinterpret_cast(offset); Options options(offset_ptr, hidden_size, intermediate_size, num_of_expert); // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This // information is used by the underlying kernel. diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py index e28abbc..250de63 100644 --- a/tests/quantization/test_fused_moe.py +++ b/tests/quantization/test_fused_moe.py @@ -56,7 +56,7 @@ def test_fused_moe( flat_expert_weights = expert_scores.view(-1, 1) cutlass_fused_moe(hidden_states=a, - w13=w1, + w13=w13, w2=w2, topk_weights=flat_expert_weights, topk_ids=flat_expert_indices, diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 00a79a7..38c4809 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -4,11 +4,11 @@ from typing import List import numpy -# from . import _vllm_fp8_C -def ref_gemm1(x, weight): - w1, w3 = torch.split(w13, list(weight.shape)[0]/2, dim=0) +from . import _vllm_fp8_C +def ref_gemm1(x, w13): + w1, w3 = torch.split(w13, int(list(w13.shape)[0]/2), dim=0) act_fn = torch.nn.SiLU() - gate = (x @ w1.T).silu() + gate = act_fn(x @ w1.T) up = x @ w3.T return gate * up @@ -17,22 +17,27 @@ def ref_gemm1(x, weight): def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, inplace, activation, num_experts): token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] - expert_cache = torch.empty_like(hidden_states, shape=(token_cnt, intermediate_size)) + expert_cache = torch.empty((token_cnt, intermediate_size), + dtype=hidden_states.dtype, + device=hidden_states.device) + idxs = topk_ids.argsort() - counts = topk_ids.bincount().cpu().numpy() + counts = topk_ids.to(torch.long).bincount().cpu().numpy() tokens_per_expert = counts.cumsum() num_per_tok = n_experts_per_token token_idxs = idxs // num_per_tok - experts_intermeidate = [] + experts_intermediate = [] + print("tokens_per_expert", tokens_per_expert) for expert_id, end_idx in enumerate(tokens_per_expert): + print("expert id: ", expert_id) start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] if start_idx == end_idx: continue expert_w13 = w13[expert_id, :, :] exp_token_idxs = token_idxs[start_idx:end_idx] - expert_tokens = x[exp_token_idxs] + expert_tokens = hidden_states[exp_token_idxs] expert_out = ref_gemm1(expert_tokens, expert_w13) # expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) # expert_cache.scatter_reduce_( @@ -43,11 +48,18 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # ) experts_intermediate.append(expert_out) - group_A = torch.stack(experts_intermediate, dim=0).contiguous() - output = torch.emtpy_like(x) - offset = tokens_per_expert - print(group_A.shape) - print(output.shape) - print(tokens_per_expert) - # gemm2(group_A, w2, output, offset) + group_A = torch.cat(experts_intermediate, dim=0).contiguous() + output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) + offset = torch.tensor(tokens_per_expert, dtype=torch.int64, device='xpu') + print("groupA [num_tokens, inter]:", group_A.shape) + print("weight2 [expert, hidden, inter(ld)]", w2.shape) + print("output [num_tokens, hidden]", output.shape) + print("offset [expert]", offset.shape) + # cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert) -> Tensor + print("enter kernel") + + + offset = torch.tensor([6, 12, 8, 4, 3, 6, 12, 8, 4, 3, 0, 0, 0, 0, 0, 0] ,dtype=torch.int64, device='xpu' ) + hidden_states = torch.ops._vllm_fp8_C.cutlass_grouped_gemm(group_A, w2, output, offset, hidden_size, intermediate_size, num_experts) + print(hidden_states) From 439cf3cbb616874bc372f89ecdf932c82f433c07 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Mon, 1 Sep 2025 07:48:09 +0000 Subject: [PATCH 16/47] base integration done Signed-off-by: Ma, Liangliang --- csrc/quantization/grouped_gemm.h | 18 ++++++++++-------- tests/quantization/test_fused_moe.py | 2 +- vllm_xpu_kernels/fused_moe_interface.py | 12 +++++++++--- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/csrc/quantization/grouped_gemm.h b/csrc/quantization/grouped_gemm.h index e36be61..c293703 100644 --- a/csrc/quantization/grouped_gemm.h +++ b/csrc/quantization/grouped_gemm.h @@ -121,13 +121,14 @@ struct Options { Options(int64_t * offset, int N, int K, int ne): num_of_expert(ne), n(N), k(K), error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100) { + std::cout << "init options" << std::endl; int group_cnt = 0; for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ group_cnt++; } } - + std::cout << "group_cnt: " << group_cnt << std::endl; problem_sizes_host.reserve(group_cnt); for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ @@ -135,7 +136,7 @@ struct Options { } } groups = group_cnt; - + std::cout << "finish options" << std::endl; } }; @@ -193,7 +194,6 @@ struct GroupedGemmRunner { cutlass::DeviceAllocation ptr_B; cutlass::DeviceAllocation ptr_C; cutlass::DeviceAllocation ptr_D; - cutlass::DeviceAllocation ptr_ref_D; cutlass::DeviceAllocation stride_A; cutlass::DeviceAllocation stride_B; @@ -271,8 +271,8 @@ void allocate(const Options &options, int64_t* offset) { ptr_C_host.at(i) = block_C.get() + offset_C.at(i); ptr_D_host.at(i) = block_D + offset_D.at(i); // Fill host vector of alpha & beta with random values if using per-group values - alpha_host.push_back(static_cast((rand() % 5) + 1)); - beta_host.push_back(static_cast(rand() % 5)); + alpha_host.push_back(static_cast(1)); + beta_host.push_back(static_cast(0)); // Fill host ptr vectors with offset addresses into device alpha/beta blocks ptr_alpha_host.at(i) = block_alpha.get() + i; ptr_beta_host.at(i) = block_beta.get() + i; @@ -311,7 +311,7 @@ void allocate(const Options &options, int64_t* offset) { beta_device.copy_from_host(ptr_beta_host.data()); - initialize_block(block_C, 666 + 2025); + initialize_block(block_C, 0); // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers // (alpha_device/beta_device) are passed instead block_alpha.copy_from_host(alpha_host.data()); @@ -385,6 +385,7 @@ void allocate(const Options &options, int64_t* offset) { ElementB* inputB, ElementOffset* offset, ElementOutput* res) { + std::cout << "enter run" << std::endl; allocate(options, offset); initialize(options, inputA, inputB, res); @@ -398,7 +399,8 @@ void allocate(const Options &options, int64_t* offset) { CUTLASS_CHECK(gemm_op.can_implement(arguments)); CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); - + + std::cout << "before run kernel" << std::endl; // Run the GEMM CUTLASS_CHECK(gemm_op.run()); @@ -421,7 +423,7 @@ void kernel_functor( // // Run examples // - + std::cout << "enter functor" << std::endl; auto offset_ptr = reinterpret_cast(offset); Options options(offset_ptr, hidden_size, intermediate_size, num_of_expert); // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py index 250de63..c248fa6 100644 --- a/tests/quantization/test_fused_moe.py +++ b/tests/quantization/test_fused_moe.py @@ -65,7 +65,7 @@ def test_fused_moe( activation="silu", num_experts=e) - print("result", hidden_states) + # print("result", a, a.shape) if __name__ == "__main__": test_fused_moe( diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 38c4809..d9ab659 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -58,8 +58,14 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert) -> Tensor print("enter kernel") - - offset = torch.tensor([6, 12, 8, 4, 3, 6, 12, 8, 4, 3, 0, 0, 0, 0, 0, 0] ,dtype=torch.int64, device='xpu' ) + ## tests only + num_experts = 2 + hidden_size = 4096 + intermediate_size = 4096 + group_A = torch.ones((2048, intermediate_size), dtype=torch.bfloat16, device="xpu") + w2 = torch.ones((num_experts, hidden_size, intermediate_size), dtype=torch.bfloat16, device="xpu") + output = torch.zeros((2048, hidden_size), dtype=torch.float32, device="xpu") + offset = torch.tensor([1024, 1024] ,dtype=torch.int64, device="cpu" ) hidden_states = torch.ops._vllm_fp8_C.cutlass_grouped_gemm(group_A, w2, output, offset, hidden_size, intermediate_size, num_experts) - print(hidden_states) + print(hidden_states, hidden_states.shape) From 48abd9f55104e631a96cc5faed7d107a27cae23f Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 2 Sep 2025 09:04:46 +0000 Subject: [PATCH 17/47] grouped gemm base ready Signed-off-by: Ma, Liangliang --- csrc/quantization/grouped_gemm.h | 32 +++++++++++++++------- tests/quantization/test_fused_moe.py | 35 ++++++++++++++++++++++++- vllm_xpu_kernels/__init__.py | 2 +- vllm_xpu_kernels/fused_moe_interface.py | 24 +++++++---------- 4 files changed, 66 insertions(+), 27 deletions(-) diff --git a/csrc/quantization/grouped_gemm.h b/csrc/quantization/grouped_gemm.h index c293703..568a0cc 100644 --- a/csrc/quantization/grouped_gemm.h +++ b/csrc/quantization/grouped_gemm.h @@ -99,7 +99,7 @@ using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementA = bfloat16_t; // <- data type of elements in input matrix A using ElementB = bfloat16_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D - +bool debug = false; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,14 +121,15 @@ struct Options { Options(int64_t * offset, int N, int K, int ne): num_of_expert(ne), n(N), k(K), error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100) { - std::cout << "init options" << std::endl; + if (debug) { + std::cout << "Options()" << std::endl; + } int group_cnt = 0; for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ group_cnt++; } } - std::cout << "group_cnt: " << group_cnt << std::endl; problem_sizes_host.reserve(group_cnt); for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ @@ -136,7 +137,6 @@ struct Options { } } groups = group_cnt; - std::cout << "finish options" << std::endl; } }; @@ -209,17 +209,23 @@ struct GroupedGemmRunner { /// Allocates device-side data void allocate(const Options &options, int64_t* offset) { + if (debug){ + std::cout << "void allocate()" << std::endl; + } int64_t total_elements_A = 0; int64_t total_elements_B = 0; int64_t total_elements_C = 0; int64_t total_elements_D = 0; - + + int offset_iter = 0; // Compute total allocation sizes across group - for (int32_t i = 0; i < options.num_of_expert; ++i) { - if (offset[i] == 0){ + for (int32_t i = 0; i < options.groups; ++i) { + while (offset[offset_iter] == 0){ total_elements_B += options.n * options.k; + offset_iter++; continue; } + offset_iter++; auto problem = options.problem_sizes_host.at(i); auto M = get<0>(problem); @@ -254,6 +260,9 @@ void allocate(const Options &options, int64_t* offset) { void initialize(const Options &options, ElementA * block_A, ElementB * block_B, ElementOutput* block_D) { + if (debug){ + std::cout << "void initialize()" << std::endl; + } problem_sizes.reset(options.groups); problem_sizes.copy_from_host(options.problem_sizes_host.data()); @@ -385,7 +394,9 @@ void allocate(const Options &options, int64_t* offset) { ElementB* inputB, ElementOffset* offset, ElementOutput* res) { - std::cout << "enter run" << std::endl; + if (debug){ + std::cout << "enter run" << std::endl; + } allocate(options, offset); initialize(options, inputA, inputB, res); @@ -400,7 +411,9 @@ void allocate(const Options &options, int64_t* offset) { CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); - std::cout << "before run kernel" << std::endl; + if (debug){ + std::cout << "before run kernel" << std::endl; + } // Run the GEMM CUTLASS_CHECK(gemm_op.run()); @@ -423,7 +436,6 @@ void kernel_functor( // // Run examples // - std::cout << "enter functor" << std::endl; auto offset_ptr = reinterpret_cast(offset); Options options(offset_ptr, hidden_size, intermediate_size, num_of_expert); // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py index c248fa6..e047645 100644 --- a/tests/quantization/test_fused_moe.py +++ b/tests/quantization/test_fused_moe.py @@ -2,7 +2,7 @@ import torch from math import ceil from typing import Callable, Optional, Union -from vllm_xpu_kernels.fused_moe_interface import cutlass_fused_moe +from vllm_xpu_kernels.fused_moe_interface import cutlass_fused_moe, cutlass_grouped_gemm NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] @@ -28,6 +28,38 @@ DEVICE = "xpu" +def test_grouped_gemm(num_experts, n, k, token_per_group): + # input + input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu") + + # weight + input_B = torch.randn((num_experts, n, k), dtype=torch.bfloat16, device="xpu") + input_B = input_B.transpose(-1, -2).contiguous().transpose(-1, -2) + + # output offset + output = torch.empty((sum(token_per_group), n), dtype=torch.float32, device="xpu") + offset = torch.tensor(token_per_group, dtype=torch.int64, device="cpu" ) + + cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) + + # ref gg + ref = [] + pre_token_sum = 0 + for i in range(num_experts): + cur_token_num = token_per_group[i] + if cur_token_num == 0: + continue + input = input_A[pre_token_sum:pre_token_sum + cur_token_num, :] + weight = input_B[i, :, :] + expert_output = input @ weight.T + ref.append(expert_output) + pre_token_sum += cur_token_num + ref = torch.cat(ref, dim=0).float() + + print(torch.allclose(output, ref, rtol=1, atol=1)) + max_diff = (output - ref).abs().max() + print("Max absolute difference:", max_diff) + # @pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) @@ -77,3 +109,4 @@ def test_fused_moe( ep_size = 1, dtype = torch.bfloat16 ) + # test_grouped_gemm(num_experts=16, n=5120, k=8192, token_per_group=[1,2,6,8,12,0,1,5,1,2,6,8,12,0,1,5]) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py index e6fe0b1..293c6e5 100644 --- a/vllm_xpu_kernels/__init__.py +++ b/vllm_xpu_kernels/__init__.py @@ -1,4 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 -from .fused_moe_interface import cutlass_fused_moe +from .fused_moe_interface import cutlass_fused_moe, cutlass_grouped_gemm diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index d9ab659..cb01d8a 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -5,6 +5,10 @@ import numpy from . import _vllm_fp8_C + +def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): + torch.ops._vllm_fp8_C.cutlass_grouped_gemm(input_A, w2, output, offset, n, k, num_experts) + def ref_gemm1(x, w13): w1, w3 = torch.split(w13, int(list(w13.shape)[0]/2), dim=0) act_fn = torch.nn.SiLU() @@ -12,8 +16,6 @@ def ref_gemm1(x, w13): up = x @ w3.T return gate * up - - def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, inplace, activation, num_experts): token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] @@ -28,10 +30,11 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ token_idxs = idxs // num_per_tok experts_intermediate = [] print("tokens_per_expert", tokens_per_expert) - + offset = [] for expert_id, end_idx in enumerate(tokens_per_expert): print("expert id: ", expert_id) start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + offset.append(end_idx - start_idx) if start_idx == end_idx: continue @@ -49,8 +52,9 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ experts_intermediate.append(expert_out) group_A = torch.cat(experts_intermediate, dim=0).contiguous() - output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) - offset = torch.tensor(tokens_per_expert, dtype=torch.int64, device='xpu') + output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) + # offset = torch.tensor(tokens_per_expert, dtype=torch.int64, device='xpu') + offset = torch.tensor(offset, dtype=torch.int64, device='cpu') print("groupA [num_tokens, inter]:", group_A.shape) print("weight2 [expert, hidden, inter(ld)]", w2.shape) print("output [num_tokens, hidden]", output.shape) @@ -58,14 +62,4 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert) -> Tensor print("enter kernel") - ## tests only - num_experts = 2 - hidden_size = 4096 - intermediate_size = 4096 - group_A = torch.ones((2048, intermediate_size), dtype=torch.bfloat16, device="xpu") - w2 = torch.ones((num_experts, hidden_size, intermediate_size), dtype=torch.bfloat16, device="xpu") - output = torch.zeros((2048, hidden_size), dtype=torch.float32, device="xpu") - offset = torch.tensor([1024, 1024] ,dtype=torch.int64, device="cpu" ) - hidden_states = torch.ops._vllm_fp8_C.cutlass_grouped_gemm(group_A, w2, output, offset, hidden_size, intermediate_size, num_experts) - print(hidden_states, hidden_states.shape) From 67eeb47af1b1b44f5337b2173326c68c5b66bfc1 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 3 Sep 2025 12:05:14 +0000 Subject: [PATCH 18/47] gemm2 use cutlass grouped_mm Signed-off-by: Ma, Liangliang --- tests/quantization/test_fused_moe.py | 84 +++++++++++++++++++++---- vllm_xpu_kernels/fused_moe_interface.py | 37 ++++++++--- 2 files changed, 100 insertions(+), 21 deletions(-) diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py index e047645..7844954 100644 --- a/tests/quantization/test_fused_moe.py +++ b/tests/quantization/test_fused_moe.py @@ -60,6 +60,45 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): max_diff = (output - ref).abs().max() print("Max absolute difference:", max_diff) +def ref_fused_moe(x, + w13, + w2, + flat_expert_weights, + flat_expert_indices, + num_per_tok, + activation, + num_experts): + + expert_cache = torch.zeros_like(x).float() + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + token_idxs = idxs // num_per_tok + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idxs] + + expert_w13 = w13[expert_id, :, :] + w1, w3 = torch.split(expert_w13, int(list(expert_w13.shape)[0]/2), dim=0) + act_fn = torch.nn.SiLU() + gate = act_fn(expert_tokens @ w1.T) + up = expert_tokens @ w3.T + expert_out = (gate * up) @ w2[expert_id, :, :].T + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache.scatter_reduce_( + 0, + exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), + expert_out.float(), + reduce='sum' + ) + + return expert_cache + + # @pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) @@ -75,7 +114,7 @@ def test_fused_moe( dtype: torch.dtype, ): # todo: seed - + verbose = False # Setup test data a = torch.randn((m, k), device=DEVICE, dtype=dtype) / 10 w13 = torch.randn((e, 2 * n, k), device=DEVICE, dtype=dtype) / 10 @@ -83,21 +122,40 @@ def test_fused_moe( # moe gate scores = torch.randn((m, e), device=DEVICE, dtype=dtype) - expert_indices, expert_scores = torch.topk(scores, k=topk, dim=-1, sorted=False) + expert_scores, expert_indices = torch.topk(scores, k=topk, dim=-1, sorted=False) + + if verbose: + print("expert_indices: ", expert_indices, expert_indices.shape) + print("expert_scores: ", expert_scores, expert_scores.shape) + flat_expert_indices = expert_indices.view(-1) flat_expert_weights = expert_scores.view(-1, 1) - cutlass_fused_moe(hidden_states=a, - w13=w13, - w2=w2, - topk_weights=flat_expert_weights, - topk_ids=flat_expert_indices, - n_experts_per_token=topk, - inplace=True, - activation="silu", - num_experts=e) - - # print("result", a, a.shape) + out = cutlass_fused_moe(hidden_states=a, + w13=w13, + w2=w2, + topk_weights=flat_expert_weights, + topk_ids=flat_expert_indices, + n_experts_per_token=topk, + activation="silu", + num_experts=e) + + ref_out = ref_fused_moe(a, + w13, + w2, + flat_expert_weights, + flat_expert_indices, + topk, + "silu", + e) + + print("ref result", ref_out, ref_out.shape) + print("kernel result", out, out.shape) + print(torch.allclose(out, ref_out, rtol=1, atol=1)) + max_diff = (out - ref_out).abs().max() + print("Max absolute difference:", max_diff) + + if __name__ == "__main__": test_fused_moe( diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index cb01d8a..4bb2d42 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -7,7 +7,7 @@ from . import _vllm_fp8_C def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): - torch.ops._vllm_fp8_C.cutlass_grouped_gemm(input_A, w2, output, offset, n, k, num_experts) + torch.ops._vllm_fp8_C.cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) def ref_gemm1(x, w13): w1, w3 = torch.split(w13, int(list(w13.shape)[0]/2), dim=0) @@ -16,11 +16,11 @@ def ref_gemm1(x, w13): up = x @ w3.T return gate * up -def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, inplace, activation, num_experts): +def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, activation, num_experts): token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] - expert_cache = torch.empty((token_cnt, intermediate_size), - dtype=hidden_states.dtype, + expert_cache = torch.empty((token_cnt, hidden_size), + dtype=torch.float32, device=hidden_states.device) idxs = topk_ids.argsort() @@ -32,7 +32,6 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ print("tokens_per_expert", tokens_per_expert) offset = [] for expert_id, end_idx in enumerate(tokens_per_expert): - print("expert id: ", expert_id) start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] offset.append(end_idx - start_idx) if start_idx == end_idx: @@ -58,8 +57,30 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ print("groupA [num_tokens, inter]:", group_A.shape) print("weight2 [expert, hidden, inter(ld)]", w2.shape) print("output [num_tokens, hidden]", output.shape) - print("offset [expert]", offset.shape) - # cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert) -> Tensor - print("enter kernel") + print("offset [expert]", offset, offset.shape) + w2 = w2.transpose(-1, -2).contiguous().transpose(-1, -2) + cutlass_grouped_gemm(input_A=group_A, + input_B=w2, + output=output, + offset=offset, + n=hidden_size, + k=intermediate_size, + num_experts=num_experts) + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = hidden_states[exp_token_idxs] + expert_out = output[start_idx:end_idx] + expert_out.mul_(topk_weights[idxs[start_idx:end_idx]]) + expert_cache.scatter_reduce_( + 0, + exp_token_idxs.view(-1, 1).repeat(1, hidden_size), + expert_out, + reduce='sum' + ) + return expert_cache From a62752fe8cf93eb10ed45db758b04e29bbd1b2ff Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 3 Sep 2025 12:55:59 +0000 Subject: [PATCH 19/47] gemm1 use cutlass group_mm Signed-off-by: Ma, Liangliang --- tests/quantization/test_fused_moe.py | 3 +- vllm_xpu_kernels/fused_moe_interface.py | 54 ++++++++++++------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/tests/quantization/test_fused_moe.py b/tests/quantization/test_fused_moe.py index 7844954..79c9266 100644 --- a/tests/quantization/test_fused_moe.py +++ b/tests/quantization/test_fused_moe.py @@ -85,7 +85,8 @@ def ref_fused_moe(x, expert_w13 = w13[expert_id, :, :] w1, w3 = torch.split(expert_w13, int(list(expert_w13.shape)[0]/2), dim=0) act_fn = torch.nn.SiLU() - gate = act_fn(expert_tokens @ w1.T) + gemm1 = expert_tokens @ w1.T + gate = act_fn(gemm1) up = expert_tokens @ w3.T expert_out = (gate * up) @ w2[expert_id, :, :].T expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 4bb2d42..2d209e7 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -9,13 +9,6 @@ def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): torch.ops._vllm_fp8_C.cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) -def ref_gemm1(x, w13): - w1, w3 = torch.split(w13, int(list(w13.shape)[0]/2), dim=0) - act_fn = torch.nn.SiLU() - gate = act_fn(x @ w1.T) - up = x @ w3.T - return gate * up - def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, activation, num_experts): token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] @@ -23,41 +16,47 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ dtype=torch.float32, device=hidden_states.device) + # map token to experts idxs = topk_ids.argsort() counts = topk_ids.to(torch.long).bincount().cpu().numpy() tokens_per_expert = counts.cumsum() num_per_tok = n_experts_per_token token_idxs = idxs // num_per_tok - experts_intermediate = [] - print("tokens_per_expert", tokens_per_expert) + grouped_input_A = [] offset = [] for expert_id, end_idx in enumerate(tokens_per_expert): start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] offset.append(end_idx - start_idx) if start_idx == end_idx: continue - - expert_w13 = w13[expert_id, :, :] exp_token_idxs = token_idxs[start_idx:end_idx] expert_tokens = hidden_states[exp_token_idxs] - expert_out = ref_gemm1(expert_tokens, expert_w13) - # expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) - # expert_cache.scatter_reduce_( - # 0, - # exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), - # expert_out, - # reduce='sum' - # ) - experts_intermediate.append(expert_out) + grouped_input_A.append(expert_tokens) - group_A = torch.cat(experts_intermediate, dim=0).contiguous() - output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) - # offset = torch.tensor(tokens_per_expert, dtype=torch.int64, device='xpu') + total_input_size = token_cnt * num_per_tok + + # gemm1 + input_A = torch.cat(grouped_input_A, dim=0).contiguous() + input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) + assert(list(input_A.shape)[0] == total_input_size) + gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=torch.float32,device=hidden_states.device) offset = torch.tensor(offset, dtype=torch.int64, device='cpu') - print("groupA [num_tokens, inter]:", group_A.shape) - print("weight2 [expert, hidden, inter(ld)]", w2.shape) - print("output [num_tokens, hidden]", output.shape) - print("offset [expert]", offset, offset.shape) + cutlass_grouped_gemm(input_A=input_A, + input_B=input_B, + output=gemm1_output, + offset=offset, + n=2*intermediate_size, + k=hidden_size, + num_experts=num_experts) + # act + gate, up = torch.split(gemm1_output, intermediate_size, dim=1) + act = torch.nn.SiLU() + act_output = act(gate) * up + + + # gemm 2 + group_A = act_output.to(torch.bfloat16).contiguous() + output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) w2 = w2.transpose(-1, -2).contiguous().transpose(-1, -2) cutlass_grouped_gemm(input_A=group_A, input_B=w2, @@ -67,6 +66,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ k=intermediate_size, num_experts=num_experts) + # apply scores for expert_id, end_idx in enumerate(tokens_per_expert): start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] if start_idx == end_idx: From cfb724b720d13eceb4038b91e7ebb05edd1eb3db Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 4 Sep 2025 06:27:13 +0000 Subject: [PATCH 20/47] rm flash_attn in this pr Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 38 +- .../cutlass_kernels.cpp | 0 .../grouped_gemm.h | 0 .../helper.h | 0 .../sycl_common.hpp | 0 csrc/flash_attn/flash_api.cpp | 82 --- csrc/flash_attn/pytorch_shim.h | 110 ---- csrc/xpu/cutlass_kernels/chunk_prefill.hpp | 331 ----------- csrc/xpu/cutlass_kernels/utils.hpp | 49 -- csrc/xpu/cutlass_sycl_demo.cpp | 520 ------------------ csrc/xpu/helper.h | 127 ----- csrc/xpu/mha.h | 16 - csrc/xpu/ops.h | 2 - csrc/xpu/torch_bindings.cpp | 3 - .../test_fused_moe.py | 0 tests/flash_attn/test.py | 38 -- .../flash_attn/test_flash_attn_varlen_func.py | 35 -- tests/test_cutlass_op.py | 17 - vllm_xpu_kernels/__init__.py | 1 - vllm_xpu_kernels/flash_attn_interface.py | 107 ---- 20 files changed, 1 insertion(+), 1475 deletions(-) rename csrc/{quantization => cutlass_backend}/cutlass_kernels.cpp (100%) rename csrc/{quantization => cutlass_backend}/grouped_gemm.h (100%) rename csrc/{quantization => cutlass_backend}/helper.h (100%) rename csrc/{quantization => cutlass_backend}/sycl_common.hpp (100%) delete mode 100644 csrc/flash_attn/flash_api.cpp delete mode 100644 csrc/flash_attn/pytorch_shim.h delete mode 100644 csrc/xpu/cutlass_kernels/chunk_prefill.hpp delete mode 100644 csrc/xpu/cutlass_kernels/utils.hpp delete mode 100644 csrc/xpu/cutlass_sycl_demo.cpp delete mode 100644 csrc/xpu/helper.h delete mode 100644 csrc/xpu/mha.h rename tests/{quantization => cutlass}/test_fused_moe.py (100%) delete mode 100644 tests/flash_attn/test.py delete mode 100644 tests/flash_attn/test_flash_attn_varlen_func.py delete mode 100644 tests/test_cutlass_op.py delete mode 100644 vllm_xpu_kernels/flash_attn_interface.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 56c0ae2..98dbc64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -148,7 +148,6 @@ endif() if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_SRC - "csrc/xpu/cutlass_sycl_demo.cpp" "csrc/xpu/layernorm.cpp" "csrc/xpu/torch_bindings.cpp" ) @@ -225,44 +224,9 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -# -# flash attention _C extension -# - -if (FA2_ENABLED) - message(STATUS "Enabling fa2 extension.") - file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp") - - # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") - list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") - list(APPEND VLLM_GPU_FLAGS "-O3") - list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") - - define_gpu_extension_target( - _vllm_fa2_C - DESTINATION vllm_xpu_kernels - LANGUAGE ${VLLM_GPU_LANG} - SOURCES - csrc/flash_attn/flash_api.cpp - ${FA2_GEN_SRCS} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} - USE_SABI 3 - WITH_SOABI) - - # target_include_directories(_vllm_fa2_C PRIVATE - # csrc/flash_attn - # csrc/flash_attn/src) -endif () - if (FP8_ENABLED) message(STATUS "Enabling FP8 extension.") - file(GLOB FP8_GEN_SRCS "csrc/quantization/*.cpp") + file(GLOB FP8_GEN_SRCS "csrc/cutlass_backend/*.cpp") # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") diff --git a/csrc/quantization/cutlass_kernels.cpp b/csrc/cutlass_backend/cutlass_kernels.cpp similarity index 100% rename from csrc/quantization/cutlass_kernels.cpp rename to csrc/cutlass_backend/cutlass_kernels.cpp diff --git a/csrc/quantization/grouped_gemm.h b/csrc/cutlass_backend/grouped_gemm.h similarity index 100% rename from csrc/quantization/grouped_gemm.h rename to csrc/cutlass_backend/grouped_gemm.h diff --git a/csrc/quantization/helper.h b/csrc/cutlass_backend/helper.h similarity index 100% rename from csrc/quantization/helper.h rename to csrc/cutlass_backend/helper.h diff --git a/csrc/quantization/sycl_common.hpp b/csrc/cutlass_backend/sycl_common.hpp similarity index 100% rename from csrc/quantization/sycl_common.hpp rename to csrc/cutlass_backend/sycl_common.hpp diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp deleted file mode 100644 index 4a15d5e..0000000 --- a/csrc/flash_attn/flash_api.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "pytorch_shim.h" - -#include "core/registration.h" -#include "xpu/cutlass_kernels/chunk_prefill.hpp" -#include - -namespace FLASH_NAMESPACE { - -std::vector mha_varlen_fwd( - const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := - // \sum_{i=0}^{b} s_i or num_blocks x page_block_size - // x num_heads_k x head_size if there's a block_table. - const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := - // \sum_{i=0}^{b} s_i or num_blocks x page_block_size - // x num_heads_k x head_size if there's a block_table. - std::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - std::optional& seqused_k, // b. If given, only this many elements of each batch - // element's keys are used. - std::optional& leftpad_k_, // batch_size - at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq - std::optional& alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, - int max_seqlen_k, - float p_dropout, - float softmax_scale, - const bool zero_tensors, - bool is_causal, - int window_size_left, - int window_size_right, - const float softcap, - const bool return_softmax, std::optional gen_) { - at::Tensor out; - if(out_.has_value()) { - out = *out_; - } - else { - out = torch::zeros_like(q); - } - - cutlass_chunk_prefill_impl( - q, - k, - v, - out, - block_table_, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - is_causal); - - if(return_softmax) { - auto softmax_lse = torch::empty_like(out); - return {out, softmax_lse}; - } - else { - at::Tensor softmax_lse; - return {out, softmax_lse}; - } -} -} // namespace FLASH_NAMESPACE - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - ops.def( - "varlen_fwd(Tensor q, Tensor k, Tensor v, Tensor!? out, Tensor " - "cu_seqlens_q, " - "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor " - "block_table, Tensor? alibi_slopes, " - "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " - "softmax_scale, bool zero_tensors, " - "bool is_causal, int window_size_left, int window_size_right, float " - "softcap, bool return_softmax, " - "Generator? gen) -> Tensor[]"); - ops.impl("varlen_fwd", torch::kXPU, - make_pytorch_shim(&FLASH_NAMESPACE::mha_varlen_fwd)); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/flash_attn/pytorch_shim.h b/csrc/flash_attn/pytorch_shim.h deleted file mode 100644 index 82f5f5c..0000000 --- a/csrc/flash_attn/pytorch_shim.h +++ /dev/null @@ -1,110 +0,0 @@ -#pragma once - -#include - -/** - * Unfortunately, the type signatures of the flash_attn ops are not compatible - * with the PyTorch library bindings. To get around that we use - * `make_pytorch_shim` which creates a lambda that exponses the API using - * PyTorch compatible types to the types, then converts them to the types - * expected by the flash_attn ops. This shims allows us to make minimal changes - * to `flash_api.cpp` making it easier to synchronize with upstream changes. - * - * The `pytorch_library_compatible_type` struct is used to map from the - * flash_attn ops types to a PyTorch library compatible one. The main issues is - * that the following types are not support by PyTorch library bindings: - * - `int` - * - `float` - * - `std::optional &` - * - `std::optional &` - * So we convert them to (respectively): - * - `int64_t` - * - `double` - * - `const std::optional&` - * - `const std::optional&` - */ - -template -struct pytorch_library_compatible_type { - using type = T; - static T convert_from_type(T arg) { return arg; } -}; - -template -using pytorch_library_compatible_type_t = - typename pytorch_library_compatible_type::type; - -template -T convert_from_pytorch_compatible_type( - pytorch_library_compatible_type_t arg) { - return pytorch_library_compatible_type::convert_from_type(arg); -} - -// Map `std::optional &` -> `const std::optional&` -// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate -// the optional container) -template -struct pytorch_library_compatible_type&> { - using type = const std::optional&; - static std::optional& convert_from_type(const std::optional& arg) { - return const_cast&>(arg); - } -}; - -// Map `std::optional` -> -// `std::optional>` -// (NOTE: tested for `std::optional` -> `std::optional`) -template -struct pytorch_library_compatible_type> { - using type = std::optional>; - static std::optional> convert_from_type( - std::optional arg) { - return arg; - } -}; - -// Map `std::optional&` -> `const std::optional&` -template <> -struct pytorch_library_compatible_type&> { - using type = const std::optional&; - static std::optional& convert_from_type( - const std::optional& arg) { - return const_cast&>( - reinterpret_cast&>(arg)); - } -}; - -// Map `int` -> `int64_t` -template <> -struct pytorch_library_compatible_type { - using type = int64_t; - static int convert_from_type(int64_t arg) { - TORCH_CHECK(arg <= std::numeric_limits::max(), - "int64_t value is too large to be converted to int"); - TORCH_CHECK(arg >= std::numeric_limits::min(), - "int64_t value is too small to be converted to int"); - return arg; - } -}; - -// Map `float` -> `double` -template <> -struct pytorch_library_compatible_type { - using type = double; - static float convert_from_type(double arg) { - TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), - "double value is too large to be converted to float"); - return arg; - } -}; - -// -// Shim Utils -// - -template -auto make_pytorch_shim(Ret (*fun)(Args... args)) { - return [fun](pytorch_library_compatible_type_t... args) { - return fun(convert_from_pytorch_compatible_type(args)...); - }; -} diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill.hpp deleted file mode 100644 index c96b8dd..0000000 --- a/csrc/xpu/cutlass_kernels/chunk_prefill.hpp +++ /dev/null @@ -1,331 +0,0 @@ -#pragma once - -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "flash_attention_v2/collective/fmha_fusion.hpp" -#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" -#include "cutlass/util/packed_stride.hpp" -#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" -#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" -#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" -#include "cutlass/util/GPU_Clock.hpp" -#include "cutlass/util/sycl_event_manager.hpp" - -#include -#include - -#include "cutlass/util/command_line.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/device/gemm_complex.h" -#include "cutlass/util/reference/device/tensor_compare.h" - -#include "utils.hpp" - -using namespace cute; - -struct chunk_prefill_args_t { - void* query; - void* key; - void* value; - void* out; - void* block_table; - void* num_blocks_per_seq; - void* cu_seqlens_q; - void* cu_seqlens_k; - int max_queries; - int max_keys; - int total_seqlen_q; - int total_seqlen_k; - float sm_scale; - int batch_size; - int num_heads_q; - int num_heads_k; - int head_size; - int max_blocks_per_seq; - int block_size; - bool is_causal; -}; - -template struct KernelLauncher { - using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; - using StrideK = typename FMHAChunkPrefillKernel::StrideK; - using StrideV = typename FMHAChunkPrefillKernel::StrideV; - using StrideO = typename FMHAChunkPrefillKernel::StrideO; - - using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; - using ElementK = typename FMHAChunkPrefillKernel::ElementK; - using ElementV = typename FMHAChunkPrefillKernel::ElementV; - using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; - - using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; - using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; - - using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; - - /// Initialization - StrideQ stride_Q; - StrideK stride_K; - StrideV stride_V; - StrideK stride_K_cache; - StrideV stride_V_cache; - StrideO stride_O; - uint64_t seed = 0; - - ProblemShapeType initialize(const chunk_prefill_args_t &args) { - auto problem_shape = cute::make_tuple( - 1, - args.num_heads_q, - args.num_heads_k, - args.total_seqlen_q, - args.total_seqlen_k, - args.total_seqlen_k, - args.head_size, - args.head_size); - auto problem_shape_out = cute::make_tuple( - args.batch_size, - args.num_heads_q, - args.num_heads_k, - cutlass::fmha::collective::VariableLength{args.max_queries}, // cu_q - cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_k - cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_v - args.head_size, - args.head_size); - auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; - auto group_q_size = num_heads_q / num_heads_kv; - auto group_q_num = num_heads_q / group_q_size; - - stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); - stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); - stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); - - stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); - stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv)); - - stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); - - get<3>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_q); - get<4>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); - get<5>(problem_shape_out).cumulative_length = reinterpret_cast(args.cu_seqlens_k); - - return problem_shape_out; - } - - cutlass::Status run(const chunk_prefill_args_t &args, const cutlass::KernelHardwareInfo &hw_info) { - ProblemShapeType problem_size = initialize(args); - - typename FMHAChunkPrefillKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - { - reinterpret_cast(args.query), stride_Q, - reinterpret_cast(args.key), stride_K, - reinterpret_cast(args.value), stride_V, - reinterpret_cast(args.key), stride_K_cache, - reinterpret_cast(args.value), stride_V_cache, - static_cast(args.block_table), - args.block_size, - static_cast(args.num_blocks_per_seq) - }, - {args.sm_scale}, - {reinterpret_cast(args.out), stride_O}, - hw_info}; - - // Define device-global scratch memory - size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - if (!FMHAChunkPrefillKernel::can_implement(arguments)) { - std::cout << "Invalid Problem Size: " << std::endl; - return cutlass::Status::kErrorInvalidProblem; - } - - // Initialize the workspace - FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()); - - // Convert host-side arguments to device-side arguments to be passed to the kernel - auto params = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); - - // Run the Flash Attention implementation. - run(params); - - syclcompat::wait(); - return cutlass::Status::kSuccess; - } - - static void run(typename FMHAChunkPrefillKernel::Params params) { - dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); - dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); - - // configure smem size and carveout - int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; - - const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); - const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); - -// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension -#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) - using namespace syclcompat::experimental; - auto event = launch>( - launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, - kernel_properties{sycl_exp::sub_group_size}}, - params); -#else - syclcompat::experimental::launch_properties launch_props { - sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), - }; - syclcompat::experimental::kernel_properties kernel_props{ - sycl::ext::oneapi::experimental::sub_group_size - }; - syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; - auto event = syclcompat::experimental::launch>(policy, params); -#endif - - EventManager::getInstance().addEvent(event); - } -}; - -template struct FMHAKernel { - - template - static void run(const chunk_prefill_args_t &args) { - cutlass::KernelHardwareInfo hw_info; - - using LayoutQ = cutlass::layout::RowMajor; - using LayoutK = cutlass::layout::ColumnMajor; - using LayoutV = cutlass::layout::RowMajor; - using LayoutO = cutlass::layout::RowMajor; - - using ElementInputKV = ElementInputQ; - - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< - EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, - GmemTiledCopyStore>; - using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue; - - using ProblemShapeRegular = cute::tuple; - using namespace cutlass::fmha::collective; - using ProblemShapeVarlen = cute::tuple; - using ProblemShapeType = std::conditional_t; - - // Mainloop - using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< - GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, - cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, - GmemTiledCopyQ, // Q - GmemTiledCopyK, // K - GmemTiledCopyV, // V, - Causal, - PagedKV>; - - using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; - - KernelLauncher launcher; - - launcher.run(args, hw_info); - } - - static void dispatch(const chunk_prefill_args_t &args) { - if(args.is_causal) { - run(args); - } - else { - run(args); - } - } -}; - -template -void policy_dispatch( - CutlassType cuType, - const chunk_prefill_args_t& args) { - const int PipelineStages = 2; - if(cuType == CutlassType::half) { - FMHAKernel::dispatch(args); - } - else { - FMHAKernel::dispatch(args); - } -} - -void chunk_prefill_kernel( - CutlassType cuType, - const chunk_prefill_args_t& args) { - if(args.head_size == HEAD_SIZE_LIMIT_0) { - policy_dispatch(cuType, args); - } else if(args.head_size == HEAD_SIZE_LIMIT_1) { - policy_dispatch(cuType, args); - } - else if(args.head_size == HEAD_SIZE_LIMIT_2) { - policy_dispatch(cuType, args); - } -} - -void cutlass_chunk_prefill_impl( - const at::Tensor& query, // [seq_q, heads, head_size] - const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] - const at::Tensor& value_cache, - at::Tensor& out, - const at::Tensor& block_table, - const at::Tensor& cu_seqlens_q, - const at::Tensor& cu_seqlens_k, - int max_seqlen_q, - int max_seqlen_k, - double sm_scale, - bool is_causal) { - int num_block = key_cache.size(0); - int block_size = key_cache.size(1); - int num_heads_q = query.size(1); - int num_heads_kv = key_cache.size(2); - int head_size = query.size(2); - int batch_size = cu_seqlens_q.numel() - 1; - int max_blocks_per_seq = block_table.size(1); - int total_seqlen_q = query.size(0); - int total_seqlen_k = num_block * block_size; - at::Tensor num_blocks_per_seq = torch::div(cu_seqlens_k, block_size); - - chunk_prefill_args_t args = { - query.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - out.data_ptr(), - block_table.data_ptr(), - num_blocks_per_seq.data_ptr(), - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - max_seqlen_q, - max_seqlen_k, - total_seqlen_q, - total_seqlen_k, - static_cast(sm_scale), - batch_size, - num_heads_q, - num_heads_kv, - head_size, - max_blocks_per_seq, - block_size, - is_causal - }; - CutlassType cuType = aten_to_Cutlass_dtype(query); - chunk_prefill_kernel(cuType, args); -} diff --git a/csrc/xpu/cutlass_kernels/utils.hpp b/csrc/xpu/cutlass_kernels/utils.hpp deleted file mode 100644 index 4715419..0000000 --- a/csrc/xpu/cutlass_kernels/utils.hpp +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once -#include "torch/all.h" -#include - -#define HEAD_SIZE_LIMIT_0 64 -#define HEAD_SIZE_LIMIT_1 128 -#define HEAD_SIZE_LIMIT_2 256 -#define HEAD_SIZE_LIMIT_3 512 - -enum class CutlassType { - half, - bfloat16, -}; - -inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { - CutlassType cuType; - if (input.scalar_type() == torch::kHalf) { - cuType = CutlassType::half; - } else if (input.scalar_type() == torch::kBFloat16) { - cuType = CutlassType::bfloat16; - } else { - TORCH_INTERNAL_ASSERT( - false, - ""); - } - return cuType; -} - -using namespace cute; -struct chunk_policy_head64 { - using ShapeQK = Shape<_128, _64, _64>; - using ShapePV = Shape<_128, _32, _64>; - using ShapeOutPut = Shape<_128, _64, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; -}; - -struct chunk_policy_head128 { - using ShapeQK = Shape<_128, _64, _64>; - using ShapePV = Shape<_128, _32, _64>; - using ShapeOutPut = Shape<_128, _128, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; -}; - -struct chunk_policy_head256 { - using ShapeQK = Shape<_256, _64, _64>; - using ShapePV = Shape<_256, _32, _64>; - using ShapeOutPut = Shape<_256, _192, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; -}; \ No newline at end of file diff --git a/csrc/xpu/cutlass_sycl_demo.cpp b/csrc/xpu/cutlass_sycl_demo.cpp deleted file mode 100644 index 7254b65..0000000 --- a/csrc/xpu/cutlass_sycl_demo.cpp +++ /dev/null @@ -1,520 +0,0 @@ - - -#include -#include -#include - -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/collective/xe_epilogue.hpp" -#include "cutlass/epilogue/fusion/xe_callbacks.hpp" -#include "cutlass/gemm/device/gemm_universal.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/collective/collective_mma.hpp" -#include "cutlass/util/GPU_Clock.hpp" - -#include -#include - -#include "cutlass/util/command_line.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/device/gemm_complex.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "helper.h" - -#include "cutlass/cutlass.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/device/sycl_tensor_fill.h" - -using namespace cute; - -/// Helper to initialize a block of device data -template -bool initialize_block(Element* block, std::size_t size, uint64_t seed = 2023) { - Element scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = Element(2); - scope_min = Element(0); - } else if (bits_input <= 8) { - scope_max = Element(2); - scope_min = Element(-2); - } else { - scope_max = Element(8); - scope_min = Element(-8); - } - - cutlass::reference::device::BlockFillRandomUniform(block, size, seed, - scope_max, scope_min, 0); - - syclcompat::wait(); - return true; -} - -template -bool initialize_block(cutlass::DeviceAllocation& block, - uint64_t seed = 2023) { - return initialize_block(block.get(), block.size(), seed); -} - -template -void initialize_mixed_dtype_block( - cutlass::DeviceAllocation& block_device, - cutlass::DeviceAllocation& block_device_dq, uint64_t seed) { - static_assert(cute::sizeof_bits_v >= 8); - - std::ranlux24_base rng(std::random_device{}()); - rng.seed(seed); - - int bits_input = cute::sizeof_bits_v; - T1 scope_max, scope_min; - if (bits_input == 1) { - scope_max = T1(2); - scope_min = T1(0); - } else if (bits_input <= 8) { - scope_max = T1(2); - scope_min = T1(-2); - } else { - scope_max = T1(8); - scope_min = T1(-8); - } - - std::uniform_int_distribution<> dist(scope_min, scope_max); - - if constexpr (cute::sizeof_bits_v >= 8) { - auto block_host = std::vector(block_device.size()); - auto block_host_dq = std::vector(block_device.size()); - for (int i = 0; i < block_host.size(); ++i) { - block_host[i] = static_cast(dist(rng)); - block_host_dq[i] = static_cast(block_host[i]); - } - - block_device.copy_from_host(block_host.data()); - block_device_dq.copy_from_host(block_host_dq.data()); - } else { - static constexpr auto array_size = 1024; - - cute::array_subbyte block_host{}; - auto block_host_dq = std::vector(array_size); - - for (int i = 0; i < block_host.size(); ++i) { - block_host[i] = static_cast(dist(rng)); - block_host_dq[i] = static_cast(block_host[i].get()); - } - - static constexpr auto elements_per_byte = - cute::sizeof_bits_v / cute::sizeof_bits_v; - - int loop_cnt = block_device.size() / array_size; - for (int i = 0; i < loop_cnt; i++) { - cutlass::device_memory::copy_to_device( - block_device.get() + (i * array_size) / elements_per_byte, - raw_pointer_cast(block_host.begin()), array_size); - cutlass::device_memory::copy_to_device( - block_device_dq.get() + i * array_size, block_host_dq.data(), - array_size); - } - - auto tail_size = block_device.size() % array_size; - if (tail_size) { - cutlass::device_memory::copy_to_device( - block_device.get() + (loop_cnt * array_size) / elements_per_byte, - raw_pointer_cast(block_host.begin()), tail_size); - cutlass::device_memory::copy_to_device( - block_device_dq.get() + loop_cnt * array_size, block_host_dq.data(), - tail_size); - } - } -} - -template -inline bool is_close(T a, T b, float atol, float rtol) { - return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); -} - -// TODO(Codeplay): use on device initialisation for this -template -inline void random_fill(T* src, int seed, size_t N, float max, float min) { - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v) { - std::random_device rd; - std::mt19937 gen(seed); - std::uniform_real_distribution dis(min, max); - auto buff = std::vector(N); - - for (size_t i = 0; i < N; ++i) { - buff[i] = (T)(dis(gen)); - } - syclcompat::memcpy(src, buff.data(), N); - syclcompat::wait(); - } else { - assert(0 & "Not supported dtype"); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Command line options parsing -struct Options { - bool help; - bool error; - - int m, n, k, l, iterations; - float alpha, beta; - - Options() - : help(false), - error(false), - m(5120), - n(4096), - k(4096), - l(1), - iterations(20), - alpha(1.f), - beta(0.f) {} - - // Parses the command line - void parse(int argc, char const** args) { - cutlass::CommandLine cmd(argc, args); - - if (cmd.check_cmd_line_flag("help")) { - help = true; - return; - } - - cmd.get_cmd_line_argument("m", m, 5120); - cmd.get_cmd_line_argument("n", n, 4096); - cmd.get_cmd_line_argument("k", k, 4096); - cmd.get_cmd_line_argument("l", l, 1); - cmd.get_cmd_line_argument("alpha", alpha, 1.f); - cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); - } - - /// Prints the usage statement. - std::ostream& print_usage(std::ostream& out) const { - out << "BMG GEMM Example\n\n" - << "Options:\n\n" - << " --help If specified, displays this usage " - "statement\n\n" - << " --m= Sets the M extent of the GEMM\n" - << " --n= Sets the N extent of the GEMM\n" - << " --k= Sets the K extent of the GEMM\n" - << " --l= Sets the L extent (batch count) of " - "the GEMM\n" - << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Iterations\n\n"; - - return out; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct ExampleRunner { - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; - - using LayoutA = typename Gemm::LayoutA; - using LayoutB = typename Gemm::LayoutB; - using LayoutC = typename Gemm::LayoutC; - using LayoutD = typename Gemm::LayoutD; - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; - - using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementC = typename Gemm::ElementC; - using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; - - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - // - // Data members - // - - /// Initialization - StrideA stride_A; - StrideB stride_B; - StrideC stride_C; - StrideD stride_D; - uint64_t seed = 0; - - cutlass::DeviceAllocation block_A; - cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_D; - cutlass::DeviceAllocation - block_ref_D; // Reference GEMM result for verification - - // - // Methods - // - - bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, - ElementCompute beta) { - auto [M, N, K, L] = problem_size; - - cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); - cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); - cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); - - cutlass::reference::device::GemmComplex( - {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, - cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, - ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); - - // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. - // default in-order queue - syclcompat::wait(); - - // Check if output from CUTLASS kernel and reference kernel are equal or not - bool passed = cutlass::reference::device::BlockCompareEqual( - block_ref_D.get(), block_D.get(), block_D.size()); - - return passed; - } - - /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(const ProblemShapeType& problem_size) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - // Complete the stride by combining static layout info (StrideA) with - // runtime size info (M,K,L) - stride_A = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = - cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = - cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - - block_A.reset(static_cast(M) * K * L); - block_B.reset(static_cast(K) * N * L); - block_C.reset(static_cast(M) * N * L); - block_D.reset(static_cast(M) * N * L); - block_ref_D.reset(static_cast(M) * N * L); - - initialize_block(block_A, seed + 2023); - initialize_block(block_B, seed + 2022); - initialize_block(block_C, seed + 2021); - } - - cutlass::Status run(const Options& options, - const cutlass::KernelHardwareInfo& hw_info) { - ProblemShapeType problem_size = - ProblemShapeType{options.m, options.n, options.k, options.l}; - - initialize(problem_size); - - typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, - block_C.get(), - stride_C, - block_D.get(), - stride_D}, - hw_info}; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { - std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n - << 'x' << options.k << 'x' << options.l << std::endl; - std::exit(1); - } - - CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); - - // Run the GEMM - CUTLASS_CHECK(gemm_op.run()); - - syclcompat::wait(); - - // Verify that the result is correct - bool passed = verify(problem_size, options.alpha, options.beta); - std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; - - if (!passed) return cutlass::Status::kErrorInternal; - - if (options.iterations > 0) { - GPU_Clock timer; - timer.start(); - for (int i = 0; i < options.iterations; ++i) { - gemm_op.run(); - } - syclcompat::wait(); - - float cute_time = timer.seconds() / options.iterations; - double tflops = - (2.0 * options.m * options.n * options.k * options.l) * 1e-12; - std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' - << options.k << 'x' << options.l << std::endl; - printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", - tflops / cute_time, cute_time * 1000); - } - - return cutlass::Status::kSuccess; - } -}; - -void cutlass_sycl_demo(torch::Tensor& a) { - // - // Parse options - // - // - std::cout << a.sizes() << std::endl; - - Options options; - - /* options.parse(argc, argv); */ - - if (options.help) { - options.print_usage(std::cout) << std::endl; - return; - } - - if (options.error) { - std::cerr << "Aborting execution." << std::endl; - return; - } - - // - // Run examples - // - - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a - // given device ID. This information is used by the underlying kernel. - cutlass::KernelHardwareInfo hw_info; - - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - - bool passed; - - // The code section below describes datatype for input, output matrices and - // computation between elements in input matrices. - using ElementAccumulator = float; // <- data type of accumulator - using ElementComputeEpilogue = float; // <- data type of epilogue operations - using ElementInputA = - bfloat16_t; // <- data type of elements in input matrix A - using ElementInputB = - bfloat16_t; // <- data type of elements in input matrix B - using ElementOutput = float; // <- data type of elements in output matrix D - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::RowMajor; - using LayoutD = cutlass::layout::RowMajor; - - // The 2D block copy operations used for the A and B matrices - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; - - // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; - - // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, - // combining both additional hardware (sub-groups for Intel BMG) and - // iterations by each sub-group. - // - // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom - // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group - // layout (8x4x1). The TiledMMA constructed using TiledMMAHelper has the - // property that each sub-group operates on a single contiguous chunk of the - // work-group TileShape. For this configuration, this implies that each - // sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See - // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major - // (stride 4,1,0) for performance reasons. - using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 - typename TiledMMAHelper< - MMA_Atom, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch - // from A and B. - constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = - cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - - // This is the 'default' epilogue operation (Linear Combination) which - // performs everything in: (D = alpha * (A*B) + beta * C) aside from the - // (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for - // more complex epilogue examples. - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< - ElementOutput, ElementComputeEpilogue, ElementAccumulator, - ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; - - // FusionCallbacks ties the EpilogueOp to an implementation (based on the - // dispatch policy/architecture) and defines the epilogue arguments. - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< - EpilogueDispatchPolicy, EpilogueOp, TileShape, - decltype(tile_shape(TiledMma()))>; - // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & - // load/stores any auxiliary data required - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, TileShape, ElementAccumulator, - cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to - // CUTLASS 3.x representation - ElementOutput, - cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to - // CUTLASS 3.x representation - FusionCallBacks, - XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C - void, void, - XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D - void, void>; - - // GEMM Mainloop - iteration over blocks in K dimension - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, TileShape, ElementInputA, - cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to - // CUTLASS 3.x representation - ElementInputB, - cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to - // CUTLASS 3.x representation - TiledMma, GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - // Define the whole kernel (mainloop and epilogue) - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, // Defer global problem shape definition to - // runtime - CollectiveMainloop, CollectiveEpilogue>; - - // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the - // launch, and e.g. persistent scratch memory if required. - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - ExampleRunner runner; - - CUTLASS_CHECK(runner.run(options, hw_info)); -} diff --git a/csrc/xpu/helper.h b/csrc/xpu/helper.h deleted file mode 100644 index 4bc345c..0000000 --- a/csrc/xpu/helper.h +++ /dev/null @@ -1,127 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#if defined(CUTLASS_ENABLE_SYCL) - #include "cutlass/util/sycl_timer.hpp" -#else - #include -#endif -#include - -/** - * Panic wrapper for unwinding CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ - << " at: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -/** - * Panic wrapper for unwinding CUDA runtime errors - */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ - << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -/** - * GPU timer for recording the elapsed time across kernel(s) launched in GPU - * stream - */ -struct GpuTimer { -#if defined(CUTLASS_ENABLE_SYCL) - using cudaStream_t = int; - SYCLTimer syclTimer; -#else - cudaEvent_t _start; - cudaEvent_t _stop; -#endif - cudaStream_t _stream_id; - - /// Constructor - GpuTimer() : _stream_id(0) { -#if !defined(CUTLASS_ENABLE_SYCL) - CUDA_CHECK(cudaEventCreate(&_start)); - CUDA_CHECK(cudaEventCreate(&_stop)); -#endif - } - - /// Destructor - ~GpuTimer() { -#if !defined(CUTLASS_ENABLE_SYCL) - CUDA_CHECK(cudaEventDestroy(_start)); - CUDA_CHECK(cudaEventDestroy(_stop)); -#endif - } - - /// Start the timer for a given stream (defaults to the default stream) - void start(cudaStream_t stream_id = 0) { - _stream_id = stream_id; -#if defined(CUTLASS_ENABLE_SYCL) - syclTimer.start(); -#else - CUDA_CHECK(cudaEventRecord(_start, _stream_id)); -#endif - } - - /// Stop the timer - void stop() { -#if defined(CUTLASS_ENABLE_SYCL) - syclTimer.stop(); -#else - CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); -#endif - } - - /// Return the elapsed time (in milliseconds) - float elapsed_millis() { -#if defined(CUTLASS_ENABLE_SYCL) - return syclTimer.milliseconds(); -#else - float elapsed = 0.0; - CUDA_CHECK(cudaEventSynchronize(_stop)); - CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); - return elapsed; -#endif - } -}; \ No newline at end of file diff --git a/csrc/xpu/mha.h b/csrc/xpu/mha.h deleted file mode 100644 index a18cc0c..0000000 --- a/csrc/xpu/mha.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - - -void cutlass_chunk_prefill_impl( - at::Tensor& query, // [seq_q, heads, head_size] - at::Tensor& key_cache, // [num_block, block_size, heads, head_size] - at::Tensor& value_cache, - at::Tensor& out, - at::Tensor& block_table, - at::Tensor& cu_seqlens_q, - at::Tensor& cu_seqlens_k, - int max_seqlen_q, - int max_seqlen_k, - double sm_scale, - bool is_causal -); \ No newline at end of file diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index 7c1dcf0..7c3dbdc 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -7,5 +7,3 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); - -void cutlass_sycl_demo(torch::Tensor& a); \ No newline at end of file diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 9c9e0a2..39c193e 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -31,9 +31,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm); - - ops.def("cutlass_sycl_demo(Tensor a) -> ()"); - ops.impl("cutlass_sycl_demo", torch::kXPU, &cutlass_sycl_demo); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/quantization/test_fused_moe.py b/tests/cutlass/test_fused_moe.py similarity index 100% rename from tests/quantization/test_fused_moe.py rename to tests/cutlass/test_fused_moe.py diff --git a/tests/flash_attn/test.py b/tests/flash_attn/test.py deleted file mode 100644 index dfa6539..0000000 --- a/tests/flash_attn/test.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -import torch - -from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func - -DTYPES = [torch.half, torch.bfloat16] -dtype = torch.half - -torch.set_default_device("xpu") -batch_size = 1 -seq_len = 512 -num_heads = 8 -head_dim = 128 - -max_seqlen_q = seq_len -cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) -max_seqlen_k = seq_len -cu_seqlens_k = cu_seqlens_q - -block_size = 128 -num_blocks = max_seqlen_q // block_size -max_num_blocks_per_seq = seq_len // block_size - -block_tables = torch.randint(0, num_blocks, (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - -print(block_tables) -print(cu_seqlens_q) - -q = torch.randn(sum(cu_seqlens_q), num_heads, head_dim, dtype=dtype) -k = torch.randn(num_blocks, block_size, num_heads, head_dim, dtype=dtype) -v = torch.randn(num_blocks, block_size, num_heads, head_dim, dtype=dtype) - -# Call the flash attention function -output= flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, - max_seqlen_k, cu_seqlens_k, block_table=block_tables) - -assert output is not None -assert output.dtype == dtype diff --git a/tests/flash_attn/test_flash_attn_varlen_func.py b/tests/flash_attn/test_flash_attn_varlen_func.py deleted file mode 100644 index fab86b4..0000000 --- a/tests/flash_attn/test_flash_attn_varlen_func.py +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func - -DTYPES = [torch.half, torch.bfloat16] - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_flash_attn_varlen_func(dtype): - torch.set_default_device("xpu") - batch_size = 1 - seq_len = 4 - num_heads = 8 - head_dim = 16 - - q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) - k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) - v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) - - max_seqlen_q = seq_len - cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) - max_seqlen_k = seq_len - cu_seqlens_k = cu_seqlens_q - - # Call the flash attention function - output = flash_attn_varlen_func(q, k, v, max_seqlen_q, cu_seqlens_q, - max_seqlen_k, cu_seqlens_k) - - assert output is not None - assert output.dtype == dtype - assert output.shape == (batch_size, seq_len, num_heads, head_dim) diff --git a/tests/test_cutlass_op.py b/tests/test_cutlass_op.py deleted file mode 100644 index f575ae9..0000000 --- a/tests/test_cutlass_op.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -import vllm_xpu_kernels._C # noqa F401 - -DTYPES = [torch.half, torch.bfloat16] - - -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode() -def test_cutlass_op(dtype: torch.dtype, ): - torch.set_default_device("xpu") - a = torch.zeros((2, 3), dtype=dtype, device="xpu") - torch.ops._C.cutlass_sycl_demo(a) diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py index 293c6e5..6714e50 100644 --- a/vllm_xpu_kernels/__init__.py +++ b/vllm_xpu_kernels/__init__.py @@ -1,4 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 -from .flash_attn_interface import flash_attn_varlen_func # noqa: F401 from .fused_moe_interface import cutlass_fused_moe, cutlass_grouped_gemm diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py deleted file mode 100644 index 33e7de8..0000000 --- a/vllm_xpu_kernels/flash_attn_interface.py +++ /dev/null @@ -1,107 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional - -import torch - -#isort: off -try: - from . import _vllm_fa2_C # noqa: F401 - FA2_UNAVAILABLE_REASON = None - FA2_AVAILABLE = True -except ImportError as e: - FA2_UNAVAILABLE_REASON = str(e) - FA2_AVAILABLE = False - -#isort: on - -DEFAULT_FA_VERSION = 2 - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def flash_attn_varlen_func( - q, - k, - v, - max_seqlen_q, - cu_seqlens_q, - max_seqlen_k, - cu_seqlens_k=None, # only used for non-paged prefill - seqused_k=None, - q_v=None, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size: Optional[list[int]] = None, - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None, - return_softmax_lse=False, - out=None, - # FA3 Only - scheduler_metadata=None, - q_descale=None, - k_descale=None, - v_descale=None, - num_splits: int = 0, - # Version selector - fa_version: int = DEFAULT_FA_VERSION, -): - assert cu_seqlens_k is not None or seqused_k is not None, \ - "cu_seqlens_k or seqused_k must be provided" - assert cu_seqlens_k is None or seqused_k is None, \ - "cu_seqlens_k and seqused_k cannot be provided at the same time" - - if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) - # custom op does not support non-tuple input - real_window_size: tuple[int, int] - if window_size is None: - real_window_size = (-1, -1) - else: - assert len(window_size) == 2 - real_window_size = (window_size[0], window_size[1]) - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - - dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) - - if fa_version == 2: - if scheduler_metadata is not None and q_descale is not None \ - and k_descale is not None and v_descale is not None: - raise NotImplementedError( - "FA2 does not support scheduler_metadata, q_descale, " - "k_descale, v_descale") - if num_splits > 1: - raise NotImplementedError("FA2 does not support num_splits > 1") - out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( - q, - k, - v, - out, - cu_seqlens_q, - # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp - # still wants it so we pass all zeros - dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, - seqused_k, - None, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - real_window_size[0], - real_window_size[1], - softcap, - return_softmax_lse and dropout_p > 0, - None, - ) - else: - raise NotImplementedError("not support yet") - return (out, softmax_lse) if return_softmax_lse else out From f7518e0e42e3685bbb6062d68554b12ebd85d6a2 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 4 Sep 2025 06:46:22 +0000 Subject: [PATCH 21/47] rebase CMakeLists Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 136 ++++++++++-------- .../cutlass_backend/cutlass_kernels.cpp | 0 csrc/{ => xpu}/cutlass_backend/grouped_gemm.h | 0 csrc/{ => xpu}/cutlass_backend/helper.h | 0 .../{ => xpu}/cutlass_backend/sycl_common.hpp | 0 setup.py | 2 +- 6 files changed, 75 insertions(+), 63 deletions(-) rename csrc/{ => xpu}/cutlass_backend/cutlass_kernels.cpp (100%) rename csrc/{ => xpu}/cutlass_backend/grouped_gemm.h (100%) rename csrc/{ => xpu}/cutlass_backend/helper.h (100%) rename csrc/{ => xpu}/cutlass_backend/sycl_common.hpp (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 98dbc64..49e6be9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,8 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + # Suppress potential warnings about unused manually-specified variables set(ignoreMe "${VLLM_PYTHON_PATH}") @@ -45,10 +47,6 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21") # set(TORCH_SUPPORTED_VERSION_XPU "2.8.0") -set(ENABLE_MOE_KERNEL OFF) -set(FA2_ENABLED OFF) -set(FP8_ENABLED ON) - # # Try to find python package with an executable that exactly matches # `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions. @@ -70,6 +68,7 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") # Import torch cmake configuration. find_package(Torch REQUIRED) +find_package(oneDNN REQUIRED) # # Forward the non-CUDA device extensions to external CMake scripts. @@ -134,7 +133,8 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # # TODO: check SYCL flags set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS}") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + set(SYCL_FIRST_HEADER "${CMAKE_CURRENT_SOURCE_DIR}/csrc/sycl_first.h") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -include ${SYCL_FIRST_HEADER}") endif() # @@ -145,11 +145,45 @@ endif() # # _C extension # - if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_SRC - "csrc/xpu/layernorm.cpp" + "csrc/cache.cpp" + "csrc/layernorm.cpp" + "csrc/activation.cpp" + "csrc/pos_encoding_kernels.cpp" + "csrc/torch_bindings.cpp" + "csrc/quantization/fp8/fp8_quant.cpp" + ) + include_directories("/usr/include") + set(CMPLR_ROOT $ENV{CMPLR_ROOT}) + set(CMAKE_CXX_COMPILER icpx) + set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) + list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") + list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) +endif() + +message(STATUS "Enabling C extension.") +define_gpu_extension_target( + _C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + +# +# xpu only ops/kernels, implemented with cutlass/onednn/sycl. +# +if(VLLM_GPU_LANG STREQUAL "SYCL") + set(VLLM_EXT_XPU_SRC "csrc/xpu/torch_bindings.cpp" + "csrc/cutlass_backend/*.cpp" ) include_directories("/usr/include") list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) @@ -160,8 +194,6 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) - - # add cutlass dependency set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") @@ -205,15 +237,29 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED") list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0") list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-O3") + list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") +endif() +if(ONEDNN_FOUND) + set(_ONEDNN_SRC) + file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp) + list(APPEND VLLM_EXT_XPU_SRC + ${_ONEDNN_SRC} + ) + include_directories(${ONEDNN_INCLUDE_DIR}) + link_libraries(${ONEDNN_LIBRARY}) endif() -message(STATUS "Enabling C extension.") + + + define_gpu_extension_target( - _C + _xpu_C DESTINATION vllm_xpu_kernels LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_EXT_SRC} + SOURCES ${VLLM_EXT_XPU_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} @@ -224,57 +270,23 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -if (FP8_ENABLED) - message(STATUS "Enabling FP8 extension.") - file(GLOB FP8_GEN_SRCS "csrc/cutlass_backend/*.cpp") - - # list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file") - list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") - list(APPEND VLLM_GPU_FLAGS "-O3") - list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") - - define_gpu_extension_target( - _vllm_fp8_C - DESTINATION vllm_xpu_kernels - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${FP8_GEN_SRCS} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} - USE_SABI 3 - WITH_SOABI) - - # target_include_directories(_vllm_fa2_C PRIVATE - # csrc/flash_attn - # csrc/flash_attn/src) -endif () - # # _moe_C extension # -# TODO: add this as a placeholder for now. - -if (ENABLE_MOE_KERNEL) - set(VLLM_MOE_EXT_SRC - "csrc/moe/torch_bindings.cpp" - ) - - message(STATUS "Enabling moe extension.") - define_gpu_extension_target( - _moe_C - DESTINATION vllm_xpu_kernels - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_MOE_EXT_SRC} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} - USE_SABI 3 - WITH_SOABI) -endif() +set(VLLM_MOE_EXT_SRC + "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cpp") +message(STATUS "Enabling moe extension.") +define_gpu_extension_target( + _moe_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_MOE_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) diff --git a/csrc/cutlass_backend/cutlass_kernels.cpp b/csrc/xpu/cutlass_backend/cutlass_kernels.cpp similarity index 100% rename from csrc/cutlass_backend/cutlass_kernels.cpp rename to csrc/xpu/cutlass_backend/cutlass_kernels.cpp diff --git a/csrc/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h similarity index 100% rename from csrc/cutlass_backend/grouped_gemm.h rename to csrc/xpu/cutlass_backend/grouped_gemm.h diff --git a/csrc/cutlass_backend/helper.h b/csrc/xpu/cutlass_backend/helper.h similarity index 100% rename from csrc/cutlass_backend/helper.h rename to csrc/xpu/cutlass_backend/helper.h diff --git a/csrc/cutlass_backend/sycl_common.hpp b/csrc/xpu/cutlass_backend/sycl_common.hpp similarity index 100% rename from csrc/cutlass_backend/sycl_common.hpp rename to csrc/xpu/cutlass_backend/sycl_common.hpp diff --git a/setup.py b/setup.py index ec51ce4..0c8b824 100644 --- a/setup.py +++ b/setup.py @@ -259,7 +259,7 @@ def run(self): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C")) - ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._vllm_fp8_C")) + ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._xpu_C")) if ext_modules: cmdclass = {"build_ext": cmake_build_ext} From 083bde5a7577e94a210374830830b8c0c5299bd0 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 4 Sep 2025 06:56:34 +0000 Subject: [PATCH 22/47] use main Cmakes Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 60 +++----------------------------------------------- 1 file changed, 3 insertions(+), 57 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 49e6be9..d530a41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,63 +183,14 @@ define_gpu_extension_target( if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_XPU_SRC "csrc/xpu/torch_bindings.cpp" - "csrc/cutlass_backend/*.cpp" ) include_directories("/usr/include") - list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) - list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/) - list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) - message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") + set(CMPLR_ROOT $ENV{CMPLR_ROOT}) + set(CMAKE_CXX_COMPILER icpx) set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) - list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) - # add cutlass dependency - set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") - - # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use") - - # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided - FetchContent_Declare( - cutlass-sycl - GIT_REPOSITORY https://github.com/intel/cutlass-sycl.git - # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG ${CUTLASS_REVISION} - GIT_PROGRESS TRUE - - # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. - # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. - # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE - ) - - # cutlass compilation flags - set(CUTLASS_ENABLE_SYCL "ON") - # set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures") - set(CMAKE_EXPORT_COMPILE_COMMANDS "ON") - set(CUTLASS_ENABLE_BENCHMARKS "OFF") - # disable cuda - set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") - # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) - # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) - - - FetchContent_MakeAvailable(cutlass-sycl) - set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") - set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") - set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "") - message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}") - - # header only library - list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL") - list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET") - list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED") - list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0") - list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always") - list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") - list(APPEND VLLM_GPU_FLAGS "-O3") - list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") endif() if(ONEDNN_FOUND) @@ -252,9 +203,6 @@ if(ONEDNN_FOUND) link_libraries(${ONEDNN_LIBRARY}) endif() - - - define_gpu_extension_target( _xpu_C DESTINATION vllm_xpu_kernels @@ -265,8 +213,6 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) From 48a4808d074d361ec792c23d7d2ee9856def81c8 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 4 Sep 2025 06:59:38 +0000 Subject: [PATCH 23/47] use main setup Signed-off-by: Ma, Liangliang --- setup.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 0c8b824..358723c 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import importlib.util import logging import os +import shutil import subprocess import sys from pathlib import Path @@ -119,13 +120,12 @@ def configure(self, ext: CMakeExtension) -> None: # Select the build type. # Note: optimization level + debug info are set by the build type - default_cfg = "Debug" if self.debug else "RelWithDebInfo" + default_cfg = "Debug" if self.debug else "Release" cfg = envs.CMAKE_BUILD_TYPE or default_cfg cmake_args = [ '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), - '-DCMAKE_TOOLCHAIN_FILE=cmake/toolchain.cmake' ] verbose = envs.VERBOSE @@ -176,9 +176,21 @@ def configure(self, ext: CMakeExtension) -> None: else: # Default build tool to whatever cmake picks. build_tool = [] + my_env = os.environ.copy() + icx_path = shutil.which('icx') + icpx_path = shutil.which('icpx') + build_option_gpu = { + "BUILD_MODULE_TYPE": "GPU", + "CMAKE_C_COMPILER": f"{icx_path}", + "CMAKE_CXX_COMPILER": f"{icpx_path}", + } + for key, value in build_option_gpu.items(): + if value is not None: + cmake_args.append("-D{}={}".format(key, value)) subprocess.check_call( ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], - cwd=self.build_temp) + cwd=self.build_temp, + env=my_env) def build_extensions(self) -> None: # Ensure that CMake is present and working @@ -259,6 +271,7 @@ def run(self): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C")) + ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._moe_C")) ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._xpu_C")) if ext_modules: From 22d1ade5a00c2070a403213e568afab10d330ffc Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 4 Sep 2025 07:01:42 +0000 Subject: [PATCH 24/47] mv utils Signed-off-by: Ma, Liangliang --- csrc/{xpu => }/utils.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename csrc/{xpu => }/utils.h (100%) diff --git a/csrc/xpu/utils.h b/csrc/utils.h similarity index 100% rename from csrc/xpu/utils.h rename to csrc/utils.h From 1c7f46de47d0791638577b8ef46c2b7933892826 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 4 Sep 2025 08:08:28 +0000 Subject: [PATCH 25/47] finish rebase Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 65 +++++++++++++++++-- ...{cutlass_kernels.cpp => cutlass_kernels.h} | 8 +-- csrc/xpu/ops.h | 2 + csrc/xpu/torch_bindings.cpp | 4 ++ setup.py | 1 + vllm_xpu_kernels/__init__.py | 1 - vllm_xpu_kernels/fused_moe_interface.py | 5 +- 7 files changed, 71 insertions(+), 15 deletions(-) rename csrc/xpu/cutlass_backend/{cutlass_kernels.cpp => cutlass_kernels.h} (70%) diff --git a/CMakeLists.txt b/CMakeLists.txt index d530a41..ed65698 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -156,7 +156,6 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") ) include_directories("/usr/include") set(CMPLR_ROOT $ENV{CMPLR_ROOT}) - set(CMAKE_CXX_COMPILER icpx) set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") @@ -180,17 +179,68 @@ define_gpu_extension_target( # # xpu only ops/kernels, implemented with cutlass/onednn/sycl. # +file(GLOB CUTLASS_BACKEND_SRCS + csrc/xpu/cutlass_backend/*.cpp +) + if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_XPU_SRC "csrc/xpu/torch_bindings.cpp" + ${CUTLASS_BACKEND_SRCS} ) include_directories("/usr/include") - set(CMPLR_ROOT $ENV{CMPLR_ROOT}) - set(CMAKE_CXX_COMPILER icpx) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) + message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) - list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + # add cutlass dependency + set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use") + + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + FetchContent_Declare( + cutlass-sycl + GIT_REPOSITORY https://github.com/intel/cutlass-sycl.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) + + # cutlass compilation flags + set(CUTLASS_ENABLE_SYCL "ON") + # set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures") + set(CMAKE_EXPORT_COMPILE_COMMANDS "ON") + set(CUTLASS_ENABLE_BENCHMARKS "OFF") + # disable cuda + set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") + # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) + # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) + + + FetchContent_MakeAvailable(cutlass-sycl) + set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") + set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") + set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "") + message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}") + + # header only library + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL") + list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET") + list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED") + list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0") + list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always") + endif() if(ONEDNN_FOUND) @@ -203,6 +253,11 @@ if(ONEDNN_FOUND) link_libraries(${ONEDNN_LIBRARY}) endif() + + +list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") +list(APPEND VLLM_GPU_FLAGS "-O3") +list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") define_gpu_extension_target( _xpu_C DESTINATION vllm_xpu_kernels @@ -213,6 +268,8 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/xpu/cutlass_backend/cutlass_kernels.cpp b/csrc/xpu/cutlass_backend/cutlass_kernels.h similarity index 70% rename from csrc/xpu/cutlass_backend/cutlass_kernels.cpp rename to csrc/xpu/cutlass_backend/cutlass_kernels.h index 02f5bd4..1a0ccfe 100644 --- a/csrc/xpu/cutlass_backend/cutlass_kernels.cpp +++ b/csrc/xpu/cutlass_backend/cutlass_kernels.h @@ -7,10 +7,9 @@ // #include /* #include "pytorch_shim.h" */ -#include "core/registration.h" #include -#include "xpu/utils.h" #include "grouped_gemm.h" +#include "utils.h" namespace gpu::cutlass_kernel { @@ -46,9 +45,4 @@ at::Tensor grouped_gemm_func( } // namespace gpu::cutlass_kernel -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int hidden_size, int intermediate_size, int num_of_expert) -> Tensor"); - ops.impl("cutlass_grouped_gemm", torch::kXPU, gpu::cutlass_kernel::grouped_gemm_func); -} -REGISTER_EXTENSION(TORCH_EXTENSION_NAME); diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index e3d1b66..9292994 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -6,3 +6,5 @@ torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B, bool trans_B, const std::optional& B_scale_, const std::optional& bias_); + +torch::Tensor cutlass_grouped_gemm(torch::Tensor input, torch::Tensor weight, torch::Tensor res, torch::Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert); diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 5906365..eed9bce 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -1,5 +1,6 @@ #include "core/registration.h" #include "xpu/ops.h" +#include "xpu/cutlass_backend/cutlass_kernels.h" #include #include @@ -11,6 +12,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) { "fp8_gemm_w8a16(Tensor! A, Tensor! B, bool trans_B, Tensor? B_scale_, " "Tensor? bias_) -> Tensor"); xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16); + + xpu_ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int hidden_size, int intermediate_size, int num_of_expert) -> Tensor"); + xpu_ops.impl("cutlass_grouped_gemm", torch::kXPU, gpu::cutlass_kernel::grouped_gemm_func); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index 358723c..ac88aeb 100644 --- a/setup.py +++ b/setup.py @@ -126,6 +126,7 @@ def configure(self, ext: CMakeExtension) -> None: cmake_args = [ '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + '-DCMAKE_TOOLCHAIN_FILE=cmake/toolchain.cmake' ] verbose = envs.VERBOSE diff --git a/vllm_xpu_kernels/__init__.py b/vllm_xpu_kernels/__init__.py index 6714e50..db43924 100644 --- a/vllm_xpu_kernels/__init__.py +++ b/vllm_xpu_kernels/__init__.py @@ -1,3 +1,2 @@ # SPDX-License-Identifier: Apache-2.0 -from .fused_moe_interface import cutlass_fused_moe, cutlass_grouped_gemm diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 2d209e7..7e7250b 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -3,11 +3,10 @@ from torch import nn, Tensor from typing import List import numpy - -from . import _vllm_fp8_C +import vllm_xpu_kernels._xpu_C def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): - torch.ops._vllm_fp8_C.cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) + torch.ops._xpu_C.cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, activation, num_experts): token_cnt, hidden_size = list(hidden_states.shape) From df0b915ac556c7f37f526495d5a24f5968c4af95 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Fri, 5 Sep 2025 08:45:41 +0000 Subject: [PATCH 26/47] add profile and change to col-maj Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/grouped_gemm.h | 37 +++++++++++++++++++++++-- vllm_xpu_kernels/fused_moe_interface.py | 4 +-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 568a0cc..fa089ac 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -100,6 +100,7 @@ using ElementA = bfloat16_t; // <- data type of elements in input matri using ElementB = bfloat16_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D bool debug = false; +bool collect_gflops = false; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -138,6 +139,24 @@ struct Options { } groups = group_cnt; } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + }; @@ -416,9 +435,23 @@ void allocate(const Options &options, int64_t* offset) { } // Run the GEMM CUTLASS_CHECK(gemm_op.run()); + + if (collect_gflops){ + GPU_Clock timer; + timer.start(); + for (int iter = 0; iter < 100; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + syclcompat::wait(); + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; - stream->throw_asynchronous(); + } + stream->throw_asynchronous(); return cutlass::Status::kSuccess; } @@ -455,7 +488,7 @@ void kernel_functor( using ElementScale = cutlass::bfloat16_t; using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 7e7250b..5cfc2ed 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -36,7 +36,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # gemm1 input_A = torch.cat(grouped_input_A, dim=0).contiguous() - input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w13 #.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=torch.float32,device=hidden_states.device) offset = torch.tensor(offset, dtype=torch.int64, device='cpu') @@ -56,7 +56,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # gemm 2 group_A = act_output.to(torch.bfloat16).contiguous() output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) - w2 = w2.transpose(-1, -2).contiguous().transpose(-1, -2) + w2 = w2 #.transpose(-1, -2).contiguous().transpose(-1, -2) cutlass_grouped_gemm(input_A=group_A, input_B=w2, output=output, From 76fe4bc98d130b69e3c5692ecbaf4fc3559d0793 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 9 Sep 2025 07:58:49 +0000 Subject: [PATCH 27/47] dont not reserve block_C Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/grouped_gemm.h | 7 ++++--- tests/cutlass/test_fused_moe.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index fa089ac..31838ff 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -272,7 +272,7 @@ void allocate(const Options &options, int64_t* offset) { stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); } - block_C.reset(total_elements_C); + // block_C.reset(total_elements_C); block_alpha.reset(options.groups); block_beta.reset(options.groups); } @@ -339,7 +339,7 @@ void allocate(const Options &options, int64_t* offset) { beta_device.copy_from_host(ptr_beta_host.data()); - initialize_block(block_C, 0); + // initialize_block(block_C, 666); // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers // (alpha_device/beta_device) are passed instead block_alpha.copy_from_host(alpha_host.data()); @@ -437,6 +437,7 @@ void allocate(const Options &options, int64_t* offset) { CUTLASS_CHECK(gemm_op.run()); if (collect_gflops){ + std::cout << "collect_gflops:" << collect_gflops << std::endl; GPU_Clock timer; timer.start(); for (int iter = 0; iter < 100; ++iter) { @@ -451,7 +452,7 @@ void allocate(const Options &options, int64_t* offset) { std::cout << " GFLOPS : " << gflops << std::endl; } - stream->throw_asynchronous(); + stream->throw_asynchronous(); return cutlass::Status::kSuccess; } diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index 79c9266..aa4ad37 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -28,6 +28,25 @@ DEVICE = "xpu" + +def calculate_device_mem(m, k, n, e, topk, dtype): + total = 0 + x = m*k + w13 = e*2*n*k + w2 = e*k*n + topk_w = topk*m + topk_id = topk*m + expert_cache = x + gemm1_out = m*2*n + gemm2_out = m*k + total += x + w13 + w2 + topk_w + topk_id + expert_cache + gemm1_out + gemm2_out + byte_per_data = 4 + if dtype == torch.bfloat16: + byte_per_data = 2 + total_bytes_G = total * byte_per_data / 1000 / 1000 / 1000 + print("fused moe should take device memory: ", total_bytes_G, "G") + + def test_grouped_gemm(num_experts, n, k, token_per_group): # input input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu") From ad0fdd68cde8b8dd2972c00b5aea0740d037d209 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 11 Sep 2025 03:12:28 +0000 Subject: [PATCH 28/47] remove redundant allocation Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 2 +- csrc/xpu/cutlass_backend/cutlass_kernels.h | 39 +++-- csrc/xpu/cutlass_backend/grouped_gemm.h | 178 ++++++++++----------- csrc/xpu/ops.h | 10 +- csrc/xpu/torch_bindings.cpp | 2 +- mll_build.sh | 2 +- tests/cutlass/test_fused_moe.py | 7 +- vllm_xpu_kernels/fused_moe_interface.py | 99 +++++++++--- 8 files changed, 200 insertions(+), 139 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ed65698..bd7351d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( cutlass-sycl - GIT_REPOSITORY https://github.com/intel/cutlass-sycl.git + GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl.git # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE diff --git a/csrc/xpu/cutlass_backend/cutlass_kernels.h b/csrc/xpu/cutlass_backend/cutlass_kernels.h index 1a0ccfe..df6236f 100644 --- a/csrc/xpu/cutlass_backend/cutlass_kernels.h +++ b/csrc/xpu/cutlass_backend/cutlass_kernels.h @@ -16,31 +16,28 @@ namespace gpu::cutlass_kernel { /* gemm2(group_A, w2, output, offset) */ at::Tensor grouped_gemm_func( - at::Tensor& input, - at::Tensor& weight, - at::Tensor& res, + at::Tensor& ptr_A, + at::Tensor& ptr_B, + at::Tensor& ptr_D, + at::Tensor& ptr_alpha, + at::Tensor& ptr_beta, at::Tensor& offset, - int64_t hidden_size, - int64_t intermediate_size, - int64_t num_of_expert - ) { + int64_t N, + int64_t K, + int64_t groups) { auto dpcpp_queue = vllm::xpu::vllmGetQueue(); - if (input.scalar_type() != at::kBFloat16) { - std::cout << "error:wrong datatype, current only support bfloat16" << std::endl; - return at::Tensor(); - } - - grouped_gemm::kernel_functor( + grouped_gemm::kernel_functor( &dpcpp_queue, - input.data_ptr(), - weight.data_ptr(), - res.data_ptr(), + ptr_A.data_ptr(), + ptr_B.data_ptr(), + ptr_D.data_ptr(), + ptr_alpha.data_ptr(), + ptr_beta.data_ptr(), offset.data_ptr(), - hidden_size, - intermediate_size, - num_of_expert - ); - return res; + N, + K, + groups); + return ptr_D; } } // namespace gpu::cutlass_kernel diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 31838ff..15f44b7 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -191,28 +191,24 @@ struct GroupedGemmRunner { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; // Host-side allocations - std::vector offset_A; - std::vector offset_B; std::vector offset_C; - std::vector offset_D; + // std::vector offset_D; std::vector stride_A_host; std::vector stride_B_host; std::vector stride_C_host; std::vector stride_D_host; - std::vector alpha_host; - std::vector beta_host; + // std::vector alpha_host; + // std::vector beta_host; // Device-side allocations cutlass::DeviceAllocation problem_sizes; // This example defines all matrices in a single allocation (e.g. block_A), but this is not a // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) - cutlass::DeviceAllocation ptr_A; - cutlass::DeviceAllocation ptr_B; cutlass::DeviceAllocation ptr_C; - cutlass::DeviceAllocation ptr_D; + // cutlass::DeviceAllocation ptr_D; cutlass::DeviceAllocation stride_A; cutlass::DeviceAllocation stride_B; @@ -221,30 +217,22 @@ struct GroupedGemmRunner { cutlass::DeviceAllocation block_C; // Note, this is an array of pointers to alpha and beta scaling values per group - cutlass::DeviceAllocation alpha_device; - cutlass::DeviceAllocation beta_device; - cutlass::DeviceAllocation block_alpha; - cutlass::DeviceAllocation block_beta; + // cutlass::DeviceAllocation alpha_device; + // cutlass::DeviceAllocation beta_device; + // cutlass::DeviceAllocation block_alpha; + // cutlass::DeviceAllocation block_beta; /// Allocates device-side data -void allocate(const Options &options, int64_t* offset) { +void allocate(const Options &options) { if (debug){ std::cout << "void allocate()" << std::endl; } - int64_t total_elements_A = 0; - int64_t total_elements_B = 0; int64_t total_elements_C = 0; - int64_t total_elements_D = 0; + // int64_t total_elements_D = 0; - int offset_iter = 0; // Compute total allocation sizes across group for (int32_t i = 0; i < options.groups; ++i) { - while (offset[offset_iter] == 0){ - total_elements_B += options.n * options.k; - offset_iter++; - continue; - } - offset_iter++; + auto problem = options.problem_sizes_host.at(i); auto M = get<0>(problem); @@ -252,20 +240,14 @@ void allocate(const Options &options, int64_t* offset) { auto K = get<2>(problem); // Offset into block allocation of each matrix base pointer - offset_A.push_back(total_elements_A); - offset_B.push_back(total_elements_B); offset_C.push_back(total_elements_C); - offset_D.push_back(total_elements_D); + // offset_D.push_back(total_elements_D); - int64_t elements_A = M * K; - int64_t elements_B = K * N; int64_t elements_C = M * N; - int64_t elements_D = M * N; + // int64_t elements_D = M * N; - total_elements_A += elements_A; - total_elements_B += elements_B; total_elements_C += elements_C; - total_elements_D += elements_D; + // total_elements_D += elements_D; stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); @@ -273,52 +255,40 @@ void allocate(const Options &options, int64_t* offset) { stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); } // block_C.reset(total_elements_C); - block_alpha.reset(options.groups); - block_beta.reset(options.groups); + // block_alpha.reset(options.groups); + // block_beta.reset(options.groups); } - void initialize(const Options &options, ElementA * block_A, ElementB * block_B, - ElementOutput* block_D) { + void initialize(const Options &options) { if (debug){ std::cout << "void initialize()" << std::endl; } problem_sizes.reset(options.groups); problem_sizes.copy_from_host(options.problem_sizes_host.data()); - std::vector ptr_A_host(options.groups); - std::vector ptr_B_host(options.groups); std::vector ptr_C_host(options.groups); - std::vector ptr_D_host(options.groups); - std::vector ptr_alpha_host(options.groups); - std::vector ptr_beta_host(options.groups); + // std::vector ptr_D_host(options.groups); + // std::vector ptr_alpha_host(options.groups); + // std::vector ptr_beta_host(options.groups); // Compute offsets, alpha & beta over group on host for (int32_t i = 0; i < options.groups; ++i) { - ptr_A_host.at(i) = block_A + offset_A.at(i); - ptr_B_host.at(i) = block_B + offset_B.at(i); ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - ptr_D_host.at(i) = block_D + offset_D.at(i); + // ptr_D_host.at(i) = block_D + offset_D.at(i); // Fill host vector of alpha & beta with random values if using per-group values - alpha_host.push_back(static_cast(1)); - beta_host.push_back(static_cast(0)); - // Fill host ptr vectors with offset addresses into device alpha/beta blocks - ptr_alpha_host.at(i) = block_alpha.get() + i; - ptr_beta_host.at(i) = block_beta.get() + i; + // alpha_host.push_back(static_cast(1)); + // beta_host.push_back(static_cast(0)); + // // Fill host ptr vectors with offset addresses into device alpha/beta blocks + // ptr_alpha_host.at(i) = block_alpha.get() + i; + // ptr_beta_host.at(i) = block_beta.get() + i; } - // Allocate device memory & copy from host - ptr_A.reset(options.groups); - // Per-group alpha and beta - ptr_A.copy_from_host(ptr_A_host.data()); - - ptr_B.reset(options.groups); - ptr_B.copy_from_host(ptr_B_host.data()); - + // // Allocate device memory & copy from host ptr_C.reset(options.groups); ptr_C.copy_from_host(ptr_C_host.data()); - ptr_D.reset(options.groups); - ptr_D.copy_from_host(ptr_D_host.data()); + // ptr_D.reset(options.groups); + // ptr_D.copy_from_host(ptr_D_host.data()); stride_A.reset(options.groups); stride_A.copy_from_host(stride_A_host.data()); @@ -333,22 +303,29 @@ void allocate(const Options &options, int64_t* offset) { stride_D.copy_from_host(stride_D_host.data()); // Per-group alpha and beta ptrs - alpha_device.reset(options.groups); - alpha_device.copy_from_host(ptr_alpha_host.data()); - beta_device.reset(options.groups); - beta_device.copy_from_host(ptr_beta_host.data()); + // alpha_device.reset(options.groups); + // alpha_device.copy_from_host(ptr_alpha_host.data()); + // beta_device.reset(options.groups); + // beta_device.copy_from_host(ptr_beta_host.data()); // initialize_block(block_C, 666); // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers // (alpha_device/beta_device) are passed instead - block_alpha.copy_from_host(alpha_host.data()); - block_beta.copy_from_host(beta_host.data()); + // block_alpha.copy_from_host(alpha_host.data()); + // block_beta.copy_from_host(beta_host.data()); } /// Populates a Gemm::Arguments structure from the given commandline options - typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + typename Gemm::Arguments args_from_options(const Options &options, + const cutlass::KernelHardwareInfo& hw_info, + const ElementA ** ptr_A, + const ElementB ** ptr_B, + ElementOutput ** ptr_D, + ElementAccumulator ** ptr_alpha, + ElementAccumulator ** ptr_beta, + bool host_problem_shapes_available = true) { typename Gemm::Arguments arguments; decltype(arguments.epilogue.thread) fusion_args; @@ -371,8 +348,8 @@ void allocate(const Options &options, int64_t* offset) { fusion_args.beta = 0; fusion_args.alpha_ptr = nullptr; fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = alpha_device.get(); - fusion_args.beta_ptr_array = beta_device.get(); + fusion_args.alpha_ptr_array = ptr_alpha; + fusion_args.beta_ptr_array = ptr_beta; // One alpha and beta per each group fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; @@ -384,8 +361,8 @@ void allocate(const Options &options, int64_t* offset) { arguments = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {ptr_A, stride_A.get(), ptr_B, stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D, stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN} }; @@ -394,8 +371,8 @@ void allocate(const Options &options, int64_t* offset) { arguments = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {ptr_A, stride_A.get(), ptr_B, stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D, stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN} }; @@ -409,19 +386,34 @@ void allocate(const Options &options, int64_t* offset) { const Options& options, sycl::queue* stream, const cutlass::KernelHardwareInfo& hw_info, - ElementA* inputA, - ElementB* inputB, - ElementOffset* offset, - ElementOutput* res) { + const ElementA** ptr_A, + const ElementB** ptr_B, + ElementOutput** ptr_D, + ElementAccumulator** ptr_alpha, + ElementAccumulator** ptr_beta) { if (debug){ std::cout << "enter run" << std::endl; } + + // std::vector ptr_AA_host(options.groups); - allocate(options, offset); - initialize(options, inputA, inputB, res); + // stream->memcpy(ptr_AA_host.data(), ptr_AA, options.groups * sizeof(int64_t)).wait(); + // // cutlass::device_memory::copy_from_device(ptr_A_host.data(), ptr_A, options.groups); + // for (int i = 0; i < options.groups; ++i){ + // std::cout << "AA ptr:" << ptr_AA_host.at(i) << std::endl; + // } + + allocate(options); + initialize(options); Gemm gemm_op; - auto arguments = args_from_options(options, hw_info, true); + auto arguments = args_from_options(options, hw_info, + ptr_A, + ptr_B, + ptr_D, + ptr_alpha, + ptr_beta, + true); size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); @@ -460,18 +452,20 @@ void allocate(const Options &options, int64_t* offset) { void kernel_functor( sycl::queue* stream, - void* input, - void* weight, - void* res, + void* ptr_A, + void* ptr_B, + void* ptr_D, + void* ptr_alpha, + void* ptr_beta, void* offset, - int64_t hidden_size, - int64_t intermediate_size, - int64_t num_of_expert){ + int64_t N, + int64_t K, + int64_t groups){ // // Run examples // auto offset_ptr = reinterpret_cast(offset); - Options options(offset_ptr, hidden_size, intermediate_size, num_of_expert); + Options options(offset_ptr, N, K, groups); // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This // information is used by the underlying kernel. cutlass::KernelHardwareInfo hw_info; @@ -489,7 +483,7 @@ void kernel_functor( using ElementScale = cutlass::bfloat16_t; using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; @@ -556,11 +550,11 @@ void kernel_functor( options, stream, hw_info, - reinterpret_cast(input), - reinterpret_cast(weight), - reinterpret_cast(offset), - reinterpret_cast(res) - ); + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + reinterpret_cast(ptr_D), + reinterpret_cast(ptr_alpha), + reinterpret_cast(ptr_beta)); } diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index 9292994..a3edf81 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -7,4 +7,12 @@ torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B, const std::optional& B_scale_, const std::optional& bias_); -torch::Tensor cutlass_grouped_gemm(torch::Tensor input, torch::Tensor weight, torch::Tensor res, torch::Tensor offset, int64_t hidden_size, int64_t intermediate_size, int64_t num_of_expert); +torch::Tensor cutlass_grouped_gemm(torch::Tensor ptr_A, + torch::Tensor ptr_B, + torch::Tensor ptr_D, + torch::Tensor ptr_alpha, + torch::Tensor ptr_beta, + torch::Tensor offset, + int64_t N, + int64_t K, + int64_t groups); diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index eed9bce..904501b 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -13,7 +13,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) { "Tensor? bias_) -> Tensor"); xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16); - xpu_ops.def("cutlass_grouped_gemm(Tensor input, Tensor weight, Tensor res, Tensor offset, int hidden_size, int intermediate_size, int num_of_expert) -> Tensor"); + xpu_ops.def("cutlass_grouped_gemm(Tensor ptr_A, Tensor ptr_B, Tensor ptr_D, Tensor ptr_alpha, Tensor ptr_beta, Tensor offset, int N, int K, int groups) -> Tensor"); xpu_ops.impl("cutlass_grouped_gemm", torch::kXPU, gpu::cutlass_kernel::grouped_gemm_func); } diff --git a/mll_build.sh b/mll_build.sh index 3e5ace0..47889fb 100644 --- a/mll_build.sh +++ b/mll_build.sh @@ -1,2 +1,2 @@ -python3 setup.py clean +# python3 setup.py clean VLLM_TARGET_DEVICE=xpu python3 setup.py develop diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index aa4ad37..a5cfc2f 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -49,7 +49,7 @@ def calculate_device_mem(m, k, n, e, topk, dtype): def test_grouped_gemm(num_experts, n, k, token_per_group): # input - input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu") + input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu").contiguous() # weight input_B = torch.randn((num_experts, n, k), dtype=torch.bfloat16, device="xpu") @@ -69,12 +69,15 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): if cur_token_num == 0: continue input = input_A[pre_token_sum:pre_token_sum + cur_token_num, :] + print("refA ptr",i, ":", hex(input.data_ptr())) weight = input_B[i, :, :] expert_output = input @ weight.T ref.append(expert_output) pre_token_sum += cur_token_num ref = torch.cat(ref, dim=0).float() + print("kernel:", output) + print("reference:", ref) print(torch.allclose(output, ref, rtol=1, atol=1)) max_diff = (output - ref).abs().max() print("Max absolute difference:", max_diff) @@ -187,4 +190,4 @@ def test_fused_moe( ep_size = 1, dtype = torch.bfloat16 ) - # test_grouped_gemm(num_experts=16, n=5120, k=8192, token_per_group=[1,2,6,8,12,0,1,5,1,2,6,8,12,0,1,5]) + # test_grouped_gemm(num_experts=2, n=4096, k=4096, token_per_group=[512,512]) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 5cfc2ed..cff5233 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -5,8 +5,74 @@ import numpy import vllm_xpu_kernels._xpu_C +def prepare_gemm_args(n, k, offset, A, B, D, alpha, beta): + gemm_args = {} + device = A.device + # problem_sizes = [] + ptr_A = [] + ptr_B = [] + ptr_D = [] + ptr_alpha = [] + ptr_beta = [] + total_elements_A = 0 + total_elements_B = 0 + total_elements_D = 0 + + def process_data_ptr(tensor, offset, addr_list, dim): + mul = 2 + if tensor.dtype == torch.float32: + mul = 4 + # print("process data_ptr:", tensor.shape) + # print(tensor.data_ptr()) + # print(offset*mul) + # addr = tensor.data_ptr() + offset*mul + if dim == 1: + addr = tensor[offset].data_ptr() + elif dim == 2: + addr = tensor[offset, :].data_ptr() + elif dim == 3: + addr = tensor[offset, :, :].data_ptr() + for i in range(8): # 64bit -> 8 bytes + byte_val = (addr >> (i * 8)) & 0xFF + addr_list.append(byte_val) + + groups = 0 + for m in offset: + if m != 0: + # problem_sizes.extend([m, n, k]) + process_data_ptr(A, total_elements_A, ptr_A, 2) + process_data_ptr(B, total_elements_B, ptr_B, 3) + process_data_ptr(D, total_elements_D, ptr_D, 2) + process_data_ptr(alpha, groups, ptr_alpha, 1) + process_data_ptr(beta, groups, ptr_beta, 1) + total_elements_A += m; + total_elements_D += m; + groups += 1 + total_elements_B += 1; + + # problem_sizes = torch.tensor(problem_sizes, dtype=torch.int64, device='cpu').contiguous() + ptr_A = torch.tensor(ptr_A, dtype=torch.uint8, device=device).contiguous() + ptr_B = torch.tensor(ptr_B, dtype=torch.uint8, device=device).contiguous() + ptr_D = torch.tensor(ptr_D, dtype=torch.uint8, device=device).contiguous() + ptr_alpha = torch.tensor(ptr_alpha, dtype=torch.uint8, device=device).contiguous() + ptr_beta = torch.tensor(ptr_beta, dtype=torch.uint8, device=device).contiguous() + + # gemm_args["problem_sizes"] = problem_sizes + gemm_args["ptr_A"] = ptr_A + gemm_args["ptr_B"] = ptr_B + gemm_args["ptr_D"] = ptr_D + gemm_args["ptr_alpha"] = ptr_alpha + gemm_args["ptr_beta"] = ptr_beta + gemm_args["groups"] = groups + return gemm_args + + def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): - torch.ops._xpu_C.cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) + device = "xpu" + alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) + beta = torch.zeros(num_experts, dtype=torch.float32, device=input_A.device) + gemm_args = prepare_gemm_args(n, k, offset, input_A, input_B, output, alpha, beta) + torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset, N=n, K=k, **gemm_args) def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, activation, num_experts): token_cnt, hidden_size = list(hidden_states.shape) @@ -36,17 +102,15 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # gemm1 input_A = torch.cat(grouped_input_A, dim=0).contiguous() - input_B = w13 #.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=torch.float32,device=hidden_states.device) - offset = torch.tensor(offset, dtype=torch.int64, device='cpu') - cutlass_grouped_gemm(input_A=input_A, - input_B=input_B, - output=gemm1_output, - offset=offset, - n=2*intermediate_size, - k=hidden_size, - num_experts=num_experts) + alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) + beta = torch.zeros(num_experts, dtype=torch.float32, device=input_A.device) + gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta) + offset_t = torch.tensor(offset, dtype=torch.int64, device='cpu') + torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=2*intermediate_size, K=hidden_size, **gemm_args) + # act gate, up = torch.split(gemm1_output, intermediate_size, dim=1) act = torch.nn.SiLU() @@ -54,16 +118,11 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # gemm 2 - group_A = act_output.to(torch.bfloat16).contiguous() - output = torch.empty((list(group_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) - w2 = w2 #.transpose(-1, -2).contiguous().transpose(-1, -2) - cutlass_grouped_gemm(input_A=group_A, - input_B=w2, - output=output, - offset=offset, - n=hidden_size, - k=intermediate_size, - num_experts=num_experts) + input_A = act_output.to(torch.bfloat16).contiguous() + output = torch.empty((list(input_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) + input_B = w2.transpose(-1, -2).contiguous().transpose(-1, -2) + gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, output, alpha, beta) + torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=hidden_size, K=intermediate_size, **gemm_args) # apply scores for expert_id, end_idx in enumerate(tokens_per_expert): From 54e64a7087cde9919a0c511c6a6510d99f3a14f2 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 11 Sep 2025 07:02:22 +0000 Subject: [PATCH 29/47] e2e debug Signed-off-by: Ma, Liangliang --- vllm_xpu_kernels/fused_moe_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index cff5233..4b78890 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -110,7 +110,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta) offset_t = torch.tensor(offset, dtype=torch.int64, device='cpu') torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=2*intermediate_size, K=hidden_size, **gemm_args) - + print("@@@@@@@ cutlass fused moe gemm1 done") # act gate, up = torch.split(gemm1_output, intermediate_size, dim=1) act = torch.nn.SiLU() @@ -140,5 +140,5 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ expert_out, reduce='sum' ) - + print("@@@@@@@ cutlass fused moe gemm2 done") return expert_cache From 3c4000896248ff1dbe5ebd57d8ce6d1e04922d86 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 11 Sep 2025 07:23:43 +0000 Subject: [PATCH 30/47] add release func Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/grouped_gemm.h | 12 +++++++++++- vllm_xpu_kernels/fused_moe_interface.py | 3 ++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 15f44b7..02b25ee 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -222,6 +222,16 @@ struct GroupedGemmRunner { // cutlass::DeviceAllocation block_alpha; // cutlass::DeviceAllocation block_beta; + void release(){ + problem_sizes.release(); + ptr_C.release(); + stride_A.release(); + stride_B.release(); + stride_C.release(); + stride_D.release(); + block_C.release(); + } + /// Allocates device-side data void allocate(const Options &options) { if (debug){ @@ -445,7 +455,7 @@ void allocate(const Options &options) { } stream->throw_asynchronous(); - + release(); return cutlass::Status::kSuccess; } }; diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 4b78890..82a54a2 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -75,6 +75,7 @@ def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset, N=n, K=k, **gemm_args) def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, activation, num_experts): + token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] expert_cache = torch.empty((token_cnt, hidden_size), @@ -99,7 +100,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ grouped_input_A.append(expert_tokens) total_input_size = token_cnt * num_per_tok - + print("@@@@@@@ cutlass fused moe enter") # gemm1 input_A = torch.cat(grouped_input_A, dim=0).contiguous() input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) From 985004d11f5da43cbf65ee60b287bdbb03a701b3 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 11 Sep 2025 08:02:18 +0000 Subject: [PATCH 31/47] gemm args allocate once Signed-off-by: Ma, Liangliang --- tests/cutlass/test_fused_moe.py | 2 +- vllm_xpu_kernels/fused_moe_interface.py | 78 +++++++++++++++---------- 2 files changed, 49 insertions(+), 31 deletions(-) diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index a5cfc2f..5957576 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -174,7 +174,7 @@ def test_fused_moe( print("ref result", ref_out, ref_out.shape) print("kernel result", out, out.shape) - print(torch.allclose(out, ref_out, rtol=1, atol=1)) + print(torch.allclose(out.float(), ref_out, rtol=1, atol=1)) max_diff = (out - ref_out).abs().max() print("Max absolute difference:", max_diff) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 82a54a2..a9426a2 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -5,20 +5,37 @@ import numpy import vllm_xpu_kernels._xpu_C -def prepare_gemm_args(n, k, offset, A, B, D, alpha, beta): - gemm_args = {} - device = A.device +def prepare_gemm_args(n, k, offset, A, B, D, alpha, beta, e): + + if not hasattr(prepare_gemm_args, "gemm_args"): + print("@cutlass fusedMoe allocate gemm args once") + gemm_args = {} + device = A.device + ptr_A = torch.empty(e*8, dtype=torch.uint8, device=device).contiguous() + ptr_B = torch.empty(e*8, dtype=torch.uint8, device=device).contiguous() + ptr_D = torch.empty(e*8, dtype=torch.uint8, device=device).contiguous() + ptr_alpha = torch.empty(e*8, dtype=torch.uint8, device=device).contiguous() + ptr_beta = torch.empty(e*8, dtype=torch.uint8, device=device).contiguous() + gemm_args["ptr_A"] = ptr_A + gemm_args["ptr_B"] = ptr_B + gemm_args["ptr_D"] = ptr_D + gemm_args["ptr_alpha"] = ptr_alpha + gemm_args["ptr_beta"] = ptr_beta + prepare_gemm_args.gemm_args = gemm_args + + # gemm_args = {} + # problem_sizes = [] - ptr_A = [] - ptr_B = [] - ptr_D = [] - ptr_alpha = [] - ptr_beta = [] + ptr_A = prepare_gemm_args.gemm_args["ptr_A"] + ptr_B = prepare_gemm_args.gemm_args["ptr_B"] + ptr_D = prepare_gemm_args.gemm_args["ptr_D"] + ptr_alpha = prepare_gemm_args.gemm_args["ptr_alpha"] + ptr_beta = prepare_gemm_args.gemm_args["ptr_beta"] total_elements_A = 0 total_elements_B = 0 total_elements_D = 0 - def process_data_ptr(tensor, offset, addr_list, dim): + def process_data_ptr(tensor, offset, addr_tensor, dim, group): mul = 2 if tensor.dtype == torch.float32: mul = 4 @@ -34,37 +51,37 @@ def process_data_ptr(tensor, offset, addr_list, dim): addr = tensor[offset, :, :].data_ptr() for i in range(8): # 64bit -> 8 bytes byte_val = (addr >> (i * 8)) & 0xFF - addr_list.append(byte_val) + addr_tensor[8*group + i] = byte_val groups = 0 for m in offset: if m != 0: # problem_sizes.extend([m, n, k]) - process_data_ptr(A, total_elements_A, ptr_A, 2) - process_data_ptr(B, total_elements_B, ptr_B, 3) - process_data_ptr(D, total_elements_D, ptr_D, 2) - process_data_ptr(alpha, groups, ptr_alpha, 1) - process_data_ptr(beta, groups, ptr_beta, 1) + process_data_ptr(A, total_elements_A, ptr_A, 2, groups) + process_data_ptr(B, total_elements_B, ptr_B, 3, groups) + process_data_ptr(D, total_elements_D, ptr_D, 2, groups) + process_data_ptr(alpha, groups, ptr_alpha, 1, groups) + process_data_ptr(beta, groups, ptr_beta, 1, groups) total_elements_A += m; total_elements_D += m; groups += 1 total_elements_B += 1; # problem_sizes = torch.tensor(problem_sizes, dtype=torch.int64, device='cpu').contiguous() - ptr_A = torch.tensor(ptr_A, dtype=torch.uint8, device=device).contiguous() - ptr_B = torch.tensor(ptr_B, dtype=torch.uint8, device=device).contiguous() - ptr_D = torch.tensor(ptr_D, dtype=torch.uint8, device=device).contiguous() - ptr_alpha = torch.tensor(ptr_alpha, dtype=torch.uint8, device=device).contiguous() - ptr_beta = torch.tensor(ptr_beta, dtype=torch.uint8, device=device).contiguous() + # ptr_A = torch.tensor(ptr_A, dtype=torch.uint8, device=device).contiguous() + # ptr_B = torch.tensor(ptr_B, dtype=torch.uint8, device=device).contiguous() + # ptr_D = torch.tensor(ptr_D, dtype=torch.uint8, device=device).contiguous() + # ptr_alpha = torch.tensor(ptr_alpha, dtype=torch.uint8, device=device).contiguous() + # ptr_beta = torch.tensor(ptr_beta, dtype=torch.uint8, device=device).contiguous() # gemm_args["problem_sizes"] = problem_sizes - gemm_args["ptr_A"] = ptr_A - gemm_args["ptr_B"] = ptr_B - gemm_args["ptr_D"] = ptr_D - gemm_args["ptr_alpha"] = ptr_alpha - gemm_args["ptr_beta"] = ptr_beta - gemm_args["groups"] = groups - return gemm_args + # gemm_args["ptr_A"] = ptr_A + # gemm_args["ptr_B"] = ptr_B + # gemm_args["ptr_D"] = ptr_D + # gemm_args["ptr_alpha"] = ptr_alpha + # gemm_args["ptr_beta"] = ptr_beta + prepare_gemm_args.gemm_args["groups"] = groups + return prepare_gemm_args.gemm_args def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): @@ -108,7 +125,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=torch.float32,device=hidden_states.device) alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) beta = torch.zeros(num_experts, dtype=torch.float32, device=input_A.device) - gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta) + gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta, num_experts) offset_t = torch.tensor(offset, dtype=torch.int64, device='cpu') torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=2*intermediate_size, K=hidden_size, **gemm_args) print("@@@@@@@ cutlass fused moe gemm1 done") @@ -122,7 +139,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ input_A = act_output.to(torch.bfloat16).contiguous() output = torch.empty((list(input_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) input_B = w2.transpose(-1, -2).contiguous().transpose(-1, -2) - gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, output, alpha, beta) + gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, output, alpha, beta, num_experts) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=hidden_size, K=intermediate_size, **gemm_args) # apply scores @@ -142,4 +159,5 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ reduce='sum' ) print("@@@@@@@ cutlass fused moe gemm2 done") - return expert_cache + hidden_states = expert_cache.to(hidden_states.dtype) + return hidden_states From 9c1809237a74cc4e15798785992fccbea3907860 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Thu, 11 Sep 2025 08:30:05 +0000 Subject: [PATCH 32/47] hidden_states copy Signed-off-by: Ma, Liangliang --- vllm_xpu_kernels/fused_moe_interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index a9426a2..199dd64 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -120,7 +120,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ print("@@@@@@@ cutlass fused moe enter") # gemm1 input_A = torch.cat(grouped_input_A, dim=0).contiguous() - input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w13#.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=torch.float32,device=hidden_states.device) alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) @@ -138,7 +138,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # gemm 2 input_A = act_output.to(torch.bfloat16).contiguous() output = torch.empty((list(input_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) - input_B = w2.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w2#.transpose(-1, -2).contiguous().transpose(-1, -2) gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, output, alpha, beta, num_experts) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=hidden_size, K=intermediate_size, **gemm_args) @@ -159,5 +159,5 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ reduce='sum' ) print("@@@@@@@ cutlass fused moe gemm2 done") - hidden_states = expert_cache.to(hidden_states.dtype) + hidden_states.copy_(expert_cache.to(hidden_states.dtype)) return hidden_states From a47ecef028d9eead7705b760fd71a062da4288df Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Sun, 14 Sep 2025 08:15:24 +0000 Subject: [PATCH 33/47] output bf16 Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 2 +- csrc/xpu/cutlass_backend/grouped_gemm.h | 61 ++++++++++++++----------- vllm_xpu_kernels/fused_moe_interface.py | 21 ++++++--- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bd7351d..09e00ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -201,7 +201,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "dev" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 02b25ee..84266f8 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -72,6 +72,7 @@ #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/collective/xe_array_epilogue.hpp" #include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -98,7 +99,7 @@ using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementA = bfloat16_t; // <- data type of elements in input matrix A using ElementB = bfloat16_t; // <- data type of elements in input matrix B -using ElementOutput = float; // <- data type of elements in output matrix D +using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D bool debug = false; bool collect_gflops = false; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -178,14 +179,11 @@ struct GroupedGemmRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementC = typename Gemm::ElementC; - using ElementAcc = typename Gemm::ElementAccumulator; - using ElementScaleA = cutlass::half_t; - using ElementScaleB = cutlass::half_t; - using ElementOffset = int64_t; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementAccumulator = ElementOutput; + // using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementOutput = bfloat16_t; + using ElementAccumulator = float_t; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -488,12 +486,11 @@ void kernel_functor( using ElementComputeEpilogue = float; using ElementA = cutlass::bfloat16_t; using ElementB = cutlass::bfloat16_t; - using ElementOffset = int64_t; - using ElementOutput = float; + using ElementOutput = bfloat16_t; using ElementScale = cutlass::bfloat16_t; using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; @@ -515,23 +512,33 @@ void kernel_functor( using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + // using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + // using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + // using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + // EpilogueDispatchPolicy, + // TileShape, + // ElementAccumulator, + // cutlass::gemm::TagToStrideC_t, + // ElementOutput, + // cutlass::gemm::TagToStrideC_t, + // FusionCallBacks, + // XE_2D_U32x8x16_LD_N, + // void, void, + // XE_2D_U16x8x16_ST_N, + // void, void>; + using EpilogueOp = + cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape, + Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto, + float, float, float, LayoutC, 1, ElementOutput, LayoutC, 1, + EpilogueDispatchPolicy, EpilogueOp>::CollectiveOp; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -540,7 +547,7 @@ void kernel_functor( ElementA, cutlass::gemm::TagToStrideA_t, ElementB, - cutlass::gemm::TagToStrideB_t, + cutlass::gemm::TagToStrideA_t, TiledMma, GmemTiledCopyA, void, void, cute::identity, // A GmemTiledCopyB, void, void, cute::identity // B diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 199dd64..819b1b1 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -95,9 +95,17 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] - expert_cache = torch.empty((token_cnt, hidden_size), - dtype=torch.float32, - device=hidden_states.device) + if not hasattr(cutlass_fused_moe, "moe_buffer"): + print("@cutlass fusedMoe allocate moe_buffer once") + moe_buffer = {} + device = hidden_states.device + moe_buffer["expert_cache"] = torch.empty((token_cnt* hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device) + # gemm_args["ptr_A"] = ptr_A + cutlass_fused_moe.moe_buffer = moe_buffer + + expert_cache = moe_buffer["expert_cache"][:hidden_states.numel()].view_as(hidden_states).zero_() # map token to experts idxs = topk_ids.argsort() @@ -122,7 +130,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ input_A = torch.cat(grouped_input_A, dim=0).contiguous() input_B = w13#.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) - gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=torch.float32,device=hidden_states.device) + gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=hidden_states.dtype, device=hidden_states.device) alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) beta = torch.zeros(num_experts, dtype=torch.float32, device=input_A.device) gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta, num_experts) @@ -137,7 +145,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # gemm 2 input_A = act_output.to(torch.bfloat16).contiguous() - output = torch.empty((list(input_A.shape)[0], hidden_size), dtype=torch.float32, device=hidden_states.device) + output = torch.empty((list(input_A.shape)[0], hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) input_B = w2#.transpose(-1, -2).contiguous().transpose(-1, -2) gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, output, alpha, beta, num_experts) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=hidden_size, K=intermediate_size, **gemm_args) @@ -149,7 +157,6 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ continue exp_token_idxs = token_idxs[start_idx:end_idx] - expert_tokens = hidden_states[exp_token_idxs] expert_out = output[start_idx:end_idx] expert_out.mul_(topk_weights[idxs[start_idx:end_idx]]) expert_cache.scatter_reduce_( @@ -159,5 +166,5 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ reduce='sum' ) print("@@@@@@@ cutlass fused moe gemm2 done") - hidden_states.copy_(expert_cache.to(hidden_states.dtype)) + hidden_states.copy_(expert_cache) return hidden_states From 1a2d6556d634314e66c179e1cc6e48714071f353 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Sun, 14 Sep 2025 08:53:38 +0000 Subject: [PATCH 34/47] use static tensor buffer Signed-off-by: Ma, Liangliang --- tests/cutlass/test_fused_moe.py | 18 +++++---- vllm_xpu_kernels/fused_moe_interface.py | 49 ++++++++++++++++--------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index 5957576..5370c43 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -154,14 +154,16 @@ def test_fused_moe( flat_expert_indices = expert_indices.view(-1) flat_expert_weights = expert_scores.view(-1, 1) - out = cutlass_fused_moe(hidden_states=a, - w13=w13, - w2=w2, - topk_weights=flat_expert_weights, - topk_ids=flat_expert_indices, - n_experts_per_token=topk, - activation="silu", - num_experts=e) + iteration = 1 + for _ in range(iteration): + out = cutlass_fused_moe(hidden_states=a, + w13=w13, + w2=w2, + topk_weights=flat_expert_weights, + topk_ids=flat_expert_indices, + n_experts_per_token=topk, + activation="silu", + num_experts=e) ref_out = ref_fused_moe(a, w13, diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 819b1b1..96d8912 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -95,6 +95,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ token_cnt, hidden_size = list(hidden_states.shape) intermediate_size = list(w2.shape)[-1] + total_input_size = token_cnt * n_experts_per_token if not hasattr(cutlass_fused_moe, "moe_buffer"): print("@cutlass fusedMoe allocate moe_buffer once") moe_buffer = {} @@ -102,10 +103,27 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ moe_buffer["expert_cache"] = torch.empty((token_cnt* hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) - # gemm_args["ptr_A"] = ptr_A + moe_buffer["gemm1_input"] = torch.empty((total_input_size, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device) + moe_buffer["gemm1_output"] = torch.empty((total_input_size, 2*intermediate_size), + dtype=hidden_states.dtype, + device=hidden_states.device) + moe_buffer["gemm2_output"] = torch.empty((total_input_size, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device) + moe_buffer["alpha"] = torch.ones(num_experts, dtype=torch.float32, device=hidden_states.device) + moe_buffer["beta"] = torch.zeros(num_experts, dtype=torch.float32, device=hidden_states.device) + cutlass_fused_moe.moe_buffer = moe_buffer - expert_cache = moe_buffer["expert_cache"][:hidden_states.numel()].view_as(hidden_states).zero_() + expert_cache = cutlass_fused_moe.moe_buffer["expert_cache"][:hidden_states.numel()].view_as(hidden_states).zero_() + input_A = cutlass_fused_moe.moe_buffer["gemm1_input"][:total_input_size, :] + gemm1_output = cutlass_fused_moe.moe_buffer["gemm1_output"][:total_input_size, :] + gemm2_output = cutlass_fused_moe.moe_buffer["gemm2_output"][:total_input_size, :] + alpha = cutlass_fused_moe.moe_buffer["alpha"] + beta = cutlass_fused_moe.moe_buffer["beta"] + # map token to experts idxs = topk_ids.argsort() @@ -113,7 +131,6 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ tokens_per_expert = counts.cumsum() num_per_tok = n_experts_per_token token_idxs = idxs // num_per_tok - grouped_input_A = [] offset = [] for expert_id, end_idx in enumerate(tokens_per_expert): start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] @@ -121,18 +138,15 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ if start_idx == end_idx: continue exp_token_idxs = token_idxs[start_idx:end_idx] - expert_tokens = hidden_states[exp_token_idxs] - grouped_input_A.append(expert_tokens) + # expert_tokens = hidden_states[exp_token_idxs] + # grouped_input_A.append(expert_tokens) + input_A[start_idx:end_idx, :].copy_(hidden_states[exp_token_idxs]) + - total_input_size = token_cnt * num_per_tok + ########### gemm1 ################## print("@@@@@@@ cutlass fused moe enter") - # gemm1 - input_A = torch.cat(grouped_input_A, dim=0).contiguous() - input_B = w13#.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w13 #.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) - gemm1_output = torch.empty((total_input_size, 2*intermediate_size), dtype=hidden_states.dtype, device=hidden_states.device) - alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) - beta = torch.zeros(num_experts, dtype=torch.float32, device=input_A.device) gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta, num_experts) offset_t = torch.tensor(offset, dtype=torch.int64, device='cpu') torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=2*intermediate_size, K=hidden_size, **gemm_args) @@ -143,11 +157,10 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ act_output = act(gate) * up - # gemm 2 - input_A = act_output.to(torch.bfloat16).contiguous() - output = torch.empty((list(input_A.shape)[0], hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) - input_B = w2#.transpose(-1, -2).contiguous().transpose(-1, -2) - gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, output, alpha, beta, num_experts) + ########### gemm2 ################## + input_A = act_output.contiguous() + input_B = w2 #.transpose(-1, -2).contiguous().transpose(-1, -2) + gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, gemm2_output, alpha, beta, num_experts) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=hidden_size, K=intermediate_size, **gemm_args) # apply scores @@ -157,7 +170,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ continue exp_token_idxs = token_idxs[start_idx:end_idx] - expert_out = output[start_idx:end_idx] + expert_out = gemm2_output[start_idx:end_idx] expert_out.mul_(topk_weights[idxs[start_idx:end_idx]]) expert_cache.scatter_reduce_( 0, From f7dee65853bae2ade650eefc0b1d8b55fc175c9a Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Mon, 15 Sep 2025 07:38:15 +0000 Subject: [PATCH 35/47] remove ptr_C Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/grouped_gemm.h | 145 +++--------------------- 1 file changed, 15 insertions(+), 130 deletions(-) diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 84266f8..612c008 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -168,9 +168,6 @@ struct GroupedGemmRunner { using StrideC = typename Gemm::GemmKernel::InternalStrideC; using StrideD = typename Gemm::GemmKernel::InternalStrideD; - using StrideScaleA = typename Gemm::GemmKernel::StrideA; - using StrideScaleB = typename Gemm::GemmKernel::StrideB; - using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; @@ -181,53 +178,33 @@ struct GroupedGemmRunner { using ElementC = typename Gemm::ElementC; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - // using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementOutput = bfloat16_t; using ElementAccumulator = float_t; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - // Host-side allocations - std::vector offset_C; - // std::vector offset_D; - std::vector stride_A_host; std::vector stride_B_host; std::vector stride_C_host; std::vector stride_D_host; - // std::vector alpha_host; - // std::vector beta_host; - // Device-side allocations cutlass::DeviceAllocation problem_sizes; - // This example defines all matrices in a single allocation (e.g. block_A), but this is not a - // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) - cutlass::DeviceAllocation ptr_C; - // cutlass::DeviceAllocation ptr_D; - cutlass::DeviceAllocation stride_A; cutlass::DeviceAllocation stride_B; cutlass::DeviceAllocation stride_C; cutlass::DeviceAllocation stride_D; - cutlass::DeviceAllocation block_C; - // Note, this is an array of pointers to alpha and beta scaling values per group - // cutlass::DeviceAllocation alpha_device; - // cutlass::DeviceAllocation beta_device; - // cutlass::DeviceAllocation block_alpha; - // cutlass::DeviceAllocation block_beta; - - void release(){ +void release(){ problem_sizes.release(); - ptr_C.release(); + // ptr_C.release(); stride_A.release(); stride_B.release(); stride_C.release(); stride_D.release(); - block_C.release(); + // block_C.release(); } /// Allocates device-side data @@ -235,10 +212,6 @@ void allocate(const Options &options) { if (debug){ std::cout << "void allocate()" << std::endl; } - int64_t total_elements_C = 0; - // int64_t total_elements_D = 0; - - // Compute total allocation sizes across group for (int32_t i = 0; i < options.groups; ++i) { @@ -247,24 +220,11 @@ void allocate(const Options &options) { auto N = get<1>(problem); auto K = get<2>(problem); - // Offset into block allocation of each matrix base pointer - offset_C.push_back(total_elements_C); - // offset_D.push_back(total_elements_D); - - int64_t elements_C = M * N; - // int64_t elements_D = M * N; - - total_elements_C += elements_C; - // total_elements_D += elements_D; - stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); } - // block_C.reset(total_elements_C); - // block_alpha.reset(options.groups); - // block_beta.reset(options.groups); } void initialize(const Options &options) { @@ -274,29 +234,6 @@ void allocate(const Options &options) { problem_sizes.reset(options.groups); problem_sizes.copy_from_host(options.problem_sizes_host.data()); - std::vector ptr_C_host(options.groups); - // std::vector ptr_D_host(options.groups); - // std::vector ptr_alpha_host(options.groups); - // std::vector ptr_beta_host(options.groups); - - // Compute offsets, alpha & beta over group on host - for (int32_t i = 0; i < options.groups; ++i) { - ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - // ptr_D_host.at(i) = block_D + offset_D.at(i); - // Fill host vector of alpha & beta with random values if using per-group values - // alpha_host.push_back(static_cast(1)); - // beta_host.push_back(static_cast(0)); - // // Fill host ptr vectors with offset addresses into device alpha/beta blocks - // ptr_alpha_host.at(i) = block_alpha.get() + i; - // ptr_beta_host.at(i) = block_beta.get() + i; - } - - // // Allocate device memory & copy from host - ptr_C.reset(options.groups); - ptr_C.copy_from_host(ptr_C_host.data()); - - // ptr_D.reset(options.groups); - // ptr_D.copy_from_host(ptr_D_host.data()); stride_A.reset(options.groups); stride_A.copy_from_host(stride_A_host.data()); @@ -310,19 +247,6 @@ void allocate(const Options &options) { stride_D.reset(options.groups); stride_D.copy_from_host(stride_D_host.data()); - // Per-group alpha and beta ptrs - // alpha_device.reset(options.groups); - // alpha_device.copy_from_host(ptr_alpha_host.data()); - // beta_device.reset(options.groups); - // beta_device.copy_from_host(ptr_beta_host.data()); - - - // initialize_block(block_C, 666); - // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers - // (alpha_device/beta_device) are passed instead - // block_alpha.copy_from_host(alpha_host.data()); - // block_beta.copy_from_host(beta_host.data()); - } /// Populates a Gemm::Arguments structure from the given commandline options @@ -338,30 +262,17 @@ void allocate(const Options &options) { typename Gemm::Arguments arguments; decltype(arguments.epilogue.thread) fusion_args; - if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { - // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. - fusion_args.alpha = options.alpha; - fusion_args.beta = options.beta; - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = nullptr; - fusion_args.beta_ptr_array = nullptr; - // Single alpha and beta for all groups - fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; - fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; - } - else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = ptr_alpha; - fusion_args.beta_ptr_array = ptr_beta; - // One alpha and beta per each group - fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; - fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; - } + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = ptr_alpha; + fusion_args.beta_ptr_array = ptr_beta; + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; // Per-GEMM problem shape info may only exist on the device. @@ -370,7 +281,7 @@ void allocate(const Options &options) { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, {ptr_A, stride_A.get(), ptr_B, stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D, stride_D.get()}, + {fusion_args, nullptr, stride_C.get(), ptr_D, stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN} }; @@ -380,7 +291,7 @@ void allocate(const Options &options) { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, {ptr_A, stride_A.get(), ptr_B, stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D, stride_D.get()}, + {fusion_args, nullptr, stride_C.get(), ptr_D, stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN} }; @@ -403,14 +314,6 @@ void allocate(const Options &options) { std::cout << "enter run" << std::endl; } - // std::vector ptr_AA_host(options.groups); - - // stream->memcpy(ptr_AA_host.data(), ptr_AA, options.groups * sizeof(int64_t)).wait(); - // // cutlass::device_memory::copy_from_device(ptr_A_host.data(), ptr_A, options.groups); - // for (int i = 0; i < options.groups; ++i){ - // std::cout << "AA ptr:" << ptr_AA_host.at(i) << std::endl; - // } - allocate(options); initialize(options); Gemm gemm_op; @@ -511,24 +414,6 @@ void kernel_functor( constexpr int PipelineStages = 2; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - - // using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - // using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - // using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - // EpilogueDispatchPolicy, - // TileShape, - // ElementAccumulator, - // cutlass::gemm::TagToStrideC_t, - // ElementOutput, - // cutlass::gemm::TagToStrideC_t, - // FusionCallBacks, - // XE_2D_U32x8x16_LD_N, - // void, void, - // XE_2D_U16x8x16_ST_N, - // void, void>; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; From ad2dc48ff1e9b5cb96e6854cb57aa626c484baf2 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 17 Sep 2025 07:26:03 +0000 Subject: [PATCH 36/47] fix device lost Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/grouped_gemm.h | 7 ++++ tests/cutlass/test_fused_moe.py | 53 ++++++++++++++----------- vllm_xpu_kernels/fused_moe_interface.py | 16 ++++++-- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 612c008..ef8644f 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -127,11 +127,15 @@ struct Options { std::cout << "Options()" << std::endl; } int group_cnt = 0; + std::cout << "****Options() num_of_expert " << num_of_expert << std::endl; for (int i = 0; i < num_of_expert; ++i){ + std::cout << "****Options() i " << i << std::endl; + std::cout << "****Options() offset[i] " << offset[i] << std::endl; if (offset[i] != 0){ group_cnt++; } } + std::cout << "****Options() group_cnt " << group_cnt << std::endl; problem_sizes_host.reserve(group_cnt); for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ @@ -275,6 +279,9 @@ void allocate(const Options &options) { fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + std::cout << "grouped_gemm arguments" << std::endl; + std::cout << "options.groups " << options.groups << std::endl; + // Per-GEMM problem shape info may only exist on the device. if (host_problem_shapes_available) { arguments = typename Gemm::Arguments { diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index 5370c43..aaea85c 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -50,16 +50,15 @@ def calculate_device_mem(m, k, n, e, topk, dtype): def test_grouped_gemm(num_experts, n, k, token_per_group): # input input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu").contiguous() - + ref_A = input_A.clone() # weight input_B = torch.randn((num_experts, n, k), dtype=torch.bfloat16, device="xpu") - input_B = input_B.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = input_B #.transpose(-1, -2).contiguous().transpose(-1, -2) # output offset - output = torch.empty((sum(token_per_group), n), dtype=torch.float32, device="xpu") - offset = torch.tensor(token_per_group, dtype=torch.int64, device="cpu" ) + output = torch.empty((sum(token_per_group), n), dtype=torch.bfloat16, device="xpu") - cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts) + cutlass_grouped_gemm(input_A, input_B, output, token_per_group, n, k, num_experts) # ref gg ref = [] @@ -68,19 +67,25 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): cur_token_num = token_per_group[i] if cur_token_num == 0: continue - input = input_A[pre_token_sum:pre_token_sum + cur_token_num, :] + input = ref_A[pre_token_sum:pre_token_sum + cur_token_num, :] print("refA ptr",i, ":", hex(input.data_ptr())) weight = input_B[i, :, :] expert_output = input @ weight.T ref.append(expert_output) pre_token_sum += cur_token_num - ref = torch.cat(ref, dim=0).float() + ref = torch.cat(ref, dim=0) print("kernel:", output) print("reference:", ref) print(torch.allclose(output, ref, rtol=1, atol=1)) max_diff = (output - ref).abs().max() print("Max absolute difference:", max_diff) + try: + torch.testing.assert_close(output, ref, rtol=1, atol=1) + print("a 和 b 足够接近 ✅") + except AssertionError as e: + print("a 和 b 有差异 ❌") + print(e) def ref_fused_moe(x, w13, @@ -90,8 +95,9 @@ def ref_fused_moe(x, num_per_tok, activation, num_experts): - - expert_cache = torch.zeros_like(x).float() + # import ipdb + # ipdb.set_trace() + expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() counts = flat_expert_indices.bincount().cpu().numpy() tokens_per_expert = counts.cumsum() @@ -115,7 +121,7 @@ def ref_fused_moe(x, expert_cache.scatter_reduce_( 0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), - expert_out.float(), + expert_out, reduce='sum' ) @@ -139,9 +145,10 @@ def test_fused_moe( # todo: seed verbose = False # Setup test data - a = torch.randn((m, k), device=DEVICE, dtype=dtype) / 10 + a = torch.ones((m, k), device=DEVICE, dtype=dtype) / 10 w13 = torch.randn((e, 2 * n, k), device=DEVICE, dtype=dtype) / 10 w2 = torch.randn((e, k, n), device=DEVICE, dtype=dtype) / 10 + ref_a = a.clone() # moe gate scores = torch.randn((m, e), device=DEVICE, dtype=dtype) @@ -165,7 +172,7 @@ def test_fused_moe( activation="silu", num_experts=e) - ref_out = ref_fused_moe(a, + ref_out = ref_fused_moe(ref_a, w13, w2, flat_expert_weights, @@ -176,20 +183,20 @@ def test_fused_moe( print("ref result", ref_out, ref_out.shape) print("kernel result", out, out.shape) - print(torch.allclose(out.float(), ref_out, rtol=1, atol=1)) + print(torch.allclose(out, ref_out, rtol=1, atol=1)) max_diff = (out - ref_out).abs().max() print("Max absolute difference:", max_diff) if __name__ == "__main__": - test_fused_moe( - m = 33, - n = 2048, - k = 128, - e = 16, - topk = 2, - ep_size = 1, - dtype = torch.bfloat16 - ) - # test_grouped_gemm(num_experts=2, n=4096, k=4096, token_per_group=[512,512]) + # test_fused_moe( + # m = 4, + # n = 8192, + # k = 5120, + # e = 16, + # topk = 1, + # ep_size = 1, + # dtype = torch.bfloat16 + # ) + test_grouped_gemm(num_experts=2, n=4096, k=4096, token_per_group=[512,512]) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 96d8912..46936ff 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -80,7 +80,7 @@ def process_data_ptr(tensor, offset, addr_tensor, dim, group): # gemm_args["ptr_D"] = ptr_D # gemm_args["ptr_alpha"] = ptr_alpha # gemm_args["ptr_beta"] = ptr_beta - prepare_gemm_args.gemm_args["groups"] = groups + prepare_gemm_args.gemm_args["groups"] = e # FIXME: groups return prepare_gemm_args.gemm_args @@ -88,7 +88,8 @@ def cutlass_grouped_gemm(input_A, input_B, output, offset, n, k, num_experts): device = "xpu" alpha = torch.ones(num_experts, dtype=torch.float32, device=input_A.device) beta = torch.zeros(num_experts, dtype=torch.float32, device=input_A.device) - gemm_args = prepare_gemm_args(n, k, offset, input_A, input_B, output, alpha, beta) + gemm_args = prepare_gemm_args(n, k, offset, input_A, input_B, output, alpha, beta, num_experts) + offset = torch.tensor(offset, dtype=torch.int64, device="cpu" ) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset, N=n, K=k, **gemm_args) def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_per_token, activation, num_experts): @@ -142,19 +143,26 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # grouped_input_A.append(expert_tokens) input_A[start_idx:end_idx, :].copy_(hidden_states[exp_token_idxs]) + while len(offset) < num_experts: + offset.append(0) + # import ipdb + # ipdb.set_trace() ########### gemm1 ################## print("@@@@@@@ cutlass fused moe enter") input_B = w13 #.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta, num_experts) + print("**python offset", offset) offset_t = torch.tensor(offset, dtype=torch.int64, device='cpu') + print("**python offset_t", offset_t) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=2*intermediate_size, K=hidden_size, **gemm_args) print("@@@@@@@ cutlass fused moe gemm1 done") + print("gemm1 out", gemm1_output) # act - gate, up = torch.split(gemm1_output, intermediate_size, dim=1) + gate, up_ = torch.split(gemm1_output, intermediate_size, dim=1) act = torch.nn.SiLU() - act_output = act(gate) * up + act_output = act(gate) * up_ ########### gemm2 ################## From 56cb570749d250a29677bcc28d0895ecf30652fa Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Fri, 19 Sep 2025 05:18:40 +0000 Subject: [PATCH 37/47] acc and oom fixed Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/cutlass_kernels.h | 4 +- csrc/xpu/cutlass_backend/grouped_gemm.h | 16 ++++---- tests/cutlass/test_fused_moe.py | 43 +++++++++++++--------- vllm_xpu_kernels/fused_moe_interface.py | 4 +- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/csrc/xpu/cutlass_backend/cutlass_kernels.h b/csrc/xpu/cutlass_backend/cutlass_kernels.h index df6236f..3a2bdf7 100644 --- a/csrc/xpu/cutlass_backend/cutlass_kernels.h +++ b/csrc/xpu/cutlass_backend/cutlass_kernels.h @@ -25,9 +25,9 @@ at::Tensor grouped_gemm_func( int64_t N, int64_t K, int64_t groups) { - auto dpcpp_queue = vllm::xpu::vllmGetQueue(); + auto& dpcpp_queue = vllm::xpu::vllmGetQueue(); grouped_gemm::kernel_functor( - &dpcpp_queue, + dpcpp_queue, ptr_A.data_ptr(), ptr_B.data_ptr(), ptr_D.data_ptr(), diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index ef8644f..2442538 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -310,7 +310,7 @@ void allocate(const Options &options) { cutlass::Status run( const Options& options, - sycl::queue* stream, + sycl::queue& stream, const cutlass::KernelHardwareInfo& hw_info, const ElementA** ptr_A, const ElementB** ptr_B, @@ -344,14 +344,14 @@ void allocate(const Options &options) { std::cout << "before run kernel" << std::endl; } // Run the GEMM - CUTLASS_CHECK(gemm_op.run()); - + CUTLASS_CHECK(gemm_op.run(stream)); + // syclcompat::wait(); if (collect_gflops){ std::cout << "collect_gflops:" << collect_gflops << std::endl; GPU_Clock timer; timer.start(); for (int iter = 0; iter < 100; ++iter) { - CUTLASS_CHECK(gemm_op.run()); + CUTLASS_CHECK(gemm_op.run(stream)); } syclcompat::wait(); @@ -362,14 +362,14 @@ void allocate(const Options &options) { std::cout << " GFLOPS : " << gflops << std::endl; } - stream->throw_asynchronous(); + stream.throw_asynchronous(); release(); return cutlass::Status::kSuccess; } }; void kernel_functor( - sycl::queue* stream, + sycl::queue& stream, void* ptr_A, void* ptr_B, void* ptr_D, @@ -400,7 +400,7 @@ void kernel_functor( using ElementScale = cutlass::bfloat16_t; using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; @@ -439,7 +439,7 @@ void kernel_functor( ElementA, cutlass::gemm::TagToStrideA_t, ElementB, - cutlass::gemm::TagToStrideA_t, + cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, void, cute::identity, // A GmemTiledCopyB, void, void, cute::identity // B diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index aaea85c..8a22f92 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -53,13 +53,16 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): ref_A = input_A.clone() # weight input_B = torch.randn((num_experts, n, k), dtype=torch.bfloat16, device="xpu") - input_B = input_B #.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = input_B.transpose(-1, -2).contiguous().transpose(-1, -2) # output offset output = torch.empty((sum(token_per_group), n), dtype=torch.bfloat16, device="xpu") + print("input A is ", input_A) + print("input_B is ", input_B) + print("K sum is ", input_B[0].sum(dim=-1)) cutlass_grouped_gemm(input_A, input_B, output, token_per_group, n, k, num_experts) - + torch.xpu.synchronize() # ref gg ref = [] pre_token_sum = 0 @@ -68,7 +71,8 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): if cur_token_num == 0: continue input = ref_A[pre_token_sum:pre_token_sum + cur_token_num, :] - print("refA ptr",i, ":", hex(input.data_ptr())) + # print("refA ptr",i, ":", hex(input.data_ptr())) + print("refA ", ref_A, ref_A.shape) weight = input_B[i, :, :] expert_output = input @ weight.T ref.append(expert_output) @@ -77,11 +81,11 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): print("kernel:", output) print("reference:", ref) - print(torch.allclose(output, ref, rtol=1, atol=1)) + print(torch.allclose(output, ref, rtol=1e-2, atol=1e-2)) max_diff = (output - ref).abs().max() print("Max absolute difference:", max_diff) try: - torch.testing.assert_close(output, ref, rtol=1, atol=1) + torch.testing.assert_close(output, ref, rtol=1e-2, atol=1e-2) print("a 和 b 足够接近 ✅") except AssertionError as e: print("a 和 b 有差异 ❌") @@ -145,7 +149,7 @@ def test_fused_moe( # todo: seed verbose = False # Setup test data - a = torch.ones((m, k), device=DEVICE, dtype=dtype) / 10 + a = torch.randn((m, k), device=DEVICE, dtype=dtype) / 10 w13 = torch.randn((e, 2 * n, k), device=DEVICE, dtype=dtype) / 10 w2 = torch.randn((e, k, n), device=DEVICE, dtype=dtype) / 10 ref_a = a.clone() @@ -186,17 +190,22 @@ def test_fused_moe( print(torch.allclose(out, ref_out, rtol=1, atol=1)) max_diff = (out - ref_out).abs().max() print("Max absolute difference:", max_diff) - + try: + torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) + print("a 和 b 足够接近 ✅") + except AssertionError as e: + print("a 和 b 有差异 ❌") + print(e) if __name__ == "__main__": - # test_fused_moe( - # m = 4, - # n = 8192, - # k = 5120, - # e = 16, - # topk = 1, - # ep_size = 1, - # dtype = torch.bfloat16 - # ) - test_grouped_gemm(num_experts=2, n=4096, k=4096, token_per_group=[512,512]) + test_fused_moe( + m = 4, + n = 8192, + k = 5120, + e = 16, + topk = 1, + ep_size = 1, + dtype = torch.bfloat16 + ) + # test_grouped_gemm(num_experts=2, n=5120, k=8192, token_per_group=[2,2]) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 46936ff..dc7dcb6 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -150,7 +150,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ # ipdb.set_trace() ########### gemm1 ################## print("@@@@@@@ cutlass fused moe enter") - input_B = w13 #.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta, num_experts) print("**python offset", offset) @@ -167,7 +167,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ ########### gemm2 ################## input_A = act_output.contiguous() - input_B = w2 #.transpose(-1, -2).contiguous().transpose(-1, -2) + input_B = w2.transpose(-1, -2).contiguous().transpose(-1, -2) gemm_args = prepare_gemm_args(hidden_size, intermediate_size, offset, input_A, input_B, gemm2_output, alpha, beta, num_experts) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=hidden_size, K=intermediate_size, **gemm_args) From d1edf177f9076097d534b2f1a9fdac094715caa7 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 23 Sep 2025 02:20:29 +0000 Subject: [PATCH 38/47] base Signed-off-by: Ma, Liangliang --- csrc/xpu/cutlass_backend/grouped_gemm.h | 22 +++++-- tests/cutlass/profile_moe.py | 76 +++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 tests/cutlass/profile_moe.py diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_backend/grouped_gemm.h index 2442538..19114bd 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_backend/grouped_gemm.h @@ -127,15 +127,15 @@ struct Options { std::cout << "Options()" << std::endl; } int group_cnt = 0; - std::cout << "****Options() num_of_expert " << num_of_expert << std::endl; + // std::cout << "****Options() num_of_expert " << num_of_expert << std::endl; for (int i = 0; i < num_of_expert; ++i){ - std::cout << "****Options() i " << i << std::endl; - std::cout << "****Options() offset[i] " << offset[i] << std::endl; + // std::cout << "****Options() i " << i << std::endl; + // std::cout << "****Options() offset[i] " << offset[i] << std::endl; if (offset[i] != 0){ group_cnt++; } } - std::cout << "****Options() group_cnt " << group_cnt << std::endl; + // std::cout << "****Options() group_cnt " << group_cnt << std::endl; problem_sizes_host.reserve(group_cnt); for (int i = 0; i < num_of_expert; ++i){ if (offset[i] != 0){ @@ -344,7 +344,18 @@ void allocate(const Options &options) { std::cout << "before run kernel" << std::endl; } // Run the GEMM + + GPU_Clock timer; + timer.start(); CUTLASS_CHECK(gemm_op.run(stream)); + stream.wait(); + // syclcompat::wait(); + // stream.throw_asynchronous(); + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(1); + std::cout << " Avg runtimei : " << cute_average_time << " ms" << std::endl; + + // syclcompat::wait(); if (collect_gflops){ std::cout << "collect_gflops:" << collect_gflops << std::endl; @@ -379,7 +390,7 @@ void kernel_functor( int64_t N, int64_t K, int64_t groups){ - // + // // Run examples // auto offset_ptr = reinterpret_cast(offset); @@ -464,7 +475,6 @@ void kernel_functor( reinterpret_cast(ptr_D), reinterpret_cast(ptr_alpha), reinterpret_cast(ptr_beta)); - } } // namespace grouped_gemm diff --git a/tests/cutlass/profile_moe.py b/tests/cutlass/profile_moe.py new file mode 100644 index 0000000..a5c1637 --- /dev/null +++ b/tests/cutlass/profile_moe.py @@ -0,0 +1,76 @@ +import torch +import torch.profiler +import intel_extension_for_pytorch +from vllm_xpu_kernels.fused_moe_interface import cutlass_grouped_gemm + +def linear_silu_mul(x): + half = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (half,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + return torch.ops.torch_ipex.silu_and_mul(x, out) + +def test_moe_gemm(n_experts, intermediate_size, hidden_size, tokens, topk, dtype): + + total_m = tokens * topk + + input_a1 = torch.randn(total_m, hidden_size, dtype=dtype, device="xpu") + input_a2 = input_a1.clone() + w13 = torch.randn(n_experts, 2*intermediate_size, hidden_size, dtype=dtype, device="xpu") + w2 = torch.randn(n_experts, hidden_size, intermediate_size, dtype=dtype, device="xpu") + w13 = w13.transpose(1, 2).contiguous() + w2 = w2.transpose(1, 2).contiguous() + + offset = [int(total_m//n_experts)] * n_experts + rows_for_experts = torch.tensor(offset, device="xpu", dtype=torch.int32) + rows_for_experts_ipex = rows_for_experts.to(torch.int32).to("cpu") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU # ✅ 使用 XPU + ], + # record_shapes=True, + with_stack=False, + # profile_memory=True, + with_flops=True, + ) as prof: + + with torch.profiler.record_function("#####ipex_moegemm_1"): + ipex_o1 = torch.xpu.moe_gemm(input_a1, w13, rows_for_experts, rows_for_experts_ipex, n_experts) + torch.xpu.synchronize() + + ipex_o2 = linear_silu_mul(ipex_o1) + torch.xpu.synchronize() + + with torch.profiler.record_function("#####ipex_moegemm_2"): + ipex_o3 = torch.xpu.moe_gemm(ipex_o2, w2, rows_for_experts, rows_for_experts_ipex, n_experts) + torch.xpu.synchronize() + + + w13 = w13.transpose(1, 2) + w2 = w2.transpose(1, 2) + + cutlass_o1 = torch.empty((total_m, 2*intermediate_size), dtype=dtype, device="xpu") + cutlass_o3 = torch.empty((total_m, hidden_size), dtype=dtype, device="xpu") + with torch.profiler.record_function("@@@@@cutlass_moegemm_1"): + cutlass_grouped_gemm(input_a2, w13, cutlass_o1, offset, 2*intermediate_size, hidden_size, n_experts) + torch.xpu.synchronize() + cutlass_o2 = linear_silu_mul(cutlass_o1) + torch.xpu.synchronize() + + with torch.profiler.record_function("@@@@@cutlass_moegemm_2"): + cutlass_grouped_gemm(cutlass_o2, w2, cutlass_o3, offset, hidden_size, intermediate_size, n_experts) + torch.xpu.synchronize() + + + print(prof.key_averages().table( + sort_by="self_xpu_time_total", # 以XPU耗时排序 + row_limit=-1 + )) + + print("ipex out: \n", ipex_o3, ipex_o3.shape) + print("cutlass out: \n", cutlass_o3, cutlass_o3.shape) + torch.testing.assert_close(ipex_o3.to(float), cutlass_o3.to(float), rtol=1e-2, atol=1e-2) + +if __name__ == "__main__": + test_moe_gemm(n_experts=16, intermediate_size=8192, hidden_size=5120, topk=1, dtype=torch.bfloat16, tokens=16*512) From 55f36a808c38d13a30228799fdd647aa3469d9de Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 23 Sep 2025 05:22:23 +0000 Subject: [PATCH 39/47] update CMakeLists Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 132 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 91 insertions(+), 41 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 09e00ef..c8a7466 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,8 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21") # set(TORCH_SUPPORTED_VERSION_XPU "2.8.0") +set(FA2_ENABLED ON) + # # Try to find python package with an executable that exactly matches # `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions. @@ -155,48 +157,16 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") "csrc/quantization/fp8/fp8_quant.cpp" ) include_directories("/usr/include") - set(CMPLR_ROOT $ENV{CMPLR_ROOT}) - set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) - list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) - list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") - list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) -endif() - -message(STATUS "Enabling C extension.") -define_gpu_extension_target( - _C - DESTINATION vllm_xpu_kernels - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_EXT_SRC} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} - INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} - USE_SABI 3 - WITH_SOABI) - -# -# xpu only ops/kernels, implemented with cutlass/onednn/sycl. -# -file(GLOB CUTLASS_BACKEND_SRCS - csrc/xpu/cutlass_backend/*.cpp -) - -if(VLLM_GPU_LANG STREQUAL "SYCL") - set(VLLM_EXT_XPU_SRC - "csrc/xpu/torch_bindings.cpp" - ${CUTLASS_BACKEND_SRCS} - ) - include_directories("/usr/include") list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/) list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/) - list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/) + list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/syclcompat/) message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}") set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + + # add cutlass dependency set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") @@ -206,7 +176,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( cutlass-sycl - GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl.git + GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE @@ -243,21 +213,98 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") endif() +message(STATUS "Enabling C extension.") +define_gpu_extension_target( + _C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + +# +# flash attention _C extension +# + +if (FA2_ENABLED) + message(STATUS "Enabling fa2 extension.") + file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp") + + set(CUTLASS_GPU_FLAGS ${VLLM_GPU_FLAGS}) + set(CUTLASS_LINK_FLAGS ${VLLM_GPU_LINK_FLAGS}) + + # XPU FLAGS + list(APPEND CUTLASS_GPU_FLAGS "-O3" "-DNDEBUG") + list(APPEND CUTLASS_GPU_FLAGS "-gline-tables-only") + list(APPEND CUTLASS_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10") + + list(APPEND CUTLASS_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen") + list(APPEND CUTLASS_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread") + + define_gpu_extension_target( + _vllm_fa2_C + DESTINATION vllm_xpu_kernels + LANGUAGE ${VLLM_GPU_LANG} + SOURCES + csrc/flash_attn/flash_api.cpp + ${FA2_GEN_SRCS} + COMPILE_FLAGS ${CUTLASS_GPU_FLAGS} + LINK_FLAGS ${CUTLASS_LINK_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) +endif () + +# +# xpu only ops/kernels, implemented with cutlass/onednn/sycl. +# +file(GLOB CUTLASS_BACKEND_SRCS + csrc/xpu/cutlass_backend/*.cpp +) +if(VLLM_GPU_LANG STREQUAL "SYCL") + set(VLLM_EXT_XPU_SRC + "csrc/xpu/torch_bindings.cpp" + "csrc/xpu/lora/lora_shrink.cpp" + "csrc/xpu/lora/lora_expand.cpp" + ${CUTLASS_BACKEND_SRCS} + ) + include_directories("/usr/include") + set(CMPLR_ROOT $ENV{CMPLR_ROOT}) + set(CMAKE_CXX_COMPILER icpx) + set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) + list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") + list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + # CUTLASS FLAGS + list(APPEND VLLM_GPU_FLAGS "-O3" "-DNDEBUG") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen") + list(APPEND VLLM_GPU_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread") +endif() + if(ONEDNN_FOUND) set(_ONEDNN_SRC) file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp) list(APPEND VLLM_EXT_XPU_SRC ${_ONEDNN_SRC} + "csrc/xpu/sycl/deepseek_scaling_rope.cpp" ) include_directories(${ONEDNN_INCLUDE_DIR}) link_libraries(${ONEDNN_LIBRARY}) endif() - - -list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") -list(APPEND VLLM_GPU_FLAGS "-O3") -list(APPEND VLLM_GPU_FLAGS "-DNDEBUG") define_gpu_extension_target( _xpu_C DESTINATION vllm_xpu_kernels @@ -278,6 +325,9 @@ define_gpu_extension_target( # set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" + "csrc/moe/grouped_topk.cpp" + "csrc/moe/fused_grouped_topk.cpp" + "csrc/moe/topk_softmax.cpp" "csrc/moe/moe_align_sum_kernels.cpp") message(STATUS "Enabling moe extension.") From 513377af0fcdbd21df0d3c0f5873114d6959b05b Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 23 Sep 2025 06:25:19 +0000 Subject: [PATCH 40/47] refactor csrc of cutlass Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 5 +- csrc/xpu/cutlass_backend/sycl_common.hpp | 149 ------------------ .../grouped_gemm.hpp} | 16 +- .../grouped_gemm_kernel.cpp} | 1 - .../helper.h | 0 csrc/xpu/torch_bindings.cpp | 2 +- 6 files changed, 19 insertions(+), 154 deletions(-) delete mode 100644 csrc/xpu/cutlass_backend/sycl_common.hpp rename csrc/xpu/{cutlass_backend/cutlass_kernels.h => cutlass_kernels/grouped_gemm.hpp} (76%) rename csrc/xpu/{cutlass_backend/grouped_gemm.h => cutlass_kernels/grouped_gemm_kernel.cpp} (99%) rename csrc/xpu/{cutlass_backend => cutlass_kernels}/helper.h (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7801b0..10b63fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -177,7 +177,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") FetchContent_Declare( cutlass-sycl GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl - + # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE @@ -270,13 +270,14 @@ endif () # xpu only ops/kernels, implemented with cutlass/onednn/sycl. # file(GLOB CUTLASS_BACKEND_SRCS - csrc/xpu/cutlass_backend/*.cpp + csrc/xpu/cutlass_kernels/*.cpp ) if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_XPU_SRC "csrc/xpu/torch_bindings.cpp" "csrc/xpu/lora/lora_shrink.cpp" "csrc/xpu/lora/lora_expand.cpp" + ${CUTLASS_BACKEND_SRCS} ) include_directories("/usr/include") set(CMPLR_ROOT $ENV{CMPLR_ROOT}) diff --git a/csrc/xpu/cutlass_backend/sycl_common.hpp b/csrc/xpu/cutlass_backend/sycl_common.hpp deleted file mode 100644 index 06fcd44..0000000 --- a/csrc/xpu/cutlass_backend/sycl_common.hpp +++ /dev/null @@ -1,149 +0,0 @@ -/*************************************************************************************************** -* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/device/sycl_tensor_fill.h" - -/// Helper to initialize a block of device data -template -bool initialize_block(Element* block, std::size_t size, uint64_t seed=2023) { - - Element scope_max = Element(1 << cute::ceil_div(std::numeric_limits::digits, 4)); - Element scope_min = cute::is_signed::value ? Element(-scope_max) : Element(0); - - cutlass::reference::device::BlockFillRandomUniform( - block, size, seed, scope_max, scope_min, 0); - - syclcompat::wait(); - return true; -} - -template -bool initialize_block( - cutlass::DeviceAllocation& block, - uint64_t seed=2023) { - return initialize_block(block.get(), block.size(), seed); -} - -template -void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, - cutlass::DeviceAllocation& block_device_dq, - uint64_t seed) { - static_assert(cute::sizeof_bits_v >= 8); - - std::ranlux24_base rng(std::random_device{}()); - rng.seed(seed); - - T1 scope_max = T1(1 << cute::ceil_div(std::numeric_limits::digits, 4)); - T1 scope_min = cute::is_signed::value ? T1(-scope_max) : T1(0); - - std::uniform_int_distribution<> dist(scope_min, scope_max); - - if constexpr (cute::sizeof_bits_v >= 8) { - auto block_host = std::vector(block_device.size()); - auto block_host_dq = std::vector(block_device.size()); - for (int i = 0; i < block_host.size(); ++i) { - block_host[i] = static_cast(dist(rng)); - block_host_dq[i] = static_cast(block_host[i]); - } - - block_device.copy_from_host(block_host.data()); - block_device_dq.copy_from_host(block_host_dq.data()); - } else { - static constexpr auto array_size = 1024; - - cute::array_subbyte block_host{}; - auto block_host_dq = std::vector(array_size); - - for (int i = 0; i < block_host.size(); ++i) { - block_host[i] = static_cast(dist(rng)); - block_host_dq[i] = static_cast(block_host[i].get()); - } - - static constexpr auto elements_per_byte = cute::sizeof_bits_v / cute::sizeof_bits_v; - - int loop_cnt = block_device.size() / array_size; - for (int i = 0; i < loop_cnt; i++) { - cutlass::device_memory::copy_to_device(block_device.get() + (i * array_size) / elements_per_byte, - raw_pointer_cast(block_host.begin()), - array_size / elements_per_byte); - cutlass::device_memory::copy_to_device(block_device_dq.get() + i * array_size, - block_host_dq.data(), - array_size); - } - - auto tail_size = block_device.size() % array_size; - if (tail_size) { - cutlass::device_memory::copy_to_device(block_device.get() + (loop_cnt * array_size) / elements_per_byte, - raw_pointer_cast(block_host.begin()), - tail_size / elements_per_byte); - cutlass::device_memory::copy_to_device(block_device_dq.get() + loop_cnt * array_size, - block_host_dq.data(), - tail_size); - } - } -} - -template -inline -bool is_close(T a, T b, float atol, float rtol) { - return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); -} - -// TODO(Codeplay): use on device initialisation for this -template -inline -void random_fill(T *src, int seed, size_t N, float max, float min) { - if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { - std::random_device rd; - std::mt19937 gen(seed); - std::uniform_real_distribution dis(min, max); - auto buff = std::vector(N); - - for (size_t i = 0; i < N; ++i) { - buff[i] = (T)(dis(gen)); - } - syclcompat::memcpy(src, buff.data(), N); - syclcompat::wait(); - } else { - assert(0 & "Not supported dtype"); - } -} - -template -void convert_dtype(const SrcT* d_src, DstT* d_dst, size_t size) { - syclcompat::get_default_queue().parallel_for(size, [=](auto indx) { - d_dst[indx] = static_cast(d_src[indx]); - }).wait(); -} diff --git a/csrc/xpu/cutlass_backend/cutlass_kernels.h b/csrc/xpu/cutlass_kernels/grouped_gemm.hpp similarity index 76% rename from csrc/xpu/cutlass_backend/cutlass_kernels.h rename to csrc/xpu/cutlass_kernels/grouped_gemm.hpp index 3a2bdf7..92b2c54 100644 --- a/csrc/xpu/cutlass_backend/cutlass_kernels.h +++ b/csrc/xpu/cutlass_kernels/grouped_gemm.hpp @@ -8,11 +8,25 @@ /* #include "pytorch_shim.h" */ #include -#include "grouped_gemm.h" #include "utils.h" namespace gpu::cutlass_kernel { +namespace grouped_gemm { + void kernel_functor( + sycl::queue& stream, + void* ptr_A, + void* ptr_B, + void* ptr_D, + void* ptr_alpha, + void* ptr_beta, + void* offset, + int64_t N, + int64_t K, + int64_t groups); +} + + /* gemm2(group_A, w2, output, offset) */ at::Tensor grouped_gemm_func( diff --git a/csrc/xpu/cutlass_backend/grouped_gemm.h b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp similarity index 99% rename from csrc/xpu/cutlass_backend/grouped_gemm.h rename to csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp index 19114bd..2df3e10 100644 --- a/csrc/xpu/cutlass_backend/grouped_gemm.h +++ b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp @@ -87,7 +87,6 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -#include "sycl_common.hpp" #include "helper.h" #include diff --git a/csrc/xpu/cutlass_backend/helper.h b/csrc/xpu/cutlass_kernels/helper.h similarity index 100% rename from csrc/xpu/cutlass_backend/helper.h rename to csrc/xpu/cutlass_kernels/helper.h diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index e58169f..1ec4974 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -1,6 +1,6 @@ #include "core/registration.h" #include "xpu/ops.h" -#include "xpu/cutlass_backend/cutlass_kernels.h" +#include "xpu/cutlass_kernels/grouped_gemm.hpp" #include "xpu/lora/lora_ops.h" #include From 534c7c353d7ff645c9e347bf356111304919d345 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 23 Sep 2025 09:15:01 +0000 Subject: [PATCH 41/47] put src in vllm Signed-off-by: Ma, Liangliang --- .../collective/gemm/xe_array_epilogue.hpp | 511 ++++++++++++ .../collective/gemm/xe_array_mma.hpp | 300 +++++++ .../collective/gemm/xe_builder.hpp | 295 +++++++ .../collective/gemm/xe_callbacks.hpp | 785 ++++++++++++++++++ .../cutlass_kernels/grouped_gemm_kernel.cpp | 21 +- mll_build.sh | 1 + 6 files changed, 1907 insertions(+), 6 deletions(-) create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/xe_callbacks.hpp diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp new file mode 100644 index 0000000..86cd234 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp @@ -0,0 +1,511 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +// #include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class CollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelXeXMX16Group, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeXMX16Group; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2R; + using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, + CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = ElementAccumulator; + using ElementSource = typename FusionCallbacks::ElementSource; + using ElementScalar = typename FusionCallbacks::ElementScalar; + static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + + static_assert(cute::is_same_v>, + "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + using Trait_C = Copy_Traits; + using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); + using Trait_D = Copy_Traits; + using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); +private: + // constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_source_supported = false; + constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + +public: + + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideC{})); //(m, n) + using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideD{})); //(m, n) + using EpilogueTensors = cute::tuple; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNL = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto [M, N, L] = problem_shape_MNL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); + TensorC mC_mnl = make_tensor(make_gmem_ptr(ptr_C_first_batch), make_layout(make_shape(M, N, L), InternalStrideC{})); + xe_load_c = {xe_load_c.with(mC_mnl)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); + TensorD mD_mnl = make_tensor(make_gmem_ptr(ptr_D_first_batch), make_layout(make_shape(M, N, L), InternalStrideD{})); + xe_store_d = {xe_store_d.with(mD_mnl)}; + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + + bool implementable = true; + bool fusion_implementable = true; + + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma, + class LoadStoreTensor + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + int thread_idx, + LoadStoreTensor const& load_store_tensors) { + + (void) tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && + BLK_N % ATOM_N == 0 && + BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + using EpilogueTile = decltype(get<0>(params.xe_store_d.get_layoutS_MN()).shape()); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + EpilogueTile{}, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = + FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + + auto synchronize = [&] () {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + + if (is_C_load_needed) { + //cordinates for C and D are the same + copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); + } + } + } + + cst_callbacks.end(); + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + auto [M, N, K, L] = problem_shape_mnkl; + + TensorC mC_mnl; + TensorD mD_mnl; + if constexpr (is_source_supported) { + ElementC const* ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); + mC_mnl = make_tensor(make_gmem_ptr(ptr_C_curr_batch), make_layout(make_shape(M, N, L), params.dC[next_group])); + } + + if constexpr (is_destination_supported) { + ElementD* ptr_D_curr_batch = reinterpret_cast(params.ptr_D[next_group]); + mD_mnl = make_tensor(make_gmem_ptr(ptr_D_curr_batch), make_layout(make_shape(M, N, L), params.dD[next_group])); + } + return cute::make_tuple(mC_mnl, mD_mnl); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp new file mode 100644 index 0000000..42c2f8f --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelXeXMX16Group; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); + + static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); + static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using Copy_A = typename Copy_Traits::template DefaultTiledCopy; + using Copy_B = typename Copy_Traits::template DefaultTiledCopy; + + using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) + using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) + using MainloopTensors = cute::tuple; + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + struct Params { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));; + auto init_M = get<0>(problem_shape_MNK); + auto init_N = get<1>(problem_shape_MNK); + auto init_K = get<2>(problem_shape_MNK); + + return Params{ + args.ptr_A, + args.dA, + args.ptr_B, + args.dB + }; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (L > 1) { + implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int const& k_tile_count, + BlkCoord const &blk_coord, int const &K_start, int const& thread_idx, + Params const &mainloop, LoadTensors const& load_tensors) { + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + (void)thread_idx; + + Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; + Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; + + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + // TODO(Codeplay): see if we can make this nicer + // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition global counting tensors for MMA + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); + + // Retile registers for copies + Tensor tArA = thr_copy_A.retile_D(tCrA); + Tensor tBrB = thr_copy_B.retile_D(tCrB); + + // Retile global counting tensors for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a); + auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b); + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); print(gA); print("\n"); + print("tCgA : "); print(tCgA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + + print("===================== B :\n"); + print(" gB : "); print(gB); print("\n"); + print("tCgB : "); print(tCgB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); + print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + } +#endif + + // + // Mainloop + // + const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + // Copy gmem to rmem for the first k_tile + copy(tiled_copy_a, tAgA(_,_,_,k_tile), tArA); + copy(tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); + + if (prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + cute::gemm(tiled_mma, tCrA, tCrB, accum); + barrier_wait(barrier_scope); + } + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + Params const& mainloop_params, + int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + const int32_t M = get<0>(problem_shape_mnkl); + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); + ElementB const* ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[next_group]); + + Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K,(int32_t)1), mainloop_params.dA[next_group]); + Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K,(int32_t)1), mainloop_params.dB[next_group]); + + return cute::make_tuple(mA, mB); + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp new file mode 100644 index 0000000..abaa091 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include // cute::DefaultCopy +#include // cute::is_base_of_v +// #include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "xe_array_epilogue.hpp" +#include "xe_callbacks.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape +struct EpilogueTileAuto {}; + +// Used to let the builder pick the epilogue schedule automatically. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct EpilogueScheduleAuto {}; + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) +namespace detail { + +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp = cute::DefaultCopy, + class = void +> +struct CallbacksBuilder { + using Callbacks = fusion::FusionCallbacks; +}; + +// callbacks builder with callbacks passthrough +template < + class DispatchPolicy, + class FusionCallbacks, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct CallbacksBuilder< + DispatchPolicy, + FusionCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t> +> { + using Callbacks = FusionCallbacks; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +namespace detail { + template + struct FusionOpInfo { + static_assert(cutlass::detail::dependent_false, + "Could not find a builder specialization."); + }; + + template < + class ElementD, + class ElementCompute, + class ElementC + > + struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + + template < + class DispatchPolicy, + class TileShape_MNK, + class EpilogueTile, + class> + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinearCombination, + TileShape_MNK, + EpilogueTile + >; + }; + + template < + template class ActivationFn, + class ElementD, + class ElementCompute, + class ElementC + > + struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + template < + class DispatchPolicy, + class TileShape_MNK, + class EpilogueTile, + class> + + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinCombEltAct, + TileShape_MNK, + EpilogueTile + >; + }; + + template < + class GmemLayoutTagC, + template class ActivationFn, + class ElementD, + class ElementCompute, + class ElementC + > + struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + + template < + class DispatchPolicy, + class TileShape_MNK, + class EpilogueTile, + class CopyOpG2R> + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinCombDeEltAct, + TileShape_MNK, + EpilogueTile, + CopyOpG2R + >; + }; +} // namespace detail + + +// Intel epilogue builder +template < + class TileShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks + > + struct CollectiveBuilder< + arch::IntelXe, + arch::OpClassTensorOp, + TileShape_MNK, + Shape<_1, _1, _1>, // Cluster Shape + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOpOrCallbacks, + cute::enable_if_t< + cute::is_same_v && + cute::is_any_of_v && + detail::FusionOpInfo::HasBuilder + > + >{ + #ifdef SYCL_NVIDIA_TARGET + static_assert(cutlass::detail::dependent_false, + "Trying to use Intel pipeline on Non Intel hardware"); + #endif + static_assert(is_static::value); + static_assert(cute::is_any_of_v, + "ElementC needs to be one of: float, bfloat, half for the Intel pipeline"); + + using EpilogueSchedule = std::conditional_t, + IntelXeXMX16, + EpilogueScheduleType>; + static constexpr bool IsGroup = cute::is_same_v; + using DispatchPolicy = std::conditional_t; + + using StrideC = std::conditional_t>, GmemLayoutTagC, cutlass::detail::TagToStrideC_t>>; + using StrideD = std::conditional_t>, GmemLayoutTagD, cutlass::detail::TagToStrideC_t>>; + + static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); + static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); + static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-Major/Row-Major layouts for C are supported in the xe epilogue collective builder"); + static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-Major/Row-Major layouts for D are supported in the xe epilogue collective builder"); + + using CopyOpG2R = std::conditional_t, void, std::conditional_t == 32, XE_2D_U32x8x16_LD_N, XE_2D_U16x8x16_LD_N>>; + using CopyOpR2G = std::conditional_t == 32, XE_2D_U32x8x16_ST_N, XE_2D_U16x8x16_ST_N>; + + // Intel Epilogue with Linear Combination does not use shared memory + using SmemLayoutAtomC_ = void; + using CopyOpS2R_ = void; + using SmemLayoutAtomD_ = void; + using CopyOpR2S_ = void; + + //TODO(Codeplay): Should FusionCallbacks use DispatchPolicy IntelXeGroupEpilogue for group gemm? That does not work. + using FusionCallbacks = typename detail::FusionOpInfo::template FusionCallbacks< + std::conditional_t, TileShape_MNK, TileShape_MNK, CopyOpG2R>; + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, + ElementAccumulator, + StrideC, + ElementD, + StrideD, + FusionCallbacks, + CopyOpG2R, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G, + SmemLayoutAtomD_, + CopyOpR2S_ + >; + }; +} // namespace cutlass::epilogue::collective diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_callbacks.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_callbacks.hpp new file mode 100644 index 0000000..5173d77 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_callbacks.hpp @@ -0,0 +1,785 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the Intel Xe epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_splitk.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { + + using Impl = Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +template < + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombEltAct, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinCombEltAct { + + using Impl = Sm90LinCombEltAct::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar_ alpha = ElementScalar_(1); + ElementScalar_ beta = ElementScalar_(0); + ElementScalar_ const* alpha_ptr = nullptr; + ElementScalar_ const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = splitk(alpha * acc + beta * C) +template< + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombSplitK = + Sm90EVT, // splitk(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + // int FragmentSize, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + class CopyOpR2G_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombSplitK, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombSplitK { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombSplitK::type, ElementCompute, CopyOpR2G_, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombSplitK; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementOutput* output_ptr = nullptr; + ElementOutput *output_ptr1 = nullptr; + ElementOutput *output_ptr2 = nullptr; + size_t NUM_HEAD = 0; + size_t NOPE_DIM = 0; + size_t ROPE_DIM = 0; + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {output_ptr, output_ptr1, output_ptr2, NUM_HEAD, NOPE_DIM, ROPE_DIM} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = softmax(alpha * acc + beta * C) +template< + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class CopyOpR2G, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombSoftmaxRow = + Sm90EVT, // softmax(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + // int FragmentSize, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + class CopyOpR2G_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombSoftmaxRow, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombSoftmaxRow { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombSoftmaxRow::type, ElementCompute, CopyOpR2G_, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombSoftmaxRow; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementOutput* output_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {output_ptr} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +template< + class StrideAux, + class CopyOpG2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombDeEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc), aux) + Sm90LinearCombination, // beta * C + (alpha * acc) + XeAuxLoad // aux + >; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = activation(dY, Z) +// +template < + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput_, + class ElementCompute_, + class ElementAux, + class ElementSource, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class CopyOpG2R +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + CopyOpG2R +> : XeLinCombDeEltAct< + cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput_, + ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle + > { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + + using Impl = + XeLinCombDeEltAct< + cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput, + ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle + >; + using Operation = + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// D = alpha * acc + beta * C + per-row bias +template < + class ElementOutput_, + class ElementCompute_, + class ElementBias_, + class ElementSource_, + class ElementScalar_, + int AlignmentBias_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombPerRowBias, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinCombPerRowBias { + + using Impl = Sm90LinCombPerRowBias< + CtaTileShapeMNK_, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, + AlignmentBias_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementBias = ElementBias_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentBias = AlignmentBias_; + using Operation = fusion::LinCombPerRowBias; + + struct Arguments { + ElementScalar_ alpha = ElementScalar_(1); + ElementScalar_ beta = ElementScalar_(0); + ElementScalar_ const* alpha_ptr = nullptr; + ElementScalar_ const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1, _0, int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = alpha * acc + beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch, // acc + XeRowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +template < + class ElementOutput_, + class ElementCompute_, + class ElementBias_, + class ElementSource_, + class ElementScalar_, + int AlignmentBias_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombPerColBias, + CtaTileShapeMNK_, + EpilogueTile_ +> : XeLinCombPerColBias<_1{} /* Stages */, CtaTileShapeMNK_, EpilogueTile_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> { + + using Impl = XeLinCombPerColBias< + _1{}, + CtaTileShapeMNK_, + EpilogueTile_, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, + AlignmentBias_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementBias = ElementBias_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentBias = AlignmentBias_; + using Operation = fusion::LinCombPerColBias; + + struct Arguments { + ElementScalar_ alpha = ElementScalar_(1); + ElementScalar_ beta = ElementScalar_(0); + ElementScalar_ const* alpha_ptr = nullptr; + ElementScalar_ const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0, _1, int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +template < + int TopK, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombTopKSoftmaxCol, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombTopKSoftmaxCol { + + static constexpr int FragmentSize = 8; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = Sm90LinCombTopKSoftmaxCol::type, + ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombTopKSoftmaxCol; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) + +template < + // int FragmentSize, + //bool ReuseSmemC, + // bool DelayTmaStore, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_, + class ElementSource_, + class ElementScalar_, + int AlignmentBias_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombPerRowBiasEltAct< + ActivationFn_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_ + >, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK_, ActivationFn_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_ + > { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementBias = ElementBias_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentBias = AlignmentBias_; + using Impl = + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK_, ActivationFn_, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle_ + >; + using Operation = + fusion::LinCombPerRowBiasEltAct< + ActivationFn_, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle_ + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////// +// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16Group, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinearCombinationPtrArray::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { + + using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp index 2df3e10..bec7d72 100644 --- a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp +++ b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp @@ -69,14 +69,14 @@ #pragma once -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/collective/xe_array_epilogue.hpp" -#include "cutlass/epilogue/fusion/xe_callbacks.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" +// #include "cutlass/epilogue/collective/default_epilogue.hpp" +// #include "cutlass/epilogue/collective/xe_array_epilogue.hpp" +// #include "cutlass/epilogue/fusion/xe_callbacks.hpp" +// #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/collective/collective_mma.hpp" +// #include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/util/GPU_Clock.hpp" #include @@ -88,9 +88,18 @@ #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "helper.h" - #include +#include "cutlass/gemm/collective/collective_mma_decl.hpp" +// #include "./collective/gemm/gemm_universal.h" +#include "./collective/gemm/xe_array_mma.hpp" +#include "./collective/gemm/xe_array_epilogue.hpp" +#include "./collective/gemm/xe_builder.hpp" +#include "./collective/gemm/xe_callbacks.hpp" +// #include "./collective/gemm/xe_gemm_array_cooperative.hpp" +// #include "./collective/gemm/gemm_universal_adapter.hpp" + + using namespace cute; using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group diff --git a/mll_build.sh b/mll_build.sh index 47889fb..cdb80cb 100644 --- a/mll_build.sh +++ b/mll_build.sh @@ -1,2 +1,3 @@ +clear # python3 setup.py clean VLLM_TARGET_DEVICE=xpu python3 setup.py develop From 1fc69590210e8d67ae6fde29198953f67194a067 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 23 Sep 2025 09:58:36 +0000 Subject: [PATCH 42/47] add adapter src Signed-off-by: Ma, Liangliang --- .../collective/gemm/default_gemm_universal.h | 396 ++++++++ .../collective/gemm/gemm_universal.h | 442 +++++++++ .../collective/gemm/gemm_universal.hpp | 55 ++ .../collective/gemm/gemm_universal_adapter.h | 842 ++++++++++++++++++ .../collective/gemm/gemm_universal_base.h | 539 +++++++++++ .../collective/gemm/gemm_universal_k.h | 702 +++++++++++++++ .../gemm/xe_gemm_array_cooperative.hpp | 348 ++++++++ .../cutlass_kernels/grouped_gemm_kernel.cpp | 7 +- 8 files changed, 3328 insertions(+), 3 deletions(-) create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h create mode 100644 csrc/xpu/cutlass_kernels/collective/gemm/xe_gemm_array_cooperative.hpp diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h new file mode 100644 index 0000000..6b9b15d --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h @@ -0,0 +1,396 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "gemm_universal_k.h" +#include "cutlass/gemm/kernel/gemm_universal_streamk.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" + +#include "cutlass/layout/permute.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout_ = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout_ = layout::NoPermute, + /// + typename Enable = void + > +struct DefaultGemmUniversal; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultGemmUniversal< + ElementA, + LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + SharedMemoryClear, + GatherA, + GatherB, + ScatterD, + PermuteDLayout, + PermuteALayout, + PermuteBLayout, + typename platform::enable_if< ! cutlass::is_complex::value>::type +> { + + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + true, + Operator, + SharedMemoryClear, + GatherA, + GatherB, + ScatterD, + PermuteDLayout, + PermuteALayout, + PermuteBLayout + >::GemmKernel; + + /// Universal kernel without StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Universal kernel with StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = SelectBase; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Complex-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear + > +struct DefaultGemmUniversal< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + SharedMemoryClear, + false, + false, + false, + layout::NoPermute, + layout::NoPermute, + layout::NoPermute, + typename platform::enable_if::value>::type +> { + + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + TransformA, + TransformB, + Operator, + false + >::GemmKernel; + + /// Universal kernel without StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Universal kernel with StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = SelectBase; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h new file mode 100644 index 0000000..8b9dd35 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h @@ -0,0 +1,442 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "gemm_universal_k.h" + +#include "default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation + (problem geometry and data references), it can be reused across different GEMM problems having the + geometry. (Once initialized, details regarding problem geometry and references to workspace memory + cannot be updated.) + + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout_ = layout::NoPermute, + /// Permute operand A + typename PermuteALayout_ = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout_ = layout::NoPermute +> +class GemmUniversal : + public GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone, + GatherA, + GatherB, + ScatterD, + PermuteDLayout_, + PermuteALayout_, + PermuteBLayout_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone, + GatherA, + GatherB, + ScatterD, + PermuteDLayout_, + PermuteALayout_, + PermuteBLayout_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout_, + /// Permute operand A + typename PermuteALayout_, + /// Permute operand B + typename PermuteBLayout_ +> +class GemmUniversal { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversal< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA, + GatherB, + GatherA, + ScatterD, + PermuteDLayout, + PermuteBLayout, + PermuteALayout + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversal() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp new file mode 100644 index 0000000..e1eb7b9 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal_decl.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +// In cases where ProblemShape is not a tuple, this is used to check if the +// underlying problem shape type is aliased within or not. +// Used for dispatching GemmUniversal to 2.x API or 3.x API +template +struct IsCutlass3ArrayKernel : cute::false_type { }; + +template +struct IsCutlass3ArrayKernel> + : cute::true_type { }; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// +#include "xe_gemm_array_cooperative.hpp" diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h new file mode 100644 index 0000000..3633072 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h @@ -0,0 +1,842 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +#include "gemm_universal_base.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "gemm_universal.hpp" + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_event_manager.hpp" +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, new static methods + are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of GemmUniversalAdapter + on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might + differ between the two specializations. +*/ +template +class GemmUniversalAdapter; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } + else { + return 0; + } +} + +} // namespace detail + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + using GemmKernel = GetUnderlyingKernel_t; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + workspace_bytes += GemmKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if !defined(CUTLASS_ENABLE_SYCL) + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } +#endif + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = GemmKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() + static Status + run(Params& params, + sycl::queue& stream, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + +#if defined(CUTLASS_ENABLE_SYCL) + const syclcompat::dim3 sycl_block(block.x, block.y, block.z); + const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); +#endif + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif +#if !defined(CUTLASS_ENABLE_SYCL) + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + ) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + + [[maybe_unused]] void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + launch_result = cuda_adapter->launch(grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + [[maybe_unused]] void const* kernel = (void const*) device_kernel; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90; + if constexpr (kClusterLaunch) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); +#endif + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + + else { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 120 + ) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with fall-back cluster"); +#endif + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params, + launch_with_pdl); + } + } + } + + } +#endif + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if defined(CUTLASS_ENABLE_SYCL) + // sycl::queue q = stream; // ? *stream : syclcompat::get_default_queue(); +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace syclcompat::experimental; + if constexpr (cute::is_same_v) { + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)} + }, q, params); + EventManager::getInstance().addEvent(event); + } else { + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)} +#if defined(SYCL_INTEL_TARGET) + , kernel_properties{sycl_exp::sub_group_size} +#endif + }, stream, params); + EventManager::getInstance().addEvent(event); + } +#else +#if defined (SYCL_INTEL_TARGET) + constexpr bool allow_subgroup_size_prop = true; +#else + constexpr bool allow_subgroup_size_prop = false; +#endif + auto kernel_props = [] { + constexpr bool is_device_agnostic = + cute::is_same_v; + if constexpr (!allow_subgroup_size_prop or is_device_agnostic) { + using EmptyProperties = decltype(sycl::ext::oneapi::experimental::properties()); + return syclcompat::experimental::kernel_properties{}; + } else { + return syclcompat::experimental::kernel_properties{ + sycl::ext::oneapi::experimental::sub_group_size + }; + } + }(); + syclcompat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::launch_policy policy{ + sycl_grid, sycl_block, launch_props, kernel_props + }; + auto event = syclcompat::experimental::launch>(policy, stream, params); + EventManager::getInstance().addEvent(event); +#endif // !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) +#else +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace, + sycl::queue& stream, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace, + sycl::queue& stream, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + sycl::queue& stream, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(sycl::queue& stream, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 2.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + + using GemmKernel = GetUnderlyingKernel_t; + + static bool const kInternalTranspose = + !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose + cute::is_same::value; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + // warp-level, arch-level (instruction), math operator + using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + + // Operator class and arch tag extract bottom-up + // set it for top-level gemm device-level template + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using MapArguments = kernel::detail::MapArguments< + typename GemmKernel::ElementA, + typename GemmKernel::LayoutA, + GemmKernel::kTransformA, + GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, + typename GemmKernel::LayoutB, + GemmKernel::kTransformB, + GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, + kInternalTranspose + >; + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static int const kAlignmentA = MapArguments::kAlignmentA; + + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + static int const kAlignmentB = MapArguments::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename MapArguments::LayoutC; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + // C and D same type for 2.x kernel + using ElementD = ElementC; + using LayoutD = LayoutC; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalAdapter() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + if (kInternalTranspose) { + return args.transposed_problem(); + } + else { + return args; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args), cuda_adapter); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), cuda_adapter); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter); + } + + /// Lightweight update given a subset of arguments. + Status update(Arguments const &args) { + + return underlying_operator_.update(to_underlying_arguments(args)); + } + + /// Runs the kernel using initialized state. + Status run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return underlying_operator_.run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h new file mode 100644 index 0000000..bc64b3d --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h @@ -0,0 +1,539 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/gemm/gemm.h" +#include "gemm_universal_k.h" + +#include "default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template +class GemmUniversalBase { +public: + + using GemmKernel = GemmKernel_; + + /// Boolean indicating whether the CudaHostAdapter is enabled + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + /// Numerical accumulation element type + using ElementAccumulator = typename GemmKernel::Mma::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + + /// Index of the GEMM Kernel within the CudaHostAdapter + static int32_t const kGemmKernelIndex = 0; + + /// Kernel dynamic shared memory allocation requirement + /// Update the kernel function's shared memory configuration for the current device + static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage); + +protected: + + // + // Device properties (uniform across all instances of the current thread) + // + + // Device ordinal + CUTLASS_THREAD_LOCAL static int device_ordinal_; + + /// Device SM count + CUTLASS_THREAD_LOCAL static int device_sms_; + + /// Kernel SM occupancy (in thread blocks) + CUTLASS_THREAD_LOCAL static int sm_occupancy_; + +protected: + + /// Initialize static thread-local members for the thread's current device, + /// if necessary. + static Status init_device_props() + { + CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); + + cudaError_t cudart_result; + + // Get current device ordinal + int current_ordinal; + cudart_result = cudaGetDevice(¤t_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // Done if matches the current static member + if (current_ordinal == device_ordinal_) { + // Already initialized + return Status::kSuccess; + } + + // Update SM count member + cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // If requires more than 48KB: configure for extended, dynamic shared memory + if constexpr (kSharedStorageSize >= (48 << 10)) + { + cudart_result = cudaFuncSetAttribute( + Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, + kSharedStorageSize); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + } + + // Update SM occupancy member + cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &sm_occupancy_, + Kernel2, + GemmKernel::kThreadCount, + kSharedStorageSize, + cudaOccupancyDisableCachingOverride); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // Update device ordinal member on success + device_ordinal_ = current_ordinal; + + CUTLASS_TRACE_HOST(" " + "device_ordinal: (" << device_ordinal_ << "), " + "device_sms: (" << device_sms_ << "), " + "sm_occupancy: (" << sm_occupancy_ << ") " + "smem_size: (" << kSharedStorageSize << ") " + "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); + + return Status::kSuccess; + } + + +protected: + + // + // Instance data members + // + + /// Kernel parameters + typename GemmKernel::Params params_; + + + /// Initialize params member + Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + // + // Occupancy query using CudaHostAdapter::query_occupancy(). + // + + if (cuda_adapter) { + + Status status = cuda_adapter->query_occupancy( + &device_sms, + &sm_occupancy, + kGemmKernelIndex, + GemmKernel::kThreadCount, + kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return status; + } + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + + // Initialize static device properties, if necessary + Status result = init_device_props(); + + if (result != Status::kSuccess) { + return result; + } + + // + // Use thread-local static members for occupancy query initialized by call to + // `init_device_props()` + // + + device_sms = device_sms_; + sm_occupancy = sm_occupancy_; + } + + // Initialize params member + params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy); + return Status::kSuccess; + } + +public: + + //--------------------------------------------------------------------------------------------- + // Stateless API + //--------------------------------------------------------------------------------------------- + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); + + if (!kEnableCudaHostAdapter || cuda_adapter) { + + dim3 grid = get_grid_shape(args, cuda_adapter); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) + { + return Status::kErrorInvalidProblem; + } + } + else { + // + // With a null host adapter, a conservative grid shape is computed and required to conform to CUDA grid + // dimension limits. + // + + int64_t logicalGridM = (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + int64_t logicalGridN = (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + int32_t logicalGridL = args.batch_count; + + if ((int64_t(std::numeric_limits::max()) < logicalGridM) || + (int64_t(std::numeric_limits::max()) < logicalGridN) || + (int32_t(std::numeric_limits::max()) < logicalGridL)) { + + return Status::kErrorInvalidProblem; + } + + } + + return GemmKernel::can_implement(args); + } + + + /// Returns the workspace size (in bytes) needed for the problem + /// geometry expressed by these arguments + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { + return 0; + } + + // Get size from parameters + size_t workspace_bytes = base.params_.get_workspace_size(); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + return workspace_bytes; + } + + + /// Returns the grid extents in thread blocks to launch + static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { + return dim3(0,0,0); + } + + // Get dims from parameters + dim3 grid_dims = base.params_.get_grid_dims(); + + CUTLASS_TRACE_HOST( + " tiled_shape: " << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); + + return grid_dims; + } + + + /// Returns the maximum number of active thread blocks per multiprocessor + static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + if (cuda_adapter) { + + Status status = cuda_adapter->query_occupancy( + &device_sms, + &sm_occupancy, + kGemmKernelIndex, + GemmKernel::kThreadCount, + kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return -1; + } + } + else { + return -1; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + // Initialize static device properties, if necessary + if (init_device_props() != Status::kSuccess) { + return -1; + } + + sm_occupancy = sm_occupancy_; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); + return sm_occupancy; + } + + + //--------------------------------------------------------------------------------------------- + // Stateful API + //--------------------------------------------------------------------------------------------- + + /// Initializes GEMM state from arguments and workspace memory + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize parameters from args + Status result = init_params(args, cuda_adapter); + if (result != Status::kSuccess) { + return result; + } + + // Assign and prepare workspace memory + if (args.mode == GemmUniversalMode::kGemm) { + return params_.init_workspace(workspace, stream); + } + + return Status::kSuccess; + } + + + /// Lightweight update given a subset of arguments. + Status update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); + params_.update(args); + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); + + // Configure grid and block dimensions + dim3 block(GemmKernel::kThreadCount, 1, 1); + dim3 grid = params_.get_grid_dims(); + + // Launch kernel + CUTLASS_TRACE_HOST(" " + "grid: (" << grid << "), " + "block: (" << block << "), " + "SMEM: (" << kSharedStorageSize << ")"); + + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms_}; + return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if defined(CUTLASS_ENABLE_SYCL) + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + sycl::queue q = stream ? *stream : syclcompat::get_default_queue(); + syclcompat::experimental::launch>( + syclcompat::experimental::launch_policy{ + sycl_grid, sycl_block, +#if defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + sycl::ext::oneapi::experimental::work_group_scratch_size(kSharedStorageSize) +#else + syclcompat::experimental::local_mem_size{static_cast(kSharedStorageSize)} +#endif + }, + q, params_); +#else + Kernel2<<>>(params_); +#endif + + // Query for errors + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) + { + return run(stream, cuda_adapter); + } + + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) + { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Static initializers +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device ordinal +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_ordinal_ = -1; + +/// Device SM count +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_sms_ = -1; + +/// Kernel SM occupancy (in thread blocks) +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h new file mode 100644 index 0000000..b2aee30 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h @@ -0,0 +1,702 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" +#include "gemm_universal.hpp" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_universal_base.h" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +class GemmUniversal< + Mma_, + Epilogue_, + ThreadblockSwizzle_, + void, + // 3.x kernels use the first template argument to define the ProblemShape + // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API + cute::enable_if_t::value || IsCutlass3ArrayKernel::value)> +> { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments : UniversalArgumentsBase + { + // + // Data members + // + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + + typename LayoutA::Stride stride_a; + typename LayoutB::Stride stride_b; + typename LayoutC::Stride stride_c; + typename LayoutC::Stride stride_d; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; + + int const * ptr_gather_A_indices; + int const * ptr_gather_B_indices; + int const * ptr_scatter_D_indices; + + // + // Methods + // + + Arguments(): + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr) + {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d, + int const *ptr_gather_A_indices = nullptr, + int const *ptr_gather_B_indices = nullptr, + int const *ptr_scatter_D_indices = nullptr) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), + stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), + ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) + { + lda = 0; + ldb = 0; + ldc = 0; + ldd = 0; + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + int const *ptr_gather_A_indices = nullptr, + int const *ptr_gather_B_indices = nullptr, + int const *ptr_scatter_D_indices = nullptr + ): + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), + ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) + { + stride_a = make_Coord(lda); + stride_b = make_Coord(ldb); + stride_c = make_Coord(ldc); + stride_d = make_Coord(ldd); + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const + { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.stride_a, args.stride_b); + std::swap(args.batch_stride_A, args.batch_stride_B); + std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); + + return args; + } + }; + + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB>; + + // + // Data members + // + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + + typename EpilogueOutputOp::Params output_op; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + + int * ptr_gather_A_indices; + int * ptr_gather_B_indices; + int * ptr_scatter_D_indices; + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), + params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), + params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), + params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), + params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), + output_op(args.epilogue), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), + ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) + {} + + /// Lightweight update given a subset of arguments. + void update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + + // Update input/output pointers + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); + ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); + ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); + + output_op = args.epilogue; + } + + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + +public: + + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) + { + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + + static int const kAlignmentA = (cute::is_same>::value) + ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (cute::is_same>::value) + ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (cute::is_same>::value) + ? 32 + : (cute::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (cute::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (cute::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (cute::is_same>::value + || cute::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (cute::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (cute::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (cute::is_same>::value + || cute::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (cute::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (cute::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (cute::is_same>::value + || cute::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmUniversal op; + op(params, shared_storage); + } + + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + + syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = ThreadIdxX(); + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.ptr_gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B, + params.ptr_gather_B_indices); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = ThreadIdxX() % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.ptr_scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.ptr_scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_gemm_array_cooperative.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_gemm_array_cooperative.hpp new file mode 100644 index 0000000..61bae38 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_gemm_array_cooperative.hpp @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or cute::rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_same_v, + "Only Group Scheduler is supported with this code."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape, 0, ProblemShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + + using MainloopTensors = typename CollectiveMainloop::MainloopTensors; + using EpilogueTensors = typename CollectiveEpilogue::EpilogueTensors; + + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static_assert(cute::is_same_v>); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace_ptr); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace_ptr), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace_ptr), + hw_info, + scheduler, + workspace + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + + implementable = implementable && (args.mode == GemmUniversalMode::kGrouped || + (args.mode == GemmUniversalMode::kBatched && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3)); + + implementable = implementable && TileScheduler::can_implement(args.scheduler); + + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, -1); + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, -1); + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (BLK_M,BLK_N,BLK_K) + + int thread_idx = int(ThreadIdxX()); + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + bool did_group_change = true; + int32_t curr_group = -1; + using ProblemShapeMNKL = Shape; + ProblemShapeMNKL problem_shape_MNKL; + MainloopTensors AB_tensors; + EpilogueTensors CD_tensors; + + if (work_tile_info.is_valid()) { + curr_group = work_tile_info.L_idx; + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_group), 1); + } + + while (work_tile_info.is_valid()) { + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + Tensor mA_mkl = cute::get_xe_tensor(make_shape(M,K,L)); //(m,k,l) + Tensor mB_nkl = cute::get_xe_tensor(make_shape(N,K,L)); //(n,k,l) + + auto m_coord = work_tile_info.M_idx; + auto n_coord = work_tile_info.N_idx; + + auto gA_mkl = local_tile(mA_mkl, select<0,2>(workgroup_shape), make_coord(m_coord, _, 0)); + auto gB_nkl = local_tile(mB_nkl, select<1,2>(workgroup_shape), make_coord(n_coord, _, 0)); + + CollectiveMainloop collective_mma; + if(did_group_change) { + AB_tensors = collective_mma.update_tensor_shape_stride(params.mainloop, curr_group, problem_shape_MNKL); + } + auto tile_coord = make_coord(m_coord, n_coord, _, 0); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + int work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, workgroup_shape); + int work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, make_shape(K)), make_shape(K)); + + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(workgroup_shape)); + + // Perform the collective scoped MMA + collective_mma( + accumulators, + gA_mkl, + gB_nkl, + accumulators, + k_tile_iter, work_k_tile_count, + tile_coord, + K, + thread_idx, + params.mainloop, + AB_tensors + ); + + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, -1, -1); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + if(did_group_change) { + CD_tensors = epilogue.update_tensor_shape_stride(curr_group, problem_shape_MNKL); + did_group_change = false; + } + + epilogue( + problem_shape_MNKL, + subgroup_shape, + tile_coord, + accumulators, + tiled_mma, + thread_idx, + CD_tensors + ); + } + + // Get next work tile + auto [next_work_tile_info, temp] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + + did_group_change = curr_group != work_tile_info.L_idx; + + if(did_group_change && work_tile_info.is_valid()) { + curr_group = work_tile_info.L_idx; + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_group), 1); + } + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp index bec7d72..4a9ea80 100644 --- a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp +++ b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp @@ -74,8 +74,8 @@ // #include "cutlass/epilogue/fusion/xe_callbacks.hpp" // #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/device/gemm_universal.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" +// #include "cutlass/gemm/device/gemm_universal.h" +// #include "cutlass/gemm/device/gemm_universal_adapter.h" // #include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/util/GPU_Clock.hpp" @@ -91,7 +91,8 @@ #include #include "cutlass/gemm/collective/collective_mma_decl.hpp" -// #include "./collective/gemm/gemm_universal.h" +#include "./collective/gemm/gemm_universal.h" +#include "./collective/gemm/gemm_universal_adapter.h" #include "./collective/gemm/xe_array_mma.hpp" #include "./collective/gemm/xe_array_epilogue.hpp" #include "./collective/gemm/xe_builder.hpp" From db6b2927ecbb03e4d9dabb0115cf58fb5c24aa25 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 24 Sep 2025 07:38:29 +0000 Subject: [PATCH 43/47] clean up Signed-off-by: Ma, Liangliang --- .../cutlass_kernels/grouped_gemm_kernel.cpp | 22 +++---- tests/cutlass/profile_moe.py | 6 +- tests/cutlass/test_fused_moe.py | 66 ++----------------- vllm_xpu_kernels/fused_moe_interface.py | 34 +--------- 4 files changed, 19 insertions(+), 109 deletions(-) diff --git a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp index 4a9ea80..bfd0060 100644 --- a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp +++ b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp @@ -288,9 +288,6 @@ void allocate(const Options &options) { fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; - std::cout << "grouped_gemm arguments" << std::endl; - std::cout << "options.groups " << options.groups << std::endl; - // Per-GEMM problem shape info may only exist on the device. if (host_problem_shapes_available) { arguments = typename Gemm::Arguments { @@ -357,15 +354,13 @@ void allocate(const Options &options) { GPU_Clock timer; timer.start(); CUTLASS_CHECK(gemm_op.run(stream)); - stream.wait(); - // syclcompat::wait(); - // stream.throw_asynchronous(); - float cute_time = timer.seconds() * 1000; - double cute_average_time = double(cute_time) / double(1); - std::cout << " Avg runtimei : " << cute_average_time << " ms" << std::endl; - - - // syclcompat::wait(); + if (collect_gflops){ + stream.wait(); + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(1); + std::cout << " Avg runtimei : " << cute_average_time << " ms" << std::endl; + } + if (collect_gflops){ std::cout << "collect_gflops:" << collect_gflops << std::endl; GPU_Clock timer; @@ -373,8 +368,7 @@ void allocate(const Options &options) { for (int iter = 0; iter < 100; ++iter) { CUTLASS_CHECK(gemm_op.run(stream)); } - syclcompat::wait(); - + stream.wait(); float cute_time = timer.seconds() * 1000; double cute_average_time = double(cute_time) / double(options.iterations); double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); diff --git a/tests/cutlass/profile_moe.py b/tests/cutlass/profile_moe.py index a5c1637..f846932 100644 --- a/tests/cutlass/profile_moe.py +++ b/tests/cutlass/profile_moe.py @@ -13,10 +13,10 @@ def test_moe_gemm(n_experts, intermediate_size, hidden_size, tokens, topk, dtype total_m = tokens * topk - input_a1 = torch.randn(total_m, hidden_size, dtype=dtype, device="xpu") + input_a1 = torch.randn(total_m, hidden_size, dtype=dtype, device="xpu") / 10 input_a2 = input_a1.clone() - w13 = torch.randn(n_experts, 2*intermediate_size, hidden_size, dtype=dtype, device="xpu") - w2 = torch.randn(n_experts, hidden_size, intermediate_size, dtype=dtype, device="xpu") + w13 = torch.randn(n_experts, 2*intermediate_size, hidden_size, dtype=dtype, device="xpu") / 10 + w2 = torch.randn(n_experts, hidden_size, intermediate_size, dtype=dtype, device="xpu") / 10 w13 = w13.transpose(1, 2).contiguous() w2 = w2.transpose(1, 2).contiguous() diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index 8a22f92..e11ddb5 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -1,52 +1,11 @@ -import pytest -import torch from math import ceil +import torch +import pytest from typing import Callable, Optional, Union from vllm_xpu_kernels.fused_moe_interface import cutlass_fused_moe, cutlass_grouped_gemm -NUM_EXPERTS = [8, 64, 192] -EP_SIZE = [1, 4] -TOP_KS = [2, 6] - -FUSED_MOE_MNK_FACTORS = [ - (1, 128, 128), - (1, 2048, 128), - (33, 2048, 128), - (222, 1024, 1024), - (32768, 128, 128), - (32768, 2048, 511), - (40000, 1024, 1024), -] - -FUSED_MOE_WN16_MNK_FACTORS = [ - (1, 128, 128), - (1, 1024, 1024), - (32, 2048, 128), - (32, 1024, 1024), - (222, 2048, 1024), -] - DEVICE = "xpu" - -def calculate_device_mem(m, k, n, e, topk, dtype): - total = 0 - x = m*k - w13 = e*2*n*k - w2 = e*k*n - topk_w = topk*m - topk_id = topk*m - expert_cache = x - gemm1_out = m*2*n - gemm2_out = m*k - total += x + w13 + w2 + topk_w + topk_id + expert_cache + gemm1_out + gemm2_out - byte_per_data = 4 - if dtype == torch.bfloat16: - byte_per_data = 2 - total_bytes_G = total * byte_per_data / 1000 / 1000 / 1000 - print("fused moe should take device memory: ", total_bytes_G, "G") - - def test_grouped_gemm(num_experts, n, k, token_per_group): # input input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu").contiguous() @@ -86,9 +45,9 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): print("Max absolute difference:", max_diff) try: torch.testing.assert_close(output, ref, rtol=1e-2, atol=1e-2) - print("a 和 b 足够接近 ✅") + print("a and b close enough") except AssertionError as e: - print("a 和 b 有差异 ❌") + print("a and b diffs") print(e) def ref_fused_moe(x, @@ -99,8 +58,6 @@ def ref_fused_moe(x, num_per_tok, activation, num_experts): - # import ipdb - # ipdb.set_trace() expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() counts = flat_expert_indices.bincount().cpu().numpy() @@ -131,19 +88,12 @@ def ref_fused_moe(x, return expert_cache - -# @pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("ep_size", EP_SIZE) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_fused_moe( m: int, # num of tokens n: int, # intermediate_size k: int, # hidden_size e: int, topk: int, - ep_size: int, dtype: torch.dtype, ): # todo: seed @@ -187,14 +137,11 @@ def test_fused_moe( print("ref result", ref_out, ref_out.shape) print("kernel result", out, out.shape) - print(torch.allclose(out, ref_out, rtol=1, atol=1)) - max_diff = (out - ref_out).abs().max() - print("Max absolute difference:", max_diff) try: torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) - print("a 和 b 足够接近 ✅") + print("a and b close enough") except AssertionError as e: - print("a 和 b 有差异 ❌") + print("a and b diffs") print(e) @@ -205,7 +152,6 @@ def test_fused_moe( k = 5120, e = 16, topk = 1, - ep_size = 1, dtype = torch.bfloat16 ) # test_grouped_gemm(num_experts=2, n=5120, k=8192, token_per_group=[2,2]) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index dc7dcb6..797e2af 100644 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -8,7 +8,6 @@ def prepare_gemm_args(n, k, offset, A, B, D, alpha, beta, e): if not hasattr(prepare_gemm_args, "gemm_args"): - print("@cutlass fusedMoe allocate gemm args once") gemm_args = {} device = A.device ptr_A = torch.empty(e*8, dtype=torch.uint8, device=device).contiguous() @@ -23,9 +22,6 @@ def prepare_gemm_args(n, k, offset, A, B, D, alpha, beta, e): gemm_args["ptr_beta"] = ptr_beta prepare_gemm_args.gemm_args = gemm_args - # gemm_args = {} - - # problem_sizes = [] ptr_A = prepare_gemm_args.gemm_args["ptr_A"] ptr_B = prepare_gemm_args.gemm_args["ptr_B"] ptr_D = prepare_gemm_args.gemm_args["ptr_D"] @@ -39,10 +35,7 @@ def process_data_ptr(tensor, offset, addr_tensor, dim, group): mul = 2 if tensor.dtype == torch.float32: mul = 4 - # print("process data_ptr:", tensor.shape) - # print(tensor.data_ptr()) - # print(offset*mul) - # addr = tensor.data_ptr() + offset*mul + if dim == 1: addr = tensor[offset].data_ptr() elif dim == 2: @@ -67,19 +60,6 @@ def process_data_ptr(tensor, offset, addr_tensor, dim, group): groups += 1 total_elements_B += 1; - # problem_sizes = torch.tensor(problem_sizes, dtype=torch.int64, device='cpu').contiguous() - # ptr_A = torch.tensor(ptr_A, dtype=torch.uint8, device=device).contiguous() - # ptr_B = torch.tensor(ptr_B, dtype=torch.uint8, device=device).contiguous() - # ptr_D = torch.tensor(ptr_D, dtype=torch.uint8, device=device).contiguous() - # ptr_alpha = torch.tensor(ptr_alpha, dtype=torch.uint8, device=device).contiguous() - # ptr_beta = torch.tensor(ptr_beta, dtype=torch.uint8, device=device).contiguous() - - # gemm_args["problem_sizes"] = problem_sizes - # gemm_args["ptr_A"] = ptr_A - # gemm_args["ptr_B"] = ptr_B - # gemm_args["ptr_D"] = ptr_D - # gemm_args["ptr_alpha"] = ptr_alpha - # gemm_args["ptr_beta"] = ptr_beta prepare_gemm_args.gemm_args["groups"] = e # FIXME: groups return prepare_gemm_args.gemm_args @@ -98,7 +78,6 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ intermediate_size = list(w2.shape)[-1] total_input_size = token_cnt * n_experts_per_token if not hasattr(cutlass_fused_moe, "moe_buffer"): - print("@cutlass fusedMoe allocate moe_buffer once") moe_buffer = {} device = hidden_states.device moe_buffer["expert_cache"] = torch.empty((token_cnt* hidden_size), @@ -146,19 +125,12 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ while len(offset) < num_experts: offset.append(0) - # import ipdb - # ipdb.set_trace() ########### gemm1 ################## - print("@@@@@@@ cutlass fused moe enter") input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) assert(list(input_A.shape)[0] == total_input_size) gemm_args = prepare_gemm_args(2*intermediate_size, hidden_size, offset, input_A, input_B, gemm1_output, alpha, beta, num_experts) - print("**python offset", offset) offset_t = torch.tensor(offset, dtype=torch.int64, device='cpu') - print("**python offset_t", offset_t) torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset_t, N=2*intermediate_size, K=hidden_size, **gemm_args) - print("@@@@@@@ cutlass fused moe gemm1 done") - print("gemm1 out", gemm1_output) # act gate, up_ = torch.split(gemm1_output, intermediate_size, dim=1) act = torch.nn.SiLU() @@ -186,6 +158,4 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, n_experts_ expert_out, reduce='sum' ) - print("@@@@@@@ cutlass fused moe gemm2 done") - hidden_states.copy_(expert_cache) - return hidden_states + return expert_cache From d651d9daaff8d7b59592710f4a4105ced931b4c5 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 24 Sep 2025 09:08:58 +0000 Subject: [PATCH 44/47] add test Signed-off-by: Ma, Liangliang --- CMakeLists.txt | 6 +- .../collective/gemm/xe_array_epilogue.hpp | 9 ++- .../collective/gemm/xe_array_mma.hpp | 1 - tests/cutlass/profile_moe.py | 2 +- tests/cutlass/test_fused_moe.py | 75 +++++++++++-------- 5 files changed, 52 insertions(+), 41 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 10b63fc..8faee83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,12 +171,12 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "dev" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( cutlass-sycl - GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl + GIT_REPOSITORY https://github.com/intel/cutlass-sycl # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} @@ -185,7 +185,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) # cutlass compilation flags diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp index 86cd234..8d59a77 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp @@ -369,7 +369,11 @@ class CollectiveEpilogue< auto m_sg = get_sub_group_id() / ATOM_N; auto n_sg = get_sub_group_id() % ATOM_N; - using EpilogueTile = decltype(get<0>(params.xe_store_d.get_layoutS_MN()).shape()); + // Get the layout and reconstruct the MN mapping equivalent to the old get_layoutS_MN() + auto layoutS_TV = params.xe_store_d.get_layoutS_TV(); + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + auto layoutS_MN = right_inverse(layoutS_TV).with_shape(mn_shape); + using EpilogueTile = decltype(layoutS_MN.shape()); auto sg_local_m_coord = get_sub_group_id() / ATOM_N; auto sg_local_n_coord = get_sub_group_id() % ATOM_N; @@ -406,8 +410,7 @@ class CollectiveEpilogue< Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - - Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) // Get the fusion callbacks // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp index 42c2f8f..3a1e84a 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp @@ -36,7 +36,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tests/cutlass/profile_moe.py b/tests/cutlass/profile_moe.py index f846932..c13019c 100644 --- a/tests/cutlass/profile_moe.py +++ b/tests/cutlass/profile_moe.py @@ -70,7 +70,7 @@ def test_moe_gemm(n_experts, intermediate_size, hidden_size, tokens, topk, dtype print("ipex out: \n", ipex_o3, ipex_o3.shape) print("cutlass out: \n", cutlass_o3, cutlass_o3.shape) - torch.testing.assert_close(ipex_o3.to(float), cutlass_o3.to(float), rtol=1e-2, atol=1e-2) + torch.testing.assert_close(ipex_o3.to(float), cutlass_o3.to(float), rtol=0, atol=2e-2) if __name__ == "__main__": test_moe_gemm(n_experts=16, intermediate_size=8192, hidden_size=5120, topk=1, dtype=torch.bfloat16, tokens=16*512) diff --git a/tests/cutlass/test_fused_moe.py b/tests/cutlass/test_fused_moe.py index e11ddb5..9bc353a 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/cutlass/test_fused_moe.py @@ -1,25 +1,47 @@ -from math import ceil import torch import pytest +from math import ceil from typing import Callable, Optional, Union from vllm_xpu_kernels.fused_moe_interface import cutlass_fused_moe, cutlass_grouped_gemm +from tests.utils import seed_everything +import random DEVICE = "xpu" -def test_grouped_gemm(num_experts, n, k, token_per_group): +# shape for Llama-4-scout +FUSED_MOE_MNK_FACTORS = [ + (1, 5120, 8192), + (4, 5120, 8192), + (16, 5120, 8192), + (8192, 5120, 8192), +] +NUM_EXPERTS = [16] +TOP_KS = [1] + +def random_partition(size_a: int, target: int): + cuts = sorted(random.sample(range(target + size_a - 1), size_a - 1)) + cuts = [-1] + cuts + [target + size_a - 1] + result = [cuts[i+1] - cuts[i] - 1 for i in range(size_a)] + return result + +@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_grouped_gemm(m, n, k, e, topk, dtype): + seed_everything(7) + num_experts = e + token_per_group = random_partition(e, m*topk) + print(token_per_group) # input - input_A = torch.randn((sum(token_per_group), k), dtype=torch.bfloat16, device="xpu").contiguous() + input_A = torch.randn((sum(token_per_group), k), dtype=dtype, device=DEVICE).contiguous() ref_A = input_A.clone() # weight - input_B = torch.randn((num_experts, n, k), dtype=torch.bfloat16, device="xpu") + input_B = torch.randn((num_experts, n, k), dtype=dtype, device=DEVICE) input_B = input_B.transpose(-1, -2).contiguous().transpose(-1, -2) # output offset - output = torch.empty((sum(token_per_group), n), dtype=torch.bfloat16, device="xpu") - - print("input A is ", input_A) - print("input_B is ", input_B) - print("K sum is ", input_B[0].sum(dim=-1)) + output = torch.empty((sum(token_per_group), n), dtype=dtype, device=DEVICE) cutlass_grouped_gemm(input_A, input_B, output, token_per_group, n, k, num_experts) torch.xpu.synchronize() # ref gg @@ -30,19 +52,12 @@ def test_grouped_gemm(num_experts, n, k, token_per_group): if cur_token_num == 0: continue input = ref_A[pre_token_sum:pre_token_sum + cur_token_num, :] - # print("refA ptr",i, ":", hex(input.data_ptr())) - print("refA ", ref_A, ref_A.shape) weight = input_B[i, :, :] expert_output = input @ weight.T ref.append(expert_output) pre_token_sum += cur_token_num ref = torch.cat(ref, dim=0) - print("kernel:", output) - print("reference:", ref) - print(torch.allclose(output, ref, rtol=1e-2, atol=1e-2)) - max_diff = (output - ref).abs().max() - print("Max absolute difference:", max_diff) try: torch.testing.assert_close(output, ref, rtol=1e-2, atol=1e-2) print("a and b close enough") @@ -88,7 +103,7 @@ def ref_fused_moe(x, return expert_cache -def test_fused_moe( +def check_fused_moe( m: int, # num of tokens n: int, # intermediate_size k: int, # hidden_size @@ -96,7 +111,7 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - # todo: seed + seed_everything(7) verbose = False # Setup test data a = torch.randn((m, k), device=DEVICE, dtype=dtype) / 10 @@ -137,21 +152,15 @@ def test_fused_moe( print("ref result", ref_out, ref_out.shape) print("kernel result", out, out.shape) - try: - torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) - print("a and b close enough") - except AssertionError as e: - print("a and b diffs") - print(e) if __name__ == "__main__": - test_fused_moe( - m = 4, - n = 8192, - k = 5120, - e = 16, - topk = 1, - dtype = torch.bfloat16 - ) - # test_grouped_gemm(num_experts=2, n=5120, k=8192, token_per_group=[2,2]) + # check_fused_moe( + # m = 4, + # n = 8192, + # k = 5120, + # e = 16, + # topk = 1, + # dtype = torch.bfloat16 + # ) + test_grouped_gemm(num_experts=16, n=5120, k=8192, token_per_group=[512]*16) From a29cfa602bc6e2291bcb4847b4ff2f3619971520 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 24 Sep 2025 09:18:09 +0000 Subject: [PATCH 45/47] clean up Signed-off-by: Ma, Liangliang --- csrc/core/registration.h | 3 - mll_build.sh | 3 - tests/cutlass/profile_moe.py | 76 ------------------- .../{cutlass => fused_moe}/test_fused_moe.py | 11 --- 4 files changed, 93 deletions(-) delete mode 100644 mll_build.sh delete mode 100644 tests/cutlass/profile_moe.py rename tests/{cutlass => fused_moe}/test_fused_moe.py (95%) diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 5ee3e56..576b5e1 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -1,6 +1,4 @@ #pragma once -#pragma push_macro("printf") -#undef printf #include #define _CONCAT(A, B) A##B @@ -33,4 +31,3 @@ nullptr}; \ return PyModule_Create(&module); \ } -#pragma pop_macro("printf") diff --git a/mll_build.sh b/mll_build.sh deleted file mode 100644 index cdb80cb..0000000 --- a/mll_build.sh +++ /dev/null @@ -1,3 +0,0 @@ -clear -# python3 setup.py clean -VLLM_TARGET_DEVICE=xpu python3 setup.py develop diff --git a/tests/cutlass/profile_moe.py b/tests/cutlass/profile_moe.py deleted file mode 100644 index c13019c..0000000 --- a/tests/cutlass/profile_moe.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.profiler -import intel_extension_for_pytorch -from vllm_xpu_kernels.fused_moe_interface import cutlass_grouped_gemm - -def linear_silu_mul(x): - half = x.shape[-1] // 2 - output_shape = x.shape[:-1] + (half,) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - return torch.ops.torch_ipex.silu_and_mul(x, out) - -def test_moe_gemm(n_experts, intermediate_size, hidden_size, tokens, topk, dtype): - - total_m = tokens * topk - - input_a1 = torch.randn(total_m, hidden_size, dtype=dtype, device="xpu") / 10 - input_a2 = input_a1.clone() - w13 = torch.randn(n_experts, 2*intermediate_size, hidden_size, dtype=dtype, device="xpu") / 10 - w2 = torch.randn(n_experts, hidden_size, intermediate_size, dtype=dtype, device="xpu") / 10 - w13 = w13.transpose(1, 2).contiguous() - w2 = w2.transpose(1, 2).contiguous() - - offset = [int(total_m//n_experts)] * n_experts - rows_for_experts = torch.tensor(offset, device="xpu", dtype=torch.int32) - rows_for_experts_ipex = rows_for_experts.to(torch.int32).to("cpu") - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.XPU # ✅ 使用 XPU - ], - # record_shapes=True, - with_stack=False, - # profile_memory=True, - with_flops=True, - ) as prof: - - with torch.profiler.record_function("#####ipex_moegemm_1"): - ipex_o1 = torch.xpu.moe_gemm(input_a1, w13, rows_for_experts, rows_for_experts_ipex, n_experts) - torch.xpu.synchronize() - - ipex_o2 = linear_silu_mul(ipex_o1) - torch.xpu.synchronize() - - with torch.profiler.record_function("#####ipex_moegemm_2"): - ipex_o3 = torch.xpu.moe_gemm(ipex_o2, w2, rows_for_experts, rows_for_experts_ipex, n_experts) - torch.xpu.synchronize() - - - w13 = w13.transpose(1, 2) - w2 = w2.transpose(1, 2) - - cutlass_o1 = torch.empty((total_m, 2*intermediate_size), dtype=dtype, device="xpu") - cutlass_o3 = torch.empty((total_m, hidden_size), dtype=dtype, device="xpu") - with torch.profiler.record_function("@@@@@cutlass_moegemm_1"): - cutlass_grouped_gemm(input_a2, w13, cutlass_o1, offset, 2*intermediate_size, hidden_size, n_experts) - torch.xpu.synchronize() - cutlass_o2 = linear_silu_mul(cutlass_o1) - torch.xpu.synchronize() - - with torch.profiler.record_function("@@@@@cutlass_moegemm_2"): - cutlass_grouped_gemm(cutlass_o2, w2, cutlass_o3, offset, hidden_size, intermediate_size, n_experts) - torch.xpu.synchronize() - - - print(prof.key_averages().table( - sort_by="self_xpu_time_total", # 以XPU耗时排序 - row_limit=-1 - )) - - print("ipex out: \n", ipex_o3, ipex_o3.shape) - print("cutlass out: \n", cutlass_o3, cutlass_o3.shape) - torch.testing.assert_close(ipex_o3.to(float), cutlass_o3.to(float), rtol=0, atol=2e-2) - -if __name__ == "__main__": - test_moe_gemm(n_experts=16, intermediate_size=8192, hidden_size=5120, topk=1, dtype=torch.bfloat16, tokens=16*512) diff --git a/tests/cutlass/test_fused_moe.py b/tests/fused_moe/test_fused_moe.py similarity index 95% rename from tests/cutlass/test_fused_moe.py rename to tests/fused_moe/test_fused_moe.py index 9bc353a..bd145a1 100644 --- a/tests/cutlass/test_fused_moe.py +++ b/tests/fused_moe/test_fused_moe.py @@ -153,14 +153,3 @@ def check_fused_moe( print("ref result", ref_out, ref_out.shape) print("kernel result", out, out.shape) - -if __name__ == "__main__": - # check_fused_moe( - # m = 4, - # n = 8192, - # k = 5120, - # e = 16, - # topk = 1, - # dtype = torch.bfloat16 - # ) - test_grouped_gemm(num_experts=16, n=5120, k=8192, token_per_group=[512]*16) From c66f152db527f30e0da5c96160a918cd2a0b82b3 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 24 Sep 2025 09:20:49 +0000 Subject: [PATCH 46/47] fix format Signed-off-by: Ma, Liangliang --- .../collective/gemm/default_gemm_universal.h | 218 ++--- .../collective/gemm/gemm_universal.h | 240 ++--- .../collective/gemm/gemm_universal.hpp | 34 +- .../collective/gemm/gemm_universal_adapter.h | 684 +++++++------- .../collective/gemm/gemm_universal_base.h | 293 +++--- .../collective/gemm/gemm_universal_k.h | 555 +++++------ .../collective/gemm/xe_array_epilogue.hpp | 482 +++++----- .../collective/gemm/xe_array_mma.hpp | 269 +++--- .../collective/gemm/xe_builder.hpp | 389 ++++---- .../collective/gemm/xe_callbacks.hpp | 871 +++++++++--------- .../gemm/xe_gemm_array_cooperative.hpp | 290 +++--- csrc/xpu/cutlass_kernels/grouped_gemm.hpp | 50 +- .../cutlass_kernels/grouped_gemm_kernel.cpp | 379 ++++---- csrc/xpu/cutlass_kernels/helper.h | 123 ++- csrc/xpu/ops.h | 15 +- csrc/xpu/torch_bindings.cpp | 10 +- tests/fused_moe/test_fused_moe.py | 76 +- vllm_xpu_kernels/fused_moe_interface.py | 116 ++- 18 files changed, 2447 insertions(+), 2647 deletions(-) diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h index 6b9b15d..f2743bf 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h +++ b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,25 +18,27 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief - Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with - the appropriate threadblock-scoped epilogue. + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. - Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are - accommodated by exchanging A and B operands and assuming transposed layouts. Partial - specializations here choose 'device::GemmTransposed' to implement this functionality. + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. */ @@ -119,8 +121,7 @@ template < /// Permute operand B typename PermuteBLayout_ = layout::NoPermute, /// - typename Enable = void - > + typename Enable = void> struct DefaultGemmUniversal; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -178,85 +179,38 @@ template < /// Permute operand A typename PermuteALayout, /// Permute operand B - typename PermuteBLayout -> + typename PermuteBLayout> struct DefaultGemmUniversal< - ElementA, - LayoutA, - ComplexTransform::kNone, // transform A - kAlignmentA, - ElementB, - LayoutB, - ComplexTransform::kNone, // transform B - kAlignmentB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator, - SharedMemoryClear, - GatherA, - GatherB, - ScatterD, - PermuteDLayout, - PermuteALayout, - PermuteBLayout, - typename platform::enable_if< ! cutlass::is_complex::value>::type -> { - + ElementA, LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, ElementB, LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, GatherA, GatherB, + ScatterD, PermuteDLayout, PermuteALayout, PermuteBLayout, + typename platform::enable_if< + !cutlass::is_complex::value>::type> { using DefaultGemmKernel = typename kernel::DefaultGemm< - ElementA, - LayoutA, - kAlignmentA, - ElementB, - LayoutB, - kAlignmentB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - true, - Operator, - SharedMemoryClear, - GatherA, - GatherB, - ScatterD, - PermuteDLayout, - PermuteALayout, - PermuteBLayout - >::GemmKernel; + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, + LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, + WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, + true, Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, + PermuteDLayout, PermuteALayout, PermuteBLayout>::GemmKernel; /// Universal kernel without StreamkFeature member type template - class SelectBase : - public kernel::GemmUniversal< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, - SwizzleT> - {}; + class SelectBase + : public kernel::GemmUniversal {}; /// Universal kernel with StreamkFeature member type template - class SelectBase : - public kernel::GemmUniversalStreamk< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, - SwizzleT> - {}; + class SelectBase + : public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, SwizzleT> {}; /// Select kernel by ThreadblockSwizzle's support for StreamkFeature using GemmKernel = SelectBase; @@ -310,78 +264,34 @@ template < /// Operation performed by GEMM typename Operator, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear - > + SharedMemoryClearOption SharedMemoryClear> struct DefaultGemmUniversal< - ElementA, - LayoutA, - TransformA, - kAlignmentA, - ElementB, - LayoutB, - TransformB, - kAlignmentB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator, - SharedMemoryClear, - false, - false, - false, - layout::NoPermute, - layout::NoPermute, - layout::NoPermute, - typename platform::enable_if::value>::type -> { - + ElementA, LayoutA, TransformA, kAlignmentA, ElementB, LayoutB, TransformB, + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, false, false, + false, layout::NoPermute, layout::NoPermute, layout::NoPermute, + typename platform::enable_if< + cutlass::is_complex::value>::type> { using DefaultGemmKernel = typename kernel::DefaultGemmComplex< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - TransformA, - TransformB, - Operator, - false - >::GemmKernel; + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, + TransformA, TransformB, Operator, false>::GemmKernel; /// Universal kernel without StreamkFeature member type template - class SelectBase : - public kernel::GemmUniversal< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, - SwizzleT> - {}; + class SelectBase + : public kernel::GemmUniversal {}; /// Universal kernel with StreamkFeature member type template - class SelectBase : - public kernel::GemmUniversalStreamk< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, - SwizzleT> - {}; + class SelectBase + : public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, SwizzleT> {}; /// Select kernel by ThreadblockSwizzle's support for StreamkFeature using GemmKernel = SelectBase; diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h index 8b9dd35..6dff3ba 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -58,14 +59,15 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// -/*! - GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation - (problem geometry and data references), it can be reused across different GEMM problems having the - geometry. (Once initialized, details regarding problem geometry and references to workspace memory - cannot be updated.) +/*! + GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a + given GEMM computation (problem geometry and data references), it can be + reused across different GEMM problems having the geometry. (Once initialized, + details regarding problem geometry and references to workspace memory cannot + be updated.) - The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. + The universal GEMM accommodates serial reductions, parallel reductions, + batched strided, and batched array variants. */ template < /// Element type for A matrix operand @@ -105,7 +107,8 @@ template < OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::EpilogueOutputOp, /// Threadblock-level swizzling operator - typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + typename ThreadblockSwizzle_ = + threadblock::GemmIdentityThreadblockSwizzle<>, /// Number of stages used in the pipelined mainloop int Stages = DefaultGemmConfiguration -class GemmUniversal : - public GemmUniversalBase< - typename kernel::DefaultGemmUniversal< - ElementA_, - LayoutA_, - TransformA, - AlignmentA, - ElementB_, - LayoutB_, - TransformB, - AlignmentB, - ElementC_, - LayoutC_, - ElementAccumulator_, - OperatorClass_, - ArchTag_, - ThreadblockShape_, - WarpShape_, - InstructionShape_, - EpilogueOutputOp_, - ThreadblockSwizzle_, - Stages, - Operator_, - SharedMemoryClearOption::kNone, - GatherA, - GatherB, - ScatterD, - PermuteDLayout_, - PermuteALayout_, - PermuteBLayout_ - >::GemmKernel - > { - + typename PermuteBLayout_ = layout::NoPermute> +class GemmUniversal + : public GemmUniversalBase::GemmKernel> { public: - using ElementAccumulator = ElementAccumulator_; using OperatorClass = OperatorClass_; using ArchTag = ArchTag_; @@ -193,37 +169,13 @@ class GemmUniversal : static ComplexTransform const kTransformA = TransformA; static ComplexTransform const kTransformB = TransformB; - using Base = GemmUniversalBase< - typename kernel::DefaultGemmUniversal< - ElementA_, - LayoutA_, - TransformA, - AlignmentA, - ElementB_, - LayoutB_, - TransformB, - AlignmentB, - ElementC_, - LayoutC_, - ElementAccumulator_, - OperatorClass_, - ArchTag_, - ThreadblockShape_, - WarpShape_, - InstructionShape_, - EpilogueOutputOp_, - ThreadblockSwizzle_, - Stages, - Operator_, - SharedMemoryClearOption::kNone, - GatherA, - GatherB, - ScatterD, - PermuteDLayout_, - PermuteALayout_, - PermuteBLayout_ - >::GemmKernel - >; + using Base = GemmUniversalBase::GemmKernel>; using Arguments = typename Base::Arguments; using GemmKernel = typename Base::GemmKernel; @@ -231,7 +183,8 @@ class GemmUniversal : //////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for column-major output exchanges problem size and operand. +/// Partial specialization for column-major output exchanges problem size and +/// operand. template < /// Element type for A matrix operand typename ElementA_, @@ -284,17 +237,15 @@ template < /// Permute operand A typename PermuteALayout_, /// Permute operand B - typename PermuteBLayout_ -> -class GemmUniversal { + typename PermuteBLayout_> +class GemmUniversal< + ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, + layout::ColumnMajor, // partially specialized on LayoutC + ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, + WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, + Stages, AlignmentA, AlignmentB, Operator_, TransformA, TransformB, GatherA, + GatherB, ScatterD, PermuteDLayout_, PermuteALayout_, PermuteBLayout_> { public: - using ElementA = ElementA_; using LayoutA = LayoutA_; using TensorRefA = TensorRef; @@ -323,34 +274,14 @@ class GemmUniversal::type, - ElementA, - typename layout::LayoutTranspose::type, - ElementC, - layout::RowMajor, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - kAlignmentB, - kAlignmentA, - Operator, - kTransformB, - kTransformA, - GatherB, - GatherA, - ScatterD, - PermuteDLayout, - PermuteBLayout, - PermuteALayout - >::Base; + using UnderlyingOperator = typename GemmUniversal< + ElementB, typename layout::LayoutTranspose::type, ElementA, + typename layout::LayoutTranspose::type, ElementC, + layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, Operator, + kTransformB, kTransformA, GatherB, GatherA, ScatterD, PermuteDLayout, + PermuteBLayout, PermuteALayout>::Base; using GemmKernel = typename UnderlyingOperator::GemmKernel; static int const kAlignmentC = EpilogueOutputOp::kCount; @@ -358,34 +289,32 @@ class GemmUniversal -struct IsCutlass3ArrayKernel : cute::false_type { }; +struct IsCutlass3ArrayKernel : cute::false_type {}; template -struct IsCutlass3ArrayKernel> - : cute::true_type { }; +struct IsCutlass3ArrayKernel< + ProblemShape, cute::void_t> + : cute::true_type {}; //////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel //////////////////////////////////////////////////////////////////////////////// #include "xe_gemm_array_cooperative.hpp" diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h index 3633072..90e8477 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,20 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. + \brief The universal GEMM accommodates serial reductions, parallel reductions, + batched strided, and batched array variants. */ #pragma once @@ -46,9 +47,9 @@ #include "cutlass/kernel_launch.h" #if !defined(__CUDACC_RTC__) -#include "cutlass/cluster_launch.hpp" -#include "cutlass/trace.h" -#endif // !defined(__CUDACC_RTC__) + #include "cutlass/cluster_launch.hpp" + #include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) // 2.x #include "gemm_universal_base.h" @@ -60,7 +61,7 @@ #include "gemm_universal.hpp" #if defined(CUTLASS_ENABLE_SYCL) -#include "cutlass/util/sycl_event_manager.hpp" + #include "cutlass/util/sycl_event_manager.hpp" #endif //////////////////////////////////////////////////////////////////////////////// @@ -73,14 +74,15 @@ namespace cutlass::gemm::device { GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. - It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs - to create it from the host facing arguments. For power users, new static methods - are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + It manages the lifetime of the underlying `kernel::Params` struct, and exposes + APIs to create it from the host facing arguments. For power users, new static + methods are exposed in 3.x APIs that bypass the stateful methods or + args->params lowering. It supports kernel types that implement both the 2.x and 3.0 APIs, - however, this is done by specializing the implementation of GemmUniversalAdapter - on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might - differ between the two specializations. + however, this is done by specializing the implementation of + GemmUniversalAdapter on the two kernel API types, and thus, + GemmUniversalAdapter's behaviour might differ between the two specializations. */ template class GemmUniversalAdapter; @@ -102,26 +104,26 @@ template struct has_Stages : cute::false_type {}; template -struct has_Stages> : cute::true_type {}; +struct has_Stages> + : cute::true_type {}; -template +template constexpr int stages_member(DispatchPolicy) { if constexpr (has_Stages::value) { return DispatchPolicy::Stages; - } - else { + } else { return 0; } } -} // namespace detail +} // namespace detail template -class GemmUniversalAdapter< - GemmKernel_, - cute::enable_if_t>::value>> -{ -public: +class GemmUniversalAdapter>::value>> { + public: using GemmKernel = GetUnderlyingKernel_t; using TileShape = typename GemmKernel::TileShape; using ElementA = typename GemmKernel::ElementA; @@ -134,40 +136,52 @@ class GemmUniversalAdapter< using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; // Map back to 2.x type as best as possible - using LayoutA = gemm::detail::StrideToLayoutTagA_t; - using LayoutB = gemm::detail::StrideToLayoutTagB_t; - using LayoutC = gemm::detail::StrideToLayoutTagC_t; - using LayoutD = gemm::detail::StrideToLayoutTagC_t; + using LayoutA = + gemm::detail::StrideToLayoutTagA_t; + using LayoutB = + gemm::detail::StrideToLayoutTagB_t; + using LayoutC = + gemm::detail::StrideToLayoutTagC_t; + using LayoutD = + gemm::detail::StrideToLayoutTagC_t; static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - static ComplexTransform const kTransformA = cute::is_same_v ? - ComplexTransform::kConjugate : ComplexTransform::kNone; - static ComplexTransform const kTransformB = cute::is_same_v ? - ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformA = + cute::is_same_v + ? ComplexTransform::kConjugate + : ComplexTransform::kNone; + static ComplexTransform const kTransformB = + cute::is_same_v + ? ComplexTransform::kConjugate + : ComplexTransform::kNone; // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 using MathOperator = cutlass::arch::OpMultiplyAdd; - using OperatorClass = cutlass::detail::get_operator_class_t; + using OperatorClass = cutlass::detail::get_operator_class_t< + typename CollectiveMainloop::TiledMma>; using ArchTag = typename GemmKernel::ArchTag; // NOTE: Assume identity swizzle for now - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape - using ThreadblockShape = cutlass::gemm::GemmShape< - cute::size<0>(TileShape{}), - cute::size<1>(TileShape{}), - cute::size<2>(TileShape{})>; + using ThreadblockShape = cutlass::gemm::GemmShape(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; using ClusterShape = cutlass::gemm::GemmShape< cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; - // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + // Instruction shape is easy too, since we get that directly from our + // TiledMma's atom shape using InstructionShape = cutlass::gemm::GemmShape< cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), @@ -178,28 +192,42 @@ class GemmUniversalAdapter< // Warp shape is not a primary API type in 3.x // But we can best approximate it by inspecting the TiledMma - // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K - // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads - static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + // For this, we make the assumption that we always have 4 warps along M, and + // rest along N, none along K We also always round up the warp count to 4 if + // the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max( + 4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); static constexpr int WarpsInMmaM = 4; static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); using WarpCount = cutlass::gemm::GemmShape; - using WarpShape = cutlass::gemm::GemmShape< - CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, - CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, - CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; - - static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + using WarpShape = + cutlass::gemm::GemmShape( + typename CollectiveMainloop::TiledMma{})) / + WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>( + typename CollectiveMainloop::TiledMma{})) / + WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>( + typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = + detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); // Inspect TiledCopy for A and B to compute the alignment size - static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); - static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); - static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); - static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + static int constexpr kAlignmentA = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, + typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, + typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; @@ -212,35 +240,29 @@ class GemmUniversalAdapter< /// Argument structure: Kernel API using Params = typename GemmKernel::Params; -private: - + private: /// Kernel API parameters object Params params_; -public: - + public: /// Access the Params structure - Params const& params() const { - return params_; - } + Params const& params() const { return params_; } /// Determines whether the GEMM can execute the given problem. - static Status - can_implement(Arguments const& args) { + static Status can_implement(Arguments const& args) { if (GemmKernel::can_implement(args)) { return Status::kSuccess; - } - else { + } else { return Status::kInvalid; } } /// Gets the workspace size - static size_t - get_workspace_size(Arguments const& args) { + static size_t get_workspace_size(Arguments const& args) { size_t workspace_bytes = 0; if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { - workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * + size_t(cute::size<1>(TileShape{})); } workspace_bytes += GemmKernel::get_workspace_size(args); @@ -251,15 +273,13 @@ class GemmUniversalAdapter< } /// Computes the grid shape - static dim3 - get_grid_shape(Arguments const& args, void* workspace = nullptr) { + static dim3 get_grid_shape(Arguments const& args, void* workspace = nullptr) { auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); return GemmKernel::get_grid_shape(tmp_params); } /// Computes the grid shape - static dim3 - get_grid_shape(Params const& params) { + static dim3 get_grid_shape(Params const& params) { return GemmKernel::get_grid_shape(params); } @@ -273,31 +293,27 @@ class GemmUniversalAdapter< cudaError_t result; if (smem_size >= (48 << 10)) { CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - result = cudaFuncSetAttribute( - device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + result = cudaFuncSetAttribute(device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST( - " cudaFuncSetAttribute() returned error: " - << cudaGetErrorString(result)); + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); return -1; } } // query occupancy after setting smem size result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - device_kernel, - GemmKernel::MaxThreadsPerBlock, - smem_size); + &max_active_blocks, device_kernel, + GemmKernel::MaxThreadsPerBlock, smem_size); if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit + result = cudaGetLastError(); // to clear the error bit CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " - << cudaGetErrorString(result)); + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); return -1; } @@ -306,29 +322,27 @@ class GemmUniversalAdapter< } /// Initializes GEMM state from arguments. - Status - initialize( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + << workspace + << ", stream: " << (stream ? "non-null" : "null")); // Initialize the workspace - Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + Status status = + GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); if (status != Status::kSuccess) { return status; } // Initialize the Params structure params_ = GemmKernel::to_underlying_arguments(args, workspace); - // Don't set the function attributes - require the CudaHostAdapter to set it. + // Don't set the function attributes - require the CudaHostAdapter to set + // it. if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); return Status::kSuccess; - } - else { + } else { // // Account for dynamic smem capacity if needed // @@ -341,11 +355,11 @@ class GemmUniversalAdapter< CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); cudaError_t result = cudaFuncSetAttribute( device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); return Status::kErrorInternal; } } @@ -354,9 +368,9 @@ class GemmUniversalAdapter< return Status::kSuccess; } - /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. - Status - update(Arguments const& args, void* workspace = nullptr) { + /// Update API is preserved in 3.0, but does not guarantee a lightweight + /// update of params. + Status update(Arguments const& args, void* workspace = nullptr) { CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); size_t workspace_bytes = get_workspace_size(args); @@ -368,13 +382,12 @@ class GemmUniversalAdapter< return Status::kSuccess; } - /// Primary run() entry point API that is static allowing users to create and manage their own params. - /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() - static Status - run(Params& params, - sycl::queue& stream, - CudaHostAdapter *cuda_adapter = nullptr, - bool launch_with_pdl = false) { + /// Primary run() entry point API that is static allowing users to create and + /// manage their own params. Supplied params struct must be construct by + /// calling GemmKernel::to_underlying_arguments() + static Status run(Params& params, sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { CUTLASS_TRACE_HOST("GemmUniversal::run()"); dim3 const block = GemmKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); @@ -387,7 +400,7 @@ class GemmUniversalAdapter< // configure smem size and carveout int smem_size = GemmKernel::SharedStorageSize; - Status launch_result{ Status::kSuccess }; + Status launch_result{Status::kSuccess}; // Use extended launch API only for mainloops that use it if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) @@ -395,23 +408,25 @@ class GemmUniversalAdapter< #endif #if !defined(CUTLASS_ENABLE_SYCL) [[maybe_unused]] constexpr bool is_static_1x1x1 = - cute::is_static_v and - cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; - [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - + cute::is_static_v< + typename GemmKernel::DispatchPolicy::ClusterShape> and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] dim3 cluster( + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + // Dynamic cluster support - [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; - if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 - || GemmKernel::ArchTag::kMinComputeCapability == 101 - ) { - if constexpr (!cute::is_static_v) { + [[maybe_unused]] dim3 fallback_cluster = dim3{0, 0, 0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || + GemmKernel::ArchTag::kMinComputeCapability == 101) { + if constexpr (!cute::is_static_v< + typename GemmKernel::DispatchPolicy::ClusterShape>) { fallback_cluster = params.hw_info.cluster_shape_fallback; cluster = params.hw_info.cluster_shape; } } - + [[maybe_unused]] void* kernel_params[] = {¶ms}; if constexpr (kEnableCudaHostAdapter) { @@ -422,106 +437,101 @@ class GemmUniversalAdapter< if (cuda_adapter) { if (launch_with_pdl) { CUTLASS_TRACE_HOST( - "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + "GemmUniversal::run() does not support launching with PDL and " + "a custom cuda adapter."); return Status::kErrorInternal; } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); -#endif + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with CUDA host adapter"); + #endif if constexpr (is_static_1x1x1) { - launch_result = cuda_adapter->launch(grid, - block, - smem_size, - stream, - kernel_params, - 0); - } - else { - launch_result = cuda_adapter->launch(grid, - cluster, - fallback_cluster, - block, - smem_size, - stream, - kernel_params, - 0); + launch_result = cuda_adapter->launch(grid, block, smem_size, stream, + kernel_params, 0); + } else { + launch_result = + cuda_adapter->launch(grid, cluster, fallback_cluster, block, + smem_size, stream, kernel_params, 0); } - } - else { - CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); + } else { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA " + "host adapter is null"); return Status::kErrorInternal; } - } - else { + } else { CUTLASS_ASSERT(cuda_adapter == nullptr); - [[maybe_unused]] void const* kernel = (void const*) device_kernel; - static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90; + [[maybe_unused]] void const* kernel = + (void const*)device_kernel; + static constexpr bool kClusterLaunch = + GemmKernel::ArchTag::kMinComputeCapability == 90; if constexpr (kClusterLaunch) { if constexpr (is_static_1x1x1) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); -#endif + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching static 1x1x1 kernel"); + #endif launch_result = cutlass::kernel_launch( - grid, block, smem_size, stream, params, launch_with_pdl); + grid, block, smem_size, stream, params, launch_with_pdl); if (launch_result != Status::kSuccess) { - CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports failure"); } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) else { - CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports success"); } -#endif - } - else { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); -#endif - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + #endif + } else { + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching dynamic cluster kernel"); + #endif + launch_result = + ClusterLauncher::launch(grid, cluster, block, smem_size, stream, + kernel, kernel_params, launch_with_pdl); } } - + else { - if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 - || GemmKernel::ArchTag::kMinComputeCapability == 101 - || GemmKernel::ArchTag::kMinComputeCapability == 120 - ) { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || + GemmKernel::ArchTag::kMinComputeCapability == 101 || + GemmKernel::ArchTag::kMinComputeCapability == 120) { if constexpr (is_static_1x1x1) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); -#endif - launch_result = cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching static 1x1x1 kernel"); + #endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); if (launch_result != Status::kSuccess) { - CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports " + "failure"); } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) else { - CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports " + "success"); } -#endif - } - else { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with fall-back cluster"); -#endif + #endif + } else { + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with fall-back " + "cluster"); + #endif launch_result = ClusterLauncher::launch_with_fallback_cluster( - grid, - cluster, - fallback_cluster, - block, - smem_size, - stream, - kernel, - kernel_params, - launch_with_pdl); + grid, cluster, fallback_cluster, block, smem_size, stream, + kernel, kernel_params, launch_with_pdl); } } } - } #endif - } - else { + } else { launch_result = Status::kSuccess; cutlass::arch::synclog_setup(); @@ -530,79 +540,93 @@ class GemmUniversalAdapter< if (cuda_adapter) { void* kernel_params[] = {¶ms}; #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with CUDA host adapter"); #endif - launch_result = cuda_adapter->launch( - grid, block, smem_size, stream, kernel_params, 0 - ); + launch_result = cuda_adapter->launch(grid, block, smem_size, stream, + kernel_params, 0); - } - else { + } else { CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); return Status::kErrorInternal; } - } - else { + } else { CUTLASS_ASSERT(cuda_adapter == nullptr); #if defined(CUTLASS_ENABLE_SYCL) - // sycl::queue q = stream; // ? *stream : syclcompat::get_default_queue(); -#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + // sycl::queue q = stream; // ? *stream : + // syclcompat::get_default_queue(); + #if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) using namespace syclcompat::experimental; if constexpr (cute::is_same_v) { - auto event = launch>(launch_policy{ - sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)} - }, q, params); + auto event = launch>( + launch_policy{sycl_grid, sycl_block, + local_mem_size { + static_cast(smem_size) + }}, + q, params); EventManager::getInstance().addEvent(event); } else { - auto event = launch>(launch_policy{ - sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)} -#if defined(SYCL_INTEL_TARGET) - , kernel_properties{sycl_exp::sub_group_size} -#endif - }, stream, params); + auto event = launch>( + launch_policy{ + sycl_grid, sycl_block, + local_mem_size{static_cast(smem_size)} + #if defined(SYCL_INTEL_TARGET) + , + kernel_properties { + sycl_exp::sub_group_size + } + #endif + }, + stream, params); EventManager::getInstance().addEvent(event); } -#else -#if defined (SYCL_INTEL_TARGET) + #else + #if defined(SYCL_INTEL_TARGET) constexpr bool allow_subgroup_size_prop = true; -#else + #else constexpr bool allow_subgroup_size_prop = false; -#endif + #endif auto kernel_props = [] { constexpr bool is_device_agnostic = - cute::is_same_v; + cute::is_same_v; if constexpr (!allow_subgroup_size_prop or is_device_agnostic) { - using EmptyProperties = decltype(sycl::ext::oneapi::experimental::properties()); - return syclcompat::experimental::kernel_properties{}; + using EmptyProperties = + decltype(sycl::ext::oneapi::experimental::properties()); + return syclcompat::experimental::kernel_properties< + EmptyProperties>{}; } else { return syclcompat::experimental::kernel_properties{ - sycl::ext::oneapi::experimental::sub_group_size - }; + sycl::ext::oneapi::experimental::sub_group_size< + DispatchPolicy::SubgroupSize>}; } }(); - syclcompat::experimental::launch_properties launch_props { - sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + syclcompat::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), }; syclcompat::experimental::launch_policy policy{ - sycl_grid, sycl_block, launch_props, kernel_props - }; - auto event = syclcompat::experimental::launch>(policy, stream, params); + sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = + syclcompat::experimental::launch>( + policy, stream, params); EventManager::getInstance().addEvent(event); -#endif // !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + #endif // !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) #else -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); -#endif + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); + #endif launch_result = cutlass::kernel_launch( - grid, block, smem_size, stream, params, launch_with_pdl); + grid, block, smem_size, stream, params, launch_with_pdl); if (launch_result != Status::kSuccess) { - CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports failure"); } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) else { - CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports success"); } -#endif + #endif #endif } } @@ -610,29 +634,26 @@ class GemmUniversalAdapter< cudaError_t result = cudaGetLastError(); if (cudaSuccess == result && Status::kSuccess == launch_result) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cudaGetLastError reports success"); #endif return Status::kSuccess; - } - else { + } else { CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); return Status::kErrorInternal; } } // - // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // Non-static launch overloads that first create and set the internal params + // struct of this kernel handle. // - /// Launches the kernel after first constructing Params internal state from supplied arguments. - Status - run( - Arguments const& args, - void* workspace, - sycl::queue& stream, - CudaHostAdapter *cuda_adapter = nullptr, - bool launch_with_pdl = false - ) { + /// Launches the kernel after first constructing Params internal state from + /// supplied arguments. + Status run(Arguments const& args, void* workspace, sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { Status status = initialize(args, workspace, stream, cuda_adapter); if (Status::kSuccess == status) { @@ -641,29 +662,26 @@ class GemmUniversalAdapter< return status; } - /// Launches the kernel after first constructing Params internal state from supplied arguments. - Status - operator()( - Arguments const& args, - void* workspace, - sycl::queue& stream, - CudaHostAdapter *cuda_adapter = nullptr, - bool launch_with_pdl = false) { + /// Launches the kernel after first constructing Params internal state from + /// supplied arguments. + Status operator()(Arguments const& args, void* workspace, sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { return run(args, workspace, stream, cuda_adapter, launch_with_pdl); } - /// Overload that allows a user to re-launch the same kernel without updating internal params struct. - Status - run( - sycl::queue& stream, - CudaHostAdapter *cuda_adapter = nullptr, - bool launch_with_pdl = false) { + /// Overload that allows a user to re-launch the same kernel without updating + /// internal params struct. + Status run(sycl::queue& stream, CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { return run(params_, stream, cuda_adapter, launch_with_pdl); } - /// Overload that allows a user to re-launch the same kernel without updating internal params struct. - Status - operator()(sycl::queue& stream, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + /// Overload that allows a user to re-launch the same kernel without updating + /// internal params struct. + Status operator()(sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { return run(params_, stream, cuda_adapter, launch_with_pdl); } }; @@ -674,16 +692,17 @@ class GemmUniversalAdapter< template class GemmUniversalAdapter< - GemmKernel_, - cute::enable_if_t>::value>> -{ -public: - + GemmKernel_, cute::enable_if_t>::value>> { + public: using GemmKernel = GetUnderlyingKernel_t; static bool const kInternalTranspose = - !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose - cute::is_same::value; + !cutlass::epilogue::threadblock::detail::is_2x_evt_v< + typename GemmKernel::Epilogue> && // 2.x EVT does not require + // internal transpose + cute::is_same::value; using ThreadblockShape = typename GemmKernel::Mma::Shape; using WarpShape = typename GemmKernel::WarpShape; @@ -701,17 +720,11 @@ class GemmUniversalAdapter< // Type, layout, and complex transform deliberately exchanged with B using MapArguments = kernel::detail::MapArguments< - typename GemmKernel::ElementA, - typename GemmKernel::LayoutA, - GemmKernel::kTransformA, - GemmKernel::kAlignmentA, - typename GemmKernel::ElementB, - typename GemmKernel::LayoutB, - GemmKernel::kTransformB, - GemmKernel::kAlignmentB, - typename GemmKernel::LayoutC, - kInternalTranspose - >; + typename GemmKernel::ElementA, typename GemmKernel::LayoutA, + GemmKernel::kTransformA, GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, typename GemmKernel::LayoutB, + GemmKernel::kTransformB, GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, kInternalTranspose>; using ElementA = typename MapArguments::ElementA; using LayoutA = typename MapArguments::LayoutA; @@ -744,39 +757,39 @@ class GemmUniversalAdapter< using UnderlyingOperator = GemmUniversalBase; using Arguments = typename UnderlyingOperator::Arguments; -private: - + private: UnderlyingOperator underlying_operator_; -public: - + public: /// Constructs the GEMM. - GemmUniversalAdapter() { } + GemmUniversalAdapter() {} - /// Helper to construct a transposed equivalent for the underying GEMM operator - static Arguments to_underlying_arguments(Arguments const &args) { + /// Helper to construct a transposed equivalent for the underying GEMM + /// operator + static Arguments to_underlying_arguments(Arguments const& args) { if (kInternalTranspose) { return args.transposed_problem(); - } - else { + } else { return args; } } /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { - - return UnderlyingOperator::can_implement(to_underlying_arguments(args), cuda_adapter); + static Status can_implement(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + return UnderlyingOperator::can_implement(to_underlying_arguments(args), + cuda_adapter); } /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { - - return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), cuda_adapter); + static size_t get_workspace_size(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), + cuda_adapter); } /// Computes the grid shape - static dim3 get_grid_shape(Arguments const &args) { + static dim3 get_grid_shape(Arguments const& args) { return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); } @@ -786,45 +799,34 @@ class GemmUniversalAdapter< } /// Initializes GEMM state from arguments. - Status initialize( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr - ) { - - return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter); + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return underlying_operator_.initialize(to_underlying_arguments(args), + workspace, stream, cuda_adapter); } /// Lightweight update given a subset of arguments. - Status update(Arguments const &args) { - + Status update(Arguments const& args) { return underlying_operator_.update(to_underlying_arguments(args)); } /// Runs the kernel using initialized state. - Status run( - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - + Status run(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { return underlying_operator_.run(stream, cuda_adapter); } /// Runs the kernel using initialized state. - Status operator()( - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - + Status operator()(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { return run(stream); } /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { @@ -837,6 +839,6 @@ class GemmUniversalAdapter< //////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::device +} // namespace cutlass::gemm::device //////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h index bc64b3d..b909318 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,27 +18,29 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. + \brief The universal GEMM accommodates streamk, batched strided, and batched + array variants. */ #pragma once #if defined(__CUDACC_RTC__) -#include + #include #else -#include + #include #endif #include "cutlass/cutlass.h" @@ -63,11 +65,9 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// - template class GemmUniversalBase { -public: - + public: using GemmKernel = GemmKernel_; /// Boolean indicating whether the CudaHostAdapter is enabled @@ -100,16 +100,16 @@ class GemmUniversalBase { /// Argument structure using Arguments = typename GemmKernel::Arguments; - /// Index of the GEMM Kernel within the CudaHostAdapter static int32_t const kGemmKernelIndex = 0; /// Kernel dynamic shared memory allocation requirement - /// Update the kernel function's shared memory configuration for the current device - static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage); - -protected: + /// Update the kernel function's shared memory configuration for the current + /// device + static constexpr size_t kSharedStorageSize = + sizeof(typename GemmKernel::SharedStorage); + protected: // // Device properties (uniform across all instances of the current thread) // @@ -123,12 +123,10 @@ class GemmUniversalBase { /// Kernel SM occupancy (in thread blocks) CUTLASS_THREAD_LOCAL static int sm_occupancy_; -protected: - + protected: /// Initialize static thread-local members for the thread's current device, /// if necessary. - static Status init_device_props() - { + static Status init_device_props() { CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); cudaError_t cudart_result; @@ -137,7 +135,8 @@ class GemmUniversalBase { int current_ordinal; cudart_result = cudaGetDevice(¤t_ordinal); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } @@ -148,53 +147,62 @@ class GemmUniversalBase { } // Update SM count member - cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + cudart_result = cudaDeviceGetAttribute( + &device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } // If requires more than 48KB: configure for extended, dynamic shared memory - if constexpr (kSharedStorageSize >= (48 << 10)) - { + if constexpr (kSharedStorageSize >= (48 << 10)) { cudart_result = cudaFuncSetAttribute( - Kernel2, - cudaFuncAttributeMaxDynamicSharedMemorySize, - kSharedStorageSize); + Kernel2, cudaFuncAttributeMaxDynamicSharedMemorySize, + kSharedStorageSize); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } } // Update SM occupancy member cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( - &sm_occupancy_, - Kernel2, - GemmKernel::kThreadCount, - kSharedStorageSize, - cudaOccupancyDisableCachingOverride); + &sm_occupancy_, Kernel2, GemmKernel::kThreadCount, + kSharedStorageSize, cudaOccupancyDisableCachingOverride); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned " + "error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } // Update device ordinal member on success device_ordinal_ = current_ordinal; - CUTLASS_TRACE_HOST(" " - "device_ordinal: (" << device_ordinal_ << "), " - "device_sms: (" << device_sms_ << "), " - "sm_occupancy: (" << sm_occupancy_ << ") " - "smem_size: (" << kSharedStorageSize << ") " - "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); + CUTLASS_TRACE_HOST( + " " + "device_ordinal: (" + << device_ordinal_ + << "), " + "device_sms: (" + << device_sms_ + << "), " + "sm_occupancy: (" + << sm_occupancy_ + << ") " + "smem_size: (" + << kSharedStorageSize + << ") " + "GemmKernel::kThreadCount: (" + << GemmKernel::kThreadCount << ")"); return Status::kSuccess; } - -protected: - + protected: // // Instance data members // @@ -202,10 +210,9 @@ class GemmUniversalBase { /// Kernel parameters typename GemmKernel::Params params_; - /// Initialize params member - Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) - { + Status init_params(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { int32_t device_sms = 0; int32_t sm_occupancy = 0; @@ -217,25 +224,19 @@ class GemmUniversalBase { // if (cuda_adapter) { - Status status = cuda_adapter->query_occupancy( - &device_sms, - &sm_occupancy, - kGemmKernelIndex, - GemmKernel::kThreadCount, - kSharedStorageSize); + &device_sms, &sm_occupancy, kGemmKernelIndex, + GemmKernel::kThreadCount, kSharedStorageSize); CUTLASS_ASSERT(status == Status::kSuccess); if (status != Status::kSuccess) { return status; } - } - else { + } else { return Status::kErrorInternal; } - } - else { + } else { CUTLASS_ASSERT(cuda_adapter == nullptr); // Initialize static device properties, if necessary @@ -246,11 +247,11 @@ class GemmUniversalBase { } // - // Use thread-local static members for occupancy query initialized by call to - // `init_device_props()` + // Use thread-local static members for occupancy query initialized by call + // to `init_device_props()` // - device_sms = device_sms_; + device_sms = device_sms_; sm_occupancy = sm_occupancy_; } @@ -259,54 +260,51 @@ class GemmUniversalBase { return Status::kSuccess; } -public: - + public: //--------------------------------------------------------------------------------------------- // Stateless API //--------------------------------------------------------------------------------------------- /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) - { + static Status can_implement(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); if (!kEnableCudaHostAdapter || cuda_adapter) { - dim3 grid = get_grid_shape(args, cuda_adapter); if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) - { + grid.z <= std::numeric_limits::max())) { return Status::kErrorInvalidProblem; } - } - else { + } else { // - // With a null host adapter, a conservative grid shape is computed and required to conform to CUDA grid - // dimension limits. + // With a null host adapter, a conservative grid shape is computed and + // required to conform to CUDA grid dimension limits. // - int64_t logicalGridM = (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / ThreadblockShape::kM; - int64_t logicalGridN = (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + int64_t logicalGridM = + (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / + ThreadblockShape::kM; + int64_t logicalGridN = + (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / + ThreadblockShape::kN; int32_t logicalGridL = args.batch_count; if ((int64_t(std::numeric_limits::max()) < logicalGridM) || (int64_t(std::numeric_limits::max()) < logicalGridN) || (int32_t(std::numeric_limits::max()) < logicalGridL)) { - return Status::kErrorInvalidProblem; } - } return GemmKernel::can_implement(args); } - /// Returns the workspace size (in bytes) needed for the problem /// geometry expressed by these arguments - static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) - { + static size_t get_workspace_size(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); // Initialize parameters from args @@ -322,61 +320,51 @@ class GemmUniversalBase { return workspace_bytes; } - /// Returns the grid extents in thread blocks to launch - static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) - { + static dim3 get_grid_shape(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); // Initialize parameters from args GemmUniversalBase base; if (base.init_params(args, cuda_adapter) != Status::kSuccess) { - return dim3(0,0,0); + return dim3(0, 0, 0); } // Get dims from parameters dim3 grid_dims = base.params_.get_grid_dims(); - CUTLASS_TRACE_HOST( - " tiled_shape: " << base.params_.get_tiled_shape() << "\n" - << " grid_dims: {" << grid_dims << "}"); + CUTLASS_TRACE_HOST(" tiled_shape: " + << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); return grid_dims; } - /// Returns the maximum number of active thread blocks per multiprocessor - static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr) - { + static int maximum_active_blocks(CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); - int32_t device_sms = 0; + int32_t device_sms = 0; int32_t sm_occupancy = 0; - if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - Status status = cuda_adapter->query_occupancy( - &device_sms, - &sm_occupancy, - kGemmKernelIndex, - GemmKernel::kThreadCount, - kSharedStorageSize); + &device_sms, &sm_occupancy, kGemmKernelIndex, + GemmKernel::kThreadCount, kSharedStorageSize); CUTLASS_ASSERT(status == Status::kSuccess); if (status != Status::kSuccess) { - return -1; + return -1; } - } - else { + } else { return -1; } - } - else { + } else { CUTLASS_ASSERT(cuda_adapter == nullptr); // Initialize static device properties, if necessary if (init_device_props() != Status::kSuccess) { @@ -390,20 +378,17 @@ class GemmUniversalBase { return sm_occupancy; } - //--------------------------------------------------------------------------------------------- // Stateful API //--------------------------------------------------------------------------------------------- /// Initializes GEMM state from arguments and workspace memory - Status initialize( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) - { + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + << workspace + << ", stream: " << (stream ? "non-null" : "null")); // Initialize parameters from args Status result = init_params(args, cuda_adapter); @@ -419,18 +404,16 @@ class GemmUniversalBase { return Status::kSuccess; } - /// Lightweight update given a subset of arguments. - Status update(Arguments const &args) - { + Status update(Arguments const& args) { CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); params_.update(args); return Status::kSuccess; } /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) - { + Status run(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); // Configure grid and block dimensions @@ -438,10 +421,16 @@ class GemmUniversalBase { dim3 grid = params_.get_grid_dims(); // Launch kernel - CUTLASS_TRACE_HOST(" " - "grid: (" << grid << "), " - "block: (" << block << "), " - "SMEM: (" << kSharedStorageSize << ")"); + CUTLASS_TRACE_HOST( + " " + "grid: (" + << grid + << "), " + "block: (" + << block + << "), " + "SMEM: (" + << kSharedStorageSize << ")"); cutlass::arch::synclog_setup(); @@ -449,13 +438,12 @@ class GemmUniversalBase { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { void* kernel_params[] = {¶ms_}; - return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0); - } - else { + return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, + kernel_params, 0); + } else { return Status::kErrorInternal; } - } - else { + } else { CUTLASS_ASSERT(cuda_adapter == nullptr); #if defined(CUTLASS_ENABLE_SYCL) @@ -464,15 +452,17 @@ class GemmUniversalBase { sycl::queue q = stream ? *stream : syclcompat::get_default_queue(); syclcompat::experimental::launch>( - syclcompat::experimental::launch_policy{ - sycl_grid, sycl_block, -#if defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) - sycl::ext::oneapi::experimental::work_group_scratch_size(kSharedStorageSize) -#else - syclcompat::experimental::local_mem_size{static_cast(kSharedStorageSize)} -#endif - }, - q, params_); + syclcompat::experimental::launch_policy{ + sycl_grid, sycl_block, + #if defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + sycl::ext::oneapi::experimental::work_group_scratch_size( + kSharedStorageSize) + #else + syclcompat::experimental::local_mem_size{ + static_cast(kSharedStorageSize)} + #endif + }, + q, params_); #else Kernel2<<>>(params_); #endif @@ -480,7 +470,8 @@ class GemmUniversalBase { // Query for errors cudaError_t result = cudaGetLastError(); if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); return Status::kErrorInternal; } } @@ -488,21 +479,16 @@ class GemmUniversalBase { return Status::kSuccess; } - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) - { + Status operator()(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { return run(stream, cuda_adapter); } - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) - { + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { @@ -513,7 +499,6 @@ class GemmUniversalBase { } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Static initializers ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -532,8 +517,8 @@ CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h index b2aee30..19871ee 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -57,22 +58,18 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_ ///! Threadblock swizzling function -> +template class GemmUniversal< - Mma_, - Epilogue_, - ThreadblockSwizzle_, - void, - // 3.x kernels use the first template argument to define the ProblemShape - // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API - cute::enable_if_t::value || IsCutlass3ArrayKernel::value)> -> { -public: - + Mma_, Epilogue_, ThreadblockSwizzle_, void, + // 3.x kernels use the first template argument to define the ProblemShape + // We use this invariant to SFINAE dispatch against either the 2.x API or + // the 3.x API + cute::enable_if_t::value || + IsCutlass3ArrayKernel::value)>> { + public: using Mma = Mma_; using Epilogue = Epilogue_; using EpilogueOutputOp = typename Epilogue::OutputOp; @@ -98,32 +95,33 @@ class GemmUniversal< static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; /// Warp count (concept: GemmShape) using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); // // Structures // /// Argument structure - struct Arguments : UniversalArgumentsBase - { + struct Arguments : UniversalArgumentsBase { // // Data members // typename EpilogueOutputOp::Params epilogue; - void const * ptr_A; - void const * ptr_B; - void const * ptr_C; - void * ptr_D; + void const* ptr_A; + void const* ptr_B; + void const* ptr_C; + void* ptr_D; int64_t batch_stride_A; int64_t batch_stride_B; @@ -139,98 +137,103 @@ class GemmUniversal< typename LayoutC::Stride::LongIndex ldc; typename LayoutC::Stride::LongIndex ldd; - int const * ptr_gather_A_indices; - int const * ptr_gather_B_indices; - int const * ptr_scatter_D_indices; + int const* ptr_gather_A_indices; + int const* ptr_gather_B_indices; + int const* ptr_scatter_D_indices; // // Methods // - Arguments(): - ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), - ptr_gather_A_indices(nullptr), - ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr) - {} + Arguments() + : ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr) {} /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void const * ptr_C, - void * ptr_D, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D, - typename LayoutA::Stride stride_a, - typename LayoutB::Stride stride_b, - typename LayoutC::Stride stride_c, - typename LayoutC::Stride stride_d, - int const *ptr_gather_A_indices = nullptr, - int const *ptr_gather_B_indices = nullptr, - int const *ptr_scatter_D_indices = nullptr) - : - UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), - stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), - ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) - { + Arguments(GemmUniversalMode mode, GemmCoord problem_size, int batch_count, + typename EpilogueOutputOp::Params epilogue, void const* ptr_A, + void const* ptr_B, void const* ptr_C, void* ptr_D, + int64_t batch_stride_A, int64_t batch_stride_B, + int64_t batch_stride_C, int64_t batch_stride_D, + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d, + int const* ptr_gather_A_indices = nullptr, + int const* ptr_gather_B_indices = nullptr, + int const* ptr_scatter_D_indices = nullptr) + : UniversalArgumentsBase(mode, problem_size, batch_count, + batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + stride_a(stride_a), + stride_b(stride_b), + stride_c(stride_c), + stride_d(stride_d), + ptr_gather_A_indices(ptr_gather_A_indices), + ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { lda = 0; ldb = 0; ldc = 0; ldd = 0; - CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST( + "GemmUniversal::Arguments::Arguments() - problem_size: " + << problem_size); } /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void const * ptr_C, - void * ptr_D, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D, - typename LayoutA::Stride::LongIndex lda, - typename LayoutB::Stride::LongIndex ldb, - typename LayoutC::Stride::LongIndex ldc, - typename LayoutC::Stride::LongIndex ldd, - int const *ptr_gather_A_indices = nullptr, - int const *ptr_gather_B_indices = nullptr, - int const *ptr_scatter_D_indices = nullptr - ): - UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), - ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) - { + Arguments(GemmUniversalMode mode, GemmCoord problem_size, int batch_count, + typename EpilogueOutputOp::Params epilogue, void const* ptr_A, + void const* ptr_B, void const* ptr_C, void* ptr_D, + int64_t batch_stride_A, int64_t batch_stride_B, + int64_t batch_stride_C, int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + int const* ptr_gather_A_indices = nullptr, + int const* ptr_gather_B_indices = nullptr, + int const* ptr_scatter_D_indices = nullptr) + : UniversalArgumentsBase(mode, problem_size, batch_count, + batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + ptr_gather_A_indices(ptr_gather_A_indices), + ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { stride_a = make_Coord(lda); stride_b = make_Coord(ldb); stride_c = make_Coord(ldc); stride_d = make_Coord(ldd); - CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST( + "GemmUniversal::Arguments::Arguments() - problem_size: " + << problem_size); } /// Returns arguments for the transposed problem - Arguments transposed_problem() const - { + Arguments transposed_problem() const { Arguments args(*this); std::swap(args.problem_size.m(), args.problem_size.n()); @@ -244,29 +247,17 @@ class GemmUniversal< } }; - // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params : UniversalParamsBase< - ThreadblockSwizzle, - ThreadblockShape, - ElementA, - ElementB, - ElementC, - LayoutA, - LayoutB> - { - using ParamsBase = UniversalParamsBase< - ThreadblockSwizzle, - ThreadblockShape, - ElementA, - ElementB, - ElementC, - LayoutA, - LayoutB>; + struct Params + : UniversalParamsBase { + using ParamsBase = + UniversalParamsBase; // // Data members @@ -279,18 +270,18 @@ class GemmUniversal< typename EpilogueOutputOp::Params output_op; - void * ptr_A; - void * ptr_B; - void * ptr_C; - void * ptr_D; + void* ptr_A; + void* ptr_B; + void* ptr_C; + void* ptr_D; int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int * ptr_gather_A_indices; - int * ptr_gather_B_indices; - int * ptr_scatter_D_indices; + int* ptr_gather_A_indices; + int* ptr_gather_B_indices; + int* ptr_scatter_D_indices; // // Host dispatch API @@ -300,38 +291,42 @@ class GemmUniversal< Params() = default; /// Constructor - Params( - Arguments const &args, /// GEMM application arguments - int device_sms, /// Number of SMs on the device - int sm_occupancy) /// Kernel SM occupancy (in thread blocks) - : - ParamsBase(args, device_sms, sm_occupancy), - params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), - params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), - params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), - params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), - output_op(args.epilogue), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_D(args.ptr_D), - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), - ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), - ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) - {} + Params(Arguments const& args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : ParamsBase(args, device_sms, sm_occupancy), + params_A(args.lda + ? make_Coord_with_padding(args.lda) + : args.stride_a), + params_B(args.ldb + ? make_Coord_with_padding(args.ldb) + : args.stride_b), + params_C(args.ldc + ? make_Coord_with_padding(args.ldc) + : args.stride_c), + params_D(args.ldd + ? make_Coord_with_padding(args.ldd) + : args.stride_d), + output_op(args.epilogue), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), + ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) {} /// Lightweight update given a subset of arguments. - void update(Arguments const &args) - { + void update(Arguments const& args) { CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); // Update input/output pointers - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); ptr_D = args.ptr_D; batch_stride_A = args.batch_stride_A; @@ -339,13 +334,12 @@ class GemmUniversal< batch_stride_C = args.batch_stride_C; this->batch_stride_D = args.batch_stride_D; - ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); - ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); - ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); + ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); + ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); output_op = args.epilogue; } - }; /// Shared memory storage structure @@ -354,40 +348,30 @@ class GemmUniversal< typename Epilogue::SharedStorage epilogue; }; - -public: - + public: // // Host dispatch API // /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size) - { + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); - static int const kAlignmentA = (cute::is_same>::value) - ? 32 - : (cute::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = (cute::is_same>::value) - ? 32 - : (cute::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = (cute::is_same>::value) - ? 32 - : (cute::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kAlignmentA = + (cute::is_same>::value) ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (cute::is_same>::value) ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + (cute::is_same>::value) ? 32 + : (cute::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; bool isAMisaligned = false; bool isBMisaligned = false; @@ -397,8 +381,10 @@ class GemmUniversal< isAMisaligned = problem_size.k() % kAlignmentA; } else if (cute::is_same::value) { isAMisaligned = problem_size.m() % kAlignmentA; - } else if (cute::is_same>::value - || cute::is_same>::value) { + } else if (cute::is_same>::value || + cute::is_same>::value) { isAMisaligned = problem_size.k() % kAlignmentA; } @@ -406,8 +392,8 @@ class GemmUniversal< isBMisaligned = problem_size.n() % kAlignmentB; } else if (cute::is_same::value) { isBMisaligned = problem_size.k() % kAlignmentB; - } else if (cute::is_same>::value - || cute::is_same>::value) { + } else if (cute::is_same>::value || + cute::is_same>::value) { isBMisaligned = problem_size.k() % kAlignmentB; } @@ -415,8 +401,10 @@ class GemmUniversal< isCMisaligned = problem_size.n() % kAlignmentC; } else if (cute::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; - } else if (cute::is_same>::value - || cute::is_same>::value) { + } else if (cute::is_same>::value || + cute::is_same>::value) { isCMisaligned = problem_size.n() % kAlignmentC; } @@ -440,109 +428,90 @@ class GemmUniversal< return Status::kSuccess; } - static Status can_implement(Arguments const &args) { + static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } - -public: - + public: // // Device-only API // // Factory invocation CUTLASS_DEVICE - static void invoke( - Params const ¶ms, - SharedStorage &shared_storage) - { + static void invoke(Params const& params, SharedStorage& shared_storage) { GemmUniversal op; op(params, shared_storage); } - /// Executes one GEMM CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { + void operator()(Params const& params, SharedStorage& shared_storage) { ThreadblockSwizzle threadblock_swizzle; run_with_swizzle(params, shared_storage, threadblock_swizzle); } /// Executes one GEMM with an externally-provided swizzling function CUTLASS_DEVICE - void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { - + void run_with_swizzle(Params const& params, SharedStorage& shared_storage, + ThreadblockSwizzle& threadblock_swizzle) { cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { return; } int offset_k = 0; int problem_size_k = params.problem_size.k(); - ElementA *ptr_A = static_cast(params.ptr_A); - ElementB *ptr_B = static_cast(params.ptr_B); + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); // // Fetch pointers based on mode. // if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - + params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) { + } else if (params.mode == GemmUniversalMode::kBatched) { ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast( + params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast( + params.ptr_B)[threadblock_tile_offset.k()]; } syncthreads(); // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, }; cutlass::MatrixCoord tb_offset_B{ - offset_k, - threadblock_tile_offset.n() * Mma::Shape::kN - }; + offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; // Compute position within threadblock int thread_idx = ThreadIdxX(); // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A, - params.ptr_gather_A_indices); + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, + thread_idx, tb_offset_A, params.ptr_gather_A_indices); typename Mma::IteratorB iterator_B( - params.params_B, - ptr_B, - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B, - params.ptr_gather_B_indices); + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, + thread_idx, tb_offset_B, params.ptr_gather_B_indices); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -562,15 +531,11 @@ class GemmUniversal< accumulators.clear(); // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = + (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add - mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators); + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); // // Epilogue @@ -582,18 +547,19 @@ class GemmUniversal< // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - //assume identity swizzle + // assume identity swizzle MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C = static_cast(params.ptr_C); - ElementC *ptr_D = static_cast(params.ptr_D); + ElementC* ptr_C = static_cast(params.ptr_C); + ElementC* ptr_D = static_cast(params.ptr_D); // // Fetch pointers based on mode. @@ -603,59 +569,47 @@ class GemmUniversal< Semaphore semaphore(params.semaphore + block_idx, thread_idx); if (params.mode == GemmUniversalMode::kGemm) { - - // If performing a reduction via split-K, fetch the initial synchronization + // If performing a reduction via split-K, fetch the initial + // synchronization if (params.grid_tiled_shape.k() > 1) { - // Fetch the synchronization lock initially but do not block. semaphore.fetch(); - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); } - } - else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + } else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kBatched) { + } else if (params.mode == GemmUniversalMode::kBatched) { ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; - ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast( + params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast( + params.ptr_D)[threadblock_tile_offset.k()]; } // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - ptr_C, - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.ptr_scatter_D_indices - ); + params.params_C, ptr_C, params.problem_size.mn(), thread_idx, + threadblock_offset, params.ptr_scatter_D_indices); // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.ptr_scatter_D_indices - ); - - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + params.params_D, ptr_D, params.problem_size.mn(), thread_idx, + threadblock_offset, params.ptr_scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. if (threadblock_tile_offset.k()) { iterator_C = iterator_D; } @@ -663,27 +617,20 @@ class GemmUniversal< semaphore.wait(threadblock_tile_offset.k()); } - // Execute the epilogue operator to update the destination tensor. - epilogue( - output_op, - iterator_D, - accumulators, - iterator_C); + epilogue(output_op, iterator_D, accumulators, iterator_C); // // Release the semaphore // - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - // The final threadblock resets the semaphore for subsequent grids. lock = 0; - } - else { + } else { // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } @@ -695,8 +642,8 @@ class GemmUniversal< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp index 8d59a77..7ee0efc 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -52,17 +53,15 @@ namespace cutlass::epilogue::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - class DispatchPolicy, - class... Args -> +template class CollectiveEpilogue { - static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); + static_assert(cutlass::detail::dependent_false, + "Could not find an epilogue specialization."); }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::epilogue::collective +} // namespace cutlass::epilogue::collective ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -74,36 +73,15 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - class CtaTileMNK_, - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpG2R_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpR2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_ -> -class CollectiveEpilogue< - IntelXeXMX16Group, - CtaTileMNK_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2R_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpR2G_, - SmemLayoutAtomD_, - CopyOpR2S_ -> { -public: +template +class CollectiveEpilogue { + public: // // Type Aliases // @@ -124,52 +102,69 @@ class CollectiveEpilogue< using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; - using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using ThreadEpilogueOp = + typename fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2R; - using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, + using GmemTiledCopyD = cute::conditional_t && + not cute::is_void_v, CopyOpR2G, XE_2D_U32x8x16_ST_N>; using ElementOutput = ElementD; using ElementCompute = ElementAccumulator; using ElementSource = typename FusionCallbacks::ElementSource; using ElementScalar = typename FusionCallbacks::ElementScalar; - static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + static constexpr FloatRoundStyle RoundStyle = + FloatRoundStyle::round_to_nearest; - static_assert(cute::is_same_v>, - "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); + static_assert( + cute::is_same_v< + typename FusionCallbacks::Operation, + fusion::LinearCombination>, + "Only Linear Combination Epilogue is supported for Grouped GEMM at the " + "moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(cute::rank(CtaTileMNK{}) == 3, + "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); using CopyThreadShape = Shape<_1, Int>; using Trait_C = Copy_Traits; - using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, - Layout{}, - make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); + using XE_Copy_C = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + make_layout( + shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); using Trait_D = Copy_Traits; - using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, - Layout{}, - make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); -private: + using XE_Copy_D = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + make_layout( + shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); + + private: // constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_source_supported = false; - constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; - -public: + constexpr static bool is_destination_supported = + not cute::is_void_v && not cute::is_void_v; + public: using EmptyType = cute::tuple<>; using SmemCStorage = EmptyType; using SmemDStorage = EmptyType; - struct TensorStorageImpl: cute::tuple { + struct TensorStorageImpl : cute::tuple { using FusionStorage = typename FusionCallbacks::SharedStorage; FusionStorage thread; }; @@ -181,8 +176,12 @@ class CollectiveEpilogue< }; using TensorStorage = typename SharedStorage::TensorStorage; - using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideC{})); //(m, n) - using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideD{})); //(m, n) + using TensorC = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideC{})); //(m, n) + using TensorD = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideD{})); //(m, n) using EpilogueTensors = cute::tuple; // Host side epilogue arguments @@ -210,58 +209,59 @@ class CollectiveEpilogue< // template - static constexpr Params - to_underlying_arguments( - ProblemShape const& problem_shape, - Arguments const& args, + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNL = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + // Optionally append 1s until problem shape is rank-4 in case its is only + // rank-3 (MNK) + auto problem_shape_MNL = repeat_like( + typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); auto [M, N, L] = problem_shape_MNL; XE_Copy_C xe_load_c = {}; if constexpr (is_source_supported) { - ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); - TensorC mC_mnl = make_tensor(make_gmem_ptr(ptr_C_first_batch), make_layout(make_shape(M, N, L), InternalStrideC{})); + ElementC const* ptr_C_first_batch = + reinterpret_cast(args.ptr_C); + TensorC mC_mnl = + make_tensor(make_gmem_ptr(ptr_C_first_batch), + make_layout(make_shape(M, N, L), InternalStrideC{})); xe_load_c = {xe_load_c.with(mC_mnl)}; } XE_Copy_D xe_store_d = {}; if constexpr (is_destination_supported) { ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); - TensorD mD_mnl = make_tensor(make_gmem_ptr(ptr_D_first_batch), make_layout(make_shape(M, N, L), InternalStrideD{})); + TensorD mD_mnl = + make_tensor(make_gmem_ptr(ptr_D_first_batch), + make_layout(make_shape(M, N, L), InternalStrideD{})); xe_store_d = {xe_store_d.with(mD_mnl)}; } - return { - FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - xe_load_c, - xe_store_d, - args.ptr_C, - args.dC, - args.ptr_D, - args.dD - }; + return {FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, + workspace), + xe_load_c, + xe_store_d, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD}; } template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + static size_t get_workspace_size(ProblemShape const& problem_shape, + Arguments const& args) { return 0; } template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, Arguments const& args, void* workspace, + cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } template - static bool - can_implement( - ProblemShape problem_shape, - Arguments const& args) { + static bool can_implement(ProblemShape problem_shape, Arguments const& args) { constexpr int copy_alignment_bits = 128; constexpr int batch_alignment_bits = 512; @@ -269,99 +269,118 @@ class CollectiveEpilogue< bool fusion_implementable = true; for (int i = 0; i < problem_shape.groups(); ++i) { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto problem_shape_MNKL = + append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; if constexpr (is_destination_supported) { - constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; - implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + constexpr int min_aligned_elements_D = + copy_alignment_bits / sizeof_bits::value; + implementable &= + cutlass::detail::check_alignment( + cute::make_shape(M, N, L), InternalStrideD{}); if (L > 1) { - constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; - implementable &= get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; + constexpr int min_batch_aligned_elements_D = + batch_alignment_bits / sizeof_bits::value; + implementable &= + get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; } } if constexpr (is_source_supported) { - constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; - implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + constexpr int min_aligned_elements_C = + copy_alignment_bits / sizeof_bits::value; + implementable &= + cutlass::detail::check_alignment( + cute::make_shape(M, N, L), InternalStrideC{}); if (L > 1) { - constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; - implementable &= get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; + constexpr int min_batch_aligned_elements_C = + batch_alignment_bits / sizeof_bits::value; + implementable &= + get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; } } - fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + fusion_implementable = + fusion_implementable && + FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); } if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for XE 2D copy.\n"); } if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements " + "for FusionCallbacks.\n"); } return implementable && fusion_implementable; } CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) - : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + CollectiveEpilogue(Params const& params_, + TensorStorage const& shared_storage_) + : params(params_), + fusion_callbacks(params_.thread, shared_storage_.thread) {} CUTLASS_DEVICE - bool - is_producer_load_needed() const { + bool is_producer_load_needed() const { return fusion_callbacks.is_producer_load_needed(); } - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class Accumulator, - class TiledMma, - class LoadStoreTensor - > - CUTLASS_DEVICE void - operator() ( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - Accumulator accumulators, - TiledMma tiled_mma, - int thread_idx, - LoadStoreTensor const& load_store_tensors) { - - (void) tiled_mma; + template + CUTLASS_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, TiledMma tiled_mma, + int thread_idx, + LoadStoreTensor const& load_store_tensors) { + (void)tiled_mma; using namespace cute; - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + static_assert(cute::rank(CtaTileMNK{}) == 3, + "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]"); using MmaAtomShape = typename TiledMma::AtomShape_MNK; static constexpr auto BLK_M = get<0>(CtaTileMNK{}); static constexpr auto BLK_N = get<1>(CtaTileMNK{}); static constexpr auto BLK_K = get<2>(CtaTileMNK{}); - // static_assert(is_same_v, "assertation fail"); - static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - + // static_assert(is_same_v, + // "assertation fail"); + static constexpr auto ATOM_M = + get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = + get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = + get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static_assert( - BLK_M % ATOM_M == 0 && - BLK_N % ATOM_N == 0 && - BLK_K % ATOM_K == 0, - "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + BLK_M % ATOM_M == 0 && BLK_N % ATOM_N == 0 && BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); static constexpr auto SG_M = BLK_M / ATOM_M; static constexpr auto SG_N = BLK_N / ATOM_N; static constexpr auto SG_K = BLK_K / ATOM_K; - using SubgroupTileShape = Shape; + using SubgroupTileShape = + Shape; - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + static constexpr int FragsM = + get<0>(SubgroupTileShape{}) / + get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = + get<1>(SubgroupTileShape{}) / + get<1>(MmaAtomShape()); // B frags per sub_group - static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + static constexpr int FragmentSize = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; @@ -369,7 +388,8 @@ class CollectiveEpilogue< auto m_sg = get_sub_group_id() / ATOM_N; auto n_sg = get_sub_group_id() % ATOM_N; - // Get the layout and reconstruct the MN mapping equivalent to the old get_layoutS_MN() + // Get the layout and reconstruct the MN mapping equivalent to the old + // get_layoutS_MN() auto layoutS_TV = params.xe_store_d.get_layoutS_TV(); auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); auto layoutS_MN = right_inverse(layoutS_TV).with_shape(mn_shape); @@ -382,78 +402,95 @@ class CollectiveEpilogue< auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - + bool is_C_load_needed = + is_source_supported && fusion_callbacks.is_C_load_needed(); + // Represent the full output tensor - Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M, N, L)); // Tile the output tensor per WG and select the tile for current WG - Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) - + Tensor g_wg_D = + local_tile(mD_mnl, take<0, 2>(CtaTileMNK{}), + make_coord(m_coord, n_coord, l_coord)); // (BLK_M,BLK_N) + // Tile the output tensor per SG and select tile for the current SG - Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + Tensor gD = local_tile(g_wg_D, take<0, 2>(SubgroupTileShape{}), + make_coord(m_sg, n_sg)); // (SG_M,SG_N) auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); Tensor tCgD = thread_xe_store_d.partition_D(gD); - Tensor trC = make_tensor(Shape>{}); - Tensor trD_compute = make_tensor(Shape>{}); + Tensor trC = + make_tensor(Shape>{}); + Tensor trD_compute = + make_tensor(Shape>{}); - // Because Sm90 uses shared memory, they are not tied to using the same accumulator values - // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be - // sure that we are operating on the same values. + // Because Sm90 uses shared memory, they are not tied to using the same + // accumulator values for MMA and Epilogue. But because we are operating + // directly in the accumulators, we need to be sure that we are operating on + // the same values. ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); // OOB predication for tile quantization "residue" // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); - Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor mD_crd = make_identity_tensor(make_shape(M, N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0, 2>(SubgroupTileShape{}), + make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0, 2>(CtaTileMNK{}), + make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S( + flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor tRS_cD = + make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) // Get the fusion callbacks - // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + // Arguments passed here relate to sub-group tiles, rather than CTA + // (work-group) tiles constexpr bool RefSrc = true; - auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto residue_mn = make_coord(M, N); // TODO(Codeplay): this is not correct auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - SubgroupTileShape{}, - sg_coord, - tiled_mma, - EpilogueTile{}, - params.xe_store_d, - cD, - residue_mn, - tRS_cD, - residue_mn, - trC, - thread_idx, - }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + EpilogueTile{}, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = + fusion_callbacks.template get_consumer_store_callbacks( + cst_args); cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = + recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); - constexpr int ValuesLoaded = - FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int ValuesLoaded = FragsM * FragsN * FragmentSize * SubgroupSize * + ATOM_M * ATOM_N * ATOM_K; constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); - static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + static_assert( + ValuesLoaded == MN, + "the total elements loaded by all threads should be the same as MxN"); - auto synchronize = [&] () {}; + auto synchronize = [&]() {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { CUTLASS_PRAGMA_UNROLL for (int epi_m = 0; epi_m < FragsM; epi_m++) { - if (is_C_load_needed) { - //cordinates for C and D are the same - copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + // cordinates for C and D are the same + copy(params.xe_load_c.with(get<0>(load_store_tensors)), + tCgD(_, epi_m, epi_n), trC); } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -462,16 +499,23 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_compute_frag(epi_v) = + cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } - cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); - + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, + (epi_m == FragsM - 1 && epi_n == FragsN - 1), + trD_compute_frag); + if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = + cutlass::NumericArrayConverter{}( + trD_compute_frag(i)); } - copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); + copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, + tCgD(_, epi_m, epi_n)); } } } @@ -481,34 +525,38 @@ class CollectiveEpilogue< template CUTLASS_DEVICE auto update_tensor_shape_stride( - int32_t const& next_group, - ProblemShape_MNKL const& problem_shape_mnkl) { - auto [M, N, K, L] = problem_shape_mnkl; + int32_t const& next_group, ProblemShape_MNKL const& problem_shape_mnkl) { + auto [M, N, K, L] = problem_shape_mnkl; - TensorC mC_mnl; - TensorD mD_mnl; - if constexpr (is_source_supported) { - ElementC const* ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); - mC_mnl = make_tensor(make_gmem_ptr(ptr_C_curr_batch), make_layout(make_shape(M, N, L), params.dC[next_group])); - } + TensorC mC_mnl; + TensorD mD_mnl; + if constexpr (is_source_supported) { + ElementC const* ptr_C_curr_batch = + reinterpret_cast(params.ptr_C[next_group]); + mC_mnl = + make_tensor(make_gmem_ptr(ptr_C_curr_batch), + make_layout(make_shape(M, N, L), params.dC[next_group])); + } - if constexpr (is_destination_supported) { - ElementD* ptr_D_curr_batch = reinterpret_cast(params.ptr_D[next_group]); - mD_mnl = make_tensor(make_gmem_ptr(ptr_D_curr_batch), make_layout(make_shape(M, N, L), params.dD[next_group])); - } - return cute::make_tuple(mC_mnl, mD_mnl); + if constexpr (is_destination_supported) { + ElementD* ptr_D_curr_batch = + reinterpret_cast(params.ptr_D[next_group]); + mD_mnl = + make_tensor(make_gmem_ptr(ptr_D_curr_batch), + make_layout(make_shape(M, N, L), params.dD[next_group])); } + return cute::make_tuple(mC_mnl, mD_mnl); + } -private: + private: Params const& params; FusionCallbacks fusion_callbacks; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace collective -} // namespace epilogue -} // namespace cutlass +} // namespace collective +} // namespace epilogue +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp index 3a1e84a..a2abb4b 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -43,11 +44,15 @@ namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, +template +struct CollectiveMma, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, + TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { // // Type Aliases @@ -72,10 +77,14 @@ struct CollectiveMma, TileShape_, El using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); + static_assert( + platform::is_same::value, + "MainloopIntelXeXMX16Array requires that A and B have same type."); - static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); - static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + static_assert(std::is_same_v, + "Transformation for A is not currently supported on Intel PVC"); + static_assert(std::is_same_v, + "Transformation for B is not currently supported on Intel PVC"); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -85,23 +94,33 @@ struct CollectiveMma, TileShape_, El static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); - static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_M = + get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = + get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = + get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); - using SubgroupTileShape = Shape; + using SubgroupTileShape = + Shape; static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - using Copy_A = typename Copy_Traits::template DefaultTiledCopy; - using Copy_B = typename Copy_Traits::template DefaultTiledCopy; - - using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) - using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) + using Copy_A = typename Copy_Traits< + GmemTiledCopyA, InternalStrideA>::template DefaultTiledCopy; + using Copy_B = typename Copy_Traits< + GmemTiledCopyB, InternalStrideB>::template DefaultTiledCopy; + + using TensorMKL = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideA{})); //(m, k) + using TensorNKL = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideB{})); //(n, k) using MainloopTensors = cute::tuple; // Host side kernel arguments struct Arguments { @@ -125,68 +144,80 @@ struct CollectiveMma, TileShape_, El CollectiveMma() = default; template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; - - auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));; + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + auto problem_shape_MNK = repeat_like( + typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + ; auto init_M = get<0>(problem_shape_MNK); auto init_N = get<1>(problem_shape_MNK); auto init_K = get<2>(problem_shape_MNK); - return Params{ - args.ptr_A, - args.dA, - args.ptr_B, - args.dB - }; + return Params{args.ptr_A, args.dA, args.ptr_B, args.dB}; } - template - static bool - can_implement( - ProblemShape problem_shapes, - Arguments const& args) { + template + static bool can_implement(ProblemShape problem_shapes, + Arguments const& args) { constexpr int copy_alignment_bits = 128; constexpr int batch_alignment_bits = 512; auto problem_shape_MNKL = append<4>(problem_shapes, 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; bool implementable = true; - constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; - constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; - constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; - constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_A = + copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = + copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = + batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = + batch_alignment_bits / sizeof_bits::value; for (int i = 0; i < problem_shapes.groups(); i++) { - auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto problem_shape_MNKL = + append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; - implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); - implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + implementable &= cutlass::detail::check_alignment( + cute::make_shape(M, K, L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment( + cute::make_shape(N, K, L), InternalStrideB{}); if (L > 1) { - implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; - implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + implementable &= + get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= + get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; } } if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for XE 2D copy.\n"); } return implementable; } /// Perform a subgroup-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int const& k_tile_count, - BlkCoord const &blk_coord, int const &K_start, int const& thread_idx, - Params const &mainloop, LoadTensors const& load_tensors) { - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); + template + CUTLASS_DEVICE void operator()(FrgTensorD& accum, TensorA gA, TensorB gB, + FrgTensorC const& src_accum, + KTileIterator k_tile_iter, + int const& k_tile_count, + BlkCoord const& blk_coord, int const& K_start, + int const& thread_idx, Params const& mainloop, + LoadTensors const& load_tensors) { + static_assert(is_rmem::value, + "D tensor must be rmem resident."); + static_assert(is_rmem::value, + "C tensor must be rmem resident."); (void)thread_idx; @@ -199,28 +230,36 @@ struct CollectiveMma, TileShape_, El // Instantiate the MMA object and get thread slice TiledMma tiled_mma; // TODO(Codeplay): see if we can make this nicer - // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup + // To make all work items in a subgroup have the same global tensors pass in + // the index of work item 0 in each subgroup auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto first_thread_in_sg_idx = + sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); // Partition global counting tensors for MMA Tensor tCgA = thr_mma.partition_A(gA); Tensor tCgB = thr_mma.partition_B(gB); - Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); - Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); - + Tensor tCrA = make_tensor( + make_fragment_layout(tiled_copy_a, tCgA(_, _, _, 0).shape())); + Tensor tCrB = make_tensor( + make_fragment_layout(tiled_copy_b, tCgB(_, _, _, 0).shape())); + // Retile registers for copies Tensor tArA = thr_copy_A.retile_D(tCrA); Tensor tBrB = thr_copy_B.retile_D(tCrB); - + // Retile global counting tensors for copies Tensor tAgA = thr_copy_A.retile_S(tCgA); Tensor tBgB = thr_copy_B.retile_S(tCgB); - auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a); - auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b); + auto tiled_prefetch_a = + cute::prefetch_selector, Int>, Num_SGs>( + tiled_copy_a); + auto tiled_prefetch_b = + cute::prefetch_selector, Int>, Num_SGs>( + tiled_copy_b); auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); @@ -230,20 +269,36 @@ struct CollectiveMma, TileShape_, El #if CUTLASS_ENABLE_DEBUG_PRINTS if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { - print("======================= A: \n"); - print(" gA : "); print(gA); print("\n"); - print("tCgA : "); print(tCgA); print("\n"); - print("tAgA : "); print(tAgA); print("\n"); - - print("===================== B :\n"); - print(" gB : "); print(gB); print("\n"); - print("tCgB : "); print(tCgB); print("\n"); - print("tBgB : "); print(tBgB); print("\n"); - - print("===================== Config: \n"); - print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); - print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); - } + print("======================= A: \n"); + print(" gA : "); + print(gA); + print("\n"); + print("tCgA : "); + print(tCgA); + print("\n"); + print("tAgA : "); + print(tAgA); + print("\n"); + + print("===================== B :\n"); + print(" gB : "); + print(gB); + print("\n"); + print("tCgB : "); + print(tCgB); + print("\n"); + print("tBgB : "); + print(tBgB); + print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); + print(MaxThreadsPerBlock); + print("\n"); + print(" SubgroupTileShape : "); + print(SubgroupTileShape{}); + print("\n"); + } #endif // @@ -259,11 +314,12 @@ struct CollectiveMma, TileShape_, El prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); } - for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; + k_tile++, prefetch_k++) { barrier_arrive(barrier_scope); // Copy gmem to rmem for the first k_tile - copy(tiled_copy_a, tAgA(_,_,_,k_tile), tArA); - copy(tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); + copy(tiled_copy_a, tAgA(_, _, _, k_tile), tArA); + copy(tiled_copy_b, tBgB(_, _, _, k_tile), tBrB); if (prefetch_k < k_tile_count) { prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); @@ -277,23 +333,28 @@ struct CollectiveMma, TileShape_, El template CUTLASS_DEVICE auto update_tensor_shape_stride( - Params const& mainloop_params, - int32_t const& next_group, - ProblemShape_MNKL const& problem_shape_mnkl) { - const int32_t M = get<0>(problem_shape_mnkl); - const int32_t N = get<1>(problem_shape_mnkl); - const int32_t K = get<2>(problem_shape_mnkl); - - ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); - ElementB const* ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[next_group]); - - Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K,(int32_t)1), mainloop_params.dA[next_group]); - Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K,(int32_t)1), mainloop_params.dB[next_group]); - - return cute::make_tuple(mA, mB); - } + Params const& mainloop_params, int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + const int32_t M = get<0>(problem_shape_mnkl); + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const* ptr_A_curr_batch = + reinterpret_cast(mainloop_params.ptr_A[next_group]); + ElementB const* ptr_B_curr_batch = + reinterpret_cast(mainloop_params.ptr_B[next_group]); + + Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), + make_shape(M, K, (int32_t)1), + mainloop_params.dA[next_group]); + Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), + make_shape(N, K, (int32_t)1), + mainloop_params.dB[next_group]); + + return cute::make_tuple(mA, mB); + } }; -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp index abaa091..ca749c3 100644 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -43,253 +44,191 @@ namespace cutlass::epilogue::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape +// Used to specify epilogue subtile shape or dispatch to automatic computation +// of subtile shape struct EpilogueTileAuto {}; // Used to let the builder pick the epilogue schedule automatically. -// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +// Can be overridden with kernel schedule tags in +// cutlass/gemm/dispatch_policy.hpp struct EpilogueScheduleAuto {}; template < - class ArchTag, - class OpClass, - class TileShape_MNK, - class ClusterShape_MNK, - class EpilogueTileType, - class ElementAccumulator, - class ElementCompute, - class ElementC, - class GmemLayoutTagC, - int AlignmentC, - class ElementD, - class GmemLayoutTagD, - int AlignmentD, - class EpilogueScheduleType, - class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, - class Enable = void -> + class ArchTag, class OpClass, class TileShape_MNK, class ClusterShape_MNK, + class EpilogueTileType, class ElementAccumulator, class ElementCompute, + class ElementC, class GmemLayoutTagC, int AlignmentC, class ElementD, + class GmemLayoutTagD, int AlignmentD, class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination< + ElementD, ElementCompute, ElementC, ElementCompute>, + class Enable = void> struct CollectiveBuilder { static_assert(cutlass::detail::dependent_false, - "Could not build a collective epilogue for given parameters."); + "Could not build a collective epilogue for given parameters."); }; -// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) +// helper sub-builder for epilogue fusion callbacks (for internal use by +// CollectiveBuilder only) namespace detail { // callbacks builder with operation tag -template< - class DispatchPolicy, - class FusionOp, - class TileShape_MNK, - class EpilogueTile_MN, - class ElementAccumulator, - class AccLoadOp = cute::DefaultCopy, - class = void -> +template struct CallbacksBuilder { - using Callbacks = fusion::FusionCallbacks; + using Callbacks = fusion::FusionCallbacks; }; // callbacks builder with callbacks passthrough -template < - class DispatchPolicy, - class FusionCallbacks, - class TileShape_MNK, - class EpilogueTile_MN, - class AccLoadOp, - class ElementAccumulator -> -struct CallbacksBuilder< - DispatchPolicy, - FusionCallbacks, - TileShape_MNK, - EpilogueTile_MN, - ElementAccumulator, - AccLoadOp, - cute::enable_if_t> -> { +template +struct CallbacksBuilder>> { using Callbacks = FusionCallbacks; }; -} // namespace detail +} // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::epilogue::collective +} // namespace cutlass::epilogue::collective ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::epilogue::collective { namespace detail { - template - struct FusionOpInfo { - static_assert(cutlass::detail::dependent_false, - "Could not find a builder specialization."); - }; - - template < - class ElementD, - class ElementCompute, - class ElementC - > - struct FusionOpInfo> { - constexpr static bool HasBuilder = true; - - template < - class DispatchPolicy, - class TileShape_MNK, - class EpilogueTile, - class> - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinearCombination, - TileShape_MNK, - EpilogueTile - >; - }; - - template < - template class ActivationFn, - class ElementD, - class ElementCompute, - class ElementC - > - struct FusionOpInfo> { - constexpr static bool HasBuilder = true; - template < - class DispatchPolicy, - class TileShape_MNK, - class EpilogueTile, - class> - - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinCombEltAct, - TileShape_MNK, - EpilogueTile - >; - }; +template +struct FusionOpInfo { + static_assert(cutlass::detail::dependent_false, + "Could not find a builder specialization."); +}; - template < - class GmemLayoutTagC, - template class ActivationFn, - class ElementD, - class ElementCompute, - class ElementC - > - struct FusionOpInfo> { - constexpr static bool HasBuilder = true; +template +struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + + template + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinearCombination, + TileShape_MNK, EpilogueTile>; +}; - template < - class DispatchPolicy, - class TileShape_MNK, - class EpilogueTile, - class CopyOpG2R> - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinCombDeEltAct, - TileShape_MNK, - EpilogueTile, - CopyOpG2R - >; - }; -} // namespace detail +template