From f22f9566dc7764f03cf3680b331b9fa932934294 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 9 May 2025 17:14:32 -0700 Subject: [PATCH 01/11] kernel: add grouped gemm support for moe --- src/kernels/CMakeLists.txt | 10 ++++----- src/kernels/gemm/CMakeLists.txt | 24 +++++++++++++++++++++ src/kernels/gemm/grouped_gemm_kernel_sm80.h | 3 +++ 3 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 src/kernels/gemm/CMakeLists.txt create mode 100644 src/kernels/gemm/grouped_gemm_kernel_sm80.h diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index 999a1393..1dfdf69e 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -1,16 +1,16 @@ include(cc_library) cc_library( - NAME + NAME kernels - HDRS + HDRS reduce_kernel_utils.cuh activation_kernels.h layernorm_kernels.h pos_embedding_kernels.h kv_cache_kernels.h sampling/sampling_kernels.h - SRCS + SRCS activation_kernels.cu layernorm_kernels.cu pos_embedding_kernels.cu @@ -28,7 +28,7 @@ cc_library( add_subdirectory(attention) add_subdirectory(moe) +add_subdirectory(gemm) add_subdirectory(quantization) add_subdirectory(playground) -add_subdirectory(triton) - +# add_subdirectory(triton) diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt new file mode 100644 index 00000000..3fb7302c --- /dev/null +++ b/src/kernels/gemm/CMakeLists.txt @@ -0,0 +1,24 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + gemm.kernels + HDRS + grouped_gemm_kernel_sm80.h + DEPS + cutlass +) + + +# cc_test( +# NAME +# gemm_kernel_test +# SRCS +# grouped_gemm_kernel_sm80_test.cu +# DEPS +# :gemm.kernels +# absl::random_random +# GTest::gtest_main +# torch +# ) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.h b/src/kernels/gemm/grouped_gemm_kernel_sm80.h new file mode 100644 index 00000000..6b74c06f --- /dev/null +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.h @@ -0,0 +1,3 @@ +#pragma once + +namespace llm {} // namespace llm From 311df8c7049a72c899750377a940c660de8485b1 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 9 May 2025 17:14:32 -0700 Subject: [PATCH 02/11] kernel: add grouped gemm support for moe --- src/kernels/gemm/CMakeLists.txt | 24 ++-- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 115 ++++++++++++++++++ 2 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 src/kernels/gemm/grouped_gemm_kernel_sm80.cuh diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt index 3fb7302c..99cc5920 100644 --- a/src/kernels/gemm/CMakeLists.txt +++ b/src/kernels/gemm/CMakeLists.txt @@ -5,20 +5,20 @@ cc_library( NAME gemm.kernels HDRS - grouped_gemm_kernel_sm80.h + grouped_gemm_kernel_sm80.cuh DEPS cutlass ) -# cc_test( -# NAME -# gemm_kernel_test -# SRCS -# grouped_gemm_kernel_sm80_test.cu -# DEPS -# :gemm.kernels -# absl::random_random -# GTest::gtest_main -# torch -# ) +cc_test( + NAME + gemm_kernel_test + SRCS + grouped_gemm_kernel_sm80_test.cu + DEPS + :gemm.kernels + absl::random_random + GTest::gtest_main + torch +) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh new file mode 100644 index 00000000..5de78663 --- /dev/null +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -0,0 +1,115 @@ +#pragma once +#include +#include + +#include +#include +#include + +namespace llm { +using namespace cute; +template +struct GEMMTraitsSM80 { + static constexpr int kDim = DIM; + static constexpr int kBlockM = BLK_M; + static constexpr int kBlockN = BLK_N; + static constexpr int kBlockK = BLK_K; + static constexpr int kStages = STAGES; + + static_assert(kBlockM % 64 == 0); + static_assert(kBlockN % 32 == 0); + static_assert(kBlockK % 16 == 0); + + // helpful aliases + using DType = DTYPE; + using _BLK_M = Int; + using _BLK_N = Int; + using _BLK_K = Int; + using _STAGES = Int; + using _DIM = Int; + + // TiledMMA: (64x32x16) + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + using TiledMma = + TiledMMA>, Tile<_64, _32, _16>>; + + // Shared memory LayoutAtom (8x64) + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + + using SmemLayoutAtomK = std::conditional_t; + // SMEM Layout for A: (BLK_M, BLK_K, STAGES) + using SmemLayoutA = + decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_M, _BLK_K>{})); + // SMEM Layout for B: (BLK_N, BLK_K, STAGES) + using SmemLayoutB = + decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_N, _BLK_K>{})); + + // Gmem tiled copy: copy A/B from global memory to shared memory (32x64) + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) + Layout>{} // Val layout: 8 vals per read + )); +}; + +template +struct GEMMSharedStorageSM80 {}; + +struct GEMMParams { + // input tensors: + // A: (m, k) + // sorted_token_idxes: (m*topk+) = (n_blocks, BLK_M) + // B: (e, n, k) + // expert_ids/group_ids: (n_blocks) + // n_tokens + // output tensors: + // C: ((m, topk), n) + // +}; + +template +__global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( + __grid_constant__ const Params params) { + using namespace cute; + + // each block takes care of one block: (BLK_M, BLK_N) + // 1: load A to smem: (BLK_M, BLK_K, STAGES) + // load sorted_token_idxes from global memory to registers, (BLK_M) + // 2: load B to smem: (BLK_N, BLK_K, STAGES) + // load group_id for current block from global memory to registers, (1) + // 3: iterate over k + // 4: partition A to tCsA, tCrA + // 5: partition B to tCsB, tCrB + // load a, b to registers + // 6: compute tCrA * tCrB with gemm + // 7: write tCrC to global memory using sorted_token_idxes +} + +template +void launch_grouped_gemm_kernel_sm80(const Params& params, + cudaStream_t stream) { + const auto batch_size = params.batch_size; + const auto max_q_packed_len = params.max_q_len * params.n_heads; + + const auto smem_size = sizeof(GEMMSharedStorageSM80); + + auto gemm_kernel = grouped_gemm_kernel_sm80; + cudaFuncSetAttribute( + gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // TODO: support persistent kernels + dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, 1); + dim3 block = Traits::kThreadNum; + gemm_kernel<<>>(params); +} + +} // namespace llm From f0de7a00c1ec92f4e5aa240f1611e88e15039585 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 2 Jun 2025 10:19:09 -0700 Subject: [PATCH 03/11] add unit test file --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 63 ++++++++++-------- src/kernels/gemm/grouped_gemm_kernel_sm80.h | 3 - .../gemm/grouped_gemm_kernel_sm80_test.cu | 64 +++++++++++++++++++ 3 files changed, 99 insertions(+), 31 deletions(-) delete mode 100644 src/kernels/gemm/grouped_gemm_kernel_sm80.h create mode 100644 src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index 5de78663..cdca5f89 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -8,6 +8,7 @@ namespace llm { using namespace cute; + template struct GEMMTraitsSM80 { static constexpr int kDim = DIM; @@ -36,30 +37,34 @@ struct GEMMTraitsSM80 { using TiledMma = TiledMMA>, Tile<_64, _32, _16>>; - // Shared memory LayoutAtom (8x64) - using SmemLayoutAtom_8x64 = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride<_64, _1>>{})); - using SmemLayoutAtom_8x32 = - decltype(composition(Swizzle<2, 3, 3>{}, - Layout, Stride<_32, _1>>{})); - - using SmemLayoutAtomK = std::conditional_t; - // SMEM Layout for A: (BLK_M, BLK_K, STAGES) - using SmemLayoutA = - decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_M, _BLK_K>{})); - // SMEM Layout for B: (BLK_N, BLK_K, STAGES) - using SmemLayoutB = - decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_N, _BLK_K>{})); - - // Gmem tiled copy: copy A/B from global memory to shared memory (32x64) - using GmemTiledCopy = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) - Layout>{} // Val layout: 8 vals per read - )); + // // Shared memory LayoutAtom (8x64) + // using SmemLayoutAtom_8x64 = + // decltype(composition(Swizzle<3, 3, 3>{}, + // Layout, Stride<_64, _1>>{})); + // using SmemLayoutAtom_8x32 = + // decltype(composition(Swizzle<2, 3, 3>{}, + // Layout, Stride<_32, _1>>{})); + + // using SmemLayoutAtomK = std::conditional_t; + // // SMEM Layout for A: (BLK_M, BLK_K, STAGES) + // using SmemLayoutA = + // decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_M, _BLK_K>{})); + // // SMEM Layout for B: (BLK_N, BLK_K, STAGES) + // using SmemLayoutB = + // decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_N, _BLK_K>{})); + + // // Gmem tiled copy: copy A/B from global memory to shared memory (32x64) + // using GmemTiledCopy = decltype(make_tiled_copy( + // Copy_Atom, DType>{}, + // Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) + // Layout>{} // Val layout: 8 vals per + // read + // )); + + // constexpr values for kernel launch + static constexpr size_t kThreadNum = size(TiledMma{}); }; template @@ -80,7 +85,7 @@ struct GEMMParams { template __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( __grid_constant__ const Params params) { - using namespace cute; + // using namespace cute; // each block takes care of one block: (BLK_M, BLK_N) // 1: load A to smem: (BLK_M, BLK_K, STAGES) @@ -98,8 +103,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( template void launch_grouped_gemm_kernel_sm80(const Params& params, cudaStream_t stream) { - const auto batch_size = params.batch_size; - const auto max_q_packed_len = params.max_q_len * params.n_heads; + // const auto batch_size = params.batch_size; + // const auto max_q_packed_len = params.max_q_len * params.n_heads; const auto smem_size = sizeof(GEMMSharedStorageSM80); @@ -107,7 +112,9 @@ void launch_grouped_gemm_kernel_sm80(const Params& params, cudaFuncSetAttribute( gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); // TODO: support persistent kernels - dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, 1); + // dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, + // 1); + dim3 grid(1, 1, 1); // Placeholder for grid dimensions, adjust as needed dim3 block = Traits::kThreadNum; gemm_kernel<<>>(params); } diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.h b/src/kernels/gemm/grouped_gemm_kernel_sm80.h deleted file mode 100644 index 6b74c06f..00000000 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once - -namespace llm {} // namespace llm diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu new file mode 100644 index 00000000..2fc3ee73 --- /dev/null +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -0,0 +1,64 @@ +#include +#include + +#include + +#include "grouped_gemm_kernel_sm80.cuh" // IWYU pragma: keep + +namespace llm { + +namespace { + +torch::Tensor grouped_gemm_sm80() { + auto out = torch::empty({}); // Placeholder for output tensor + + using Traits = GEMMTraitsSM80; /*STAGES*/ + + GEMMParams params; + launch_grouped_gemm_kernel_sm80(params, nullptr); + + return out; +} + +torch::Tensor grouped_gemm_ref() { + return torch::empty({}); // Replace with actual reference computation +} + +} // namespace + +class GroupedGemmKernelTest + : public ::testing::TestWithParam> { + public: + void SetUp() override { + // Set random seed for test stability + torch::manual_seed(0); + } +}; + +TEST_P(GroupedGemmKernelTest, GEMM) { + const auto [dtype, batch_size, dim] = GetParam(); + + const auto options = torch::dtype(dtype).device(torch::kCUDA); + + auto ref_out = grouped_gemm_ref(); + auto out = grouped_gemm_sm80(); + + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); +} + +INSTANTIATE_TEST_SUITE_P( + GEMM, + GroupedGemmKernelTest, + ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype + ::testing::Values(1), // batch_size + ::testing::Values(64) // dim + )); + +} // namespace llm From 2a5eb590b07155800a83947bce2a758a21c7b0f8 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Sun, 8 Jun 2025 10:34:37 -0700 Subject: [PATCH 04/11] added grouped_gemm_ref --- .../gemm/grouped_gemm_kernel_sm80_test.cu | 76 ++++++++++++++++--- src/kernels/moe/align_block_kernel_test.cu | 4 - 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index 2fc3ee73..d1a608d9 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -25,16 +25,51 @@ torch::Tensor grouped_gemm_sm80() { return out; } -torch::Tensor grouped_gemm_ref() { - return torch::empty({}); // Replace with actual reference computation +// returns (m, topk, n) +torch::Tensor grouped_gemm_ref(const torch::Tensor& a, // (m, k) + const torch::Tensor& w, // (e, n, k) + const torch::Tensor& topk_ids // (m, topk) + +) { + const auto m = a.size(0); + const auto k = a.size(1); + const auto n = w.size(1); + const auto n_experts = w.size(0); + const auto topk = topk_ids.size(1); + + // (m * topk, n) + auto out = torch::zeros({m * topk, n}, a.options()); + + // (m, k) => (m, topk, k) => (m * topk, k) + auto a_expanded_flat = + a.unsqueeze(/*dim=*/1).expand({-1, topk, -1}).reshape({-1, k}); + // (m, topk) => (m * topk) + auto topk_ids_flat = topk_ids.reshape({-1}); + + // process each expert + for (int64_t e = 0; e < n_experts; ++e) { + // 1D indices for the current expert + auto indices = torch::nonzero(topk_ids_flat == e).squeeze(); + // select corresponding tokens + auto a_selected = a_expanded_flat.index_select(/*dim=*/0, indices); + // perform the GEMM operation for this expert + auto e_out = torch::matmul(a_selected, w[e].transpose(0, 1)); + // copy the results into the output tensor + out.index_copy_(/*dim=*/0, indices, e_out); + } + // (m * topk, n) => (m, topk, n) + return out.view({m, topk, n}); } } // namespace class GroupedGemmKernelTest - : public ::testing::TestWithParam> { + : public ::testing::TestWithParam> { public: void SetUp() override { // Set random seed for test stability @@ -43,22 +78,39 @@ class GroupedGemmKernelTest }; TEST_P(GroupedGemmKernelTest, GEMM) { - const auto [dtype, batch_size, dim] = GetParam(); + const auto [dtype, m, n, k, n_experts, topk] = GetParam(); const auto options = torch::dtype(dtype).device(torch::kCUDA); - auto ref_out = grouped_gemm_ref(); - auto out = grouped_gemm_sm80(); + // Create input tensors + auto a = torch::randn({m, k}, options); + auto w = torch::randn({n_experts, n, k}, options); + + // Get top-k indices + auto logits = torch::randn({m, n_experts}, options).softmax(/*dim=*/1); + auto [topk_weights, topk_ids] = logits.topk(topk, /*dim=*/1); + + // LOG(ERROR) << "a: " << a; + // LOG(ERROR) << "w: " << w; + // LOG(ERROR) << "topk_ids: " << topk_ids; + // LOG(ERROR) << "topk_weights: " << topk_weights; + + auto ref_out = grouped_gemm_ref(a, w, topk_ids); + // LOG(ERROR) << "ref_out: " << ref_out; + // auto out = grouped_gemm_sm80(); - EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); } INSTANTIATE_TEST_SUITE_P( GEMM, GroupedGemmKernelTest, - ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype - ::testing::Values(1), // batch_size - ::testing::Values(64) // dim + ::testing::Combine(::testing::Values(torch::kHalf), // dtype + ::testing::Values(1), // m + ::testing::Values(64), // n + ::testing::Values(64), // k + ::testing::Values(8), // n_experts + ::testing::Values(4) // topk )); } // namespace llm diff --git a/src/kernels/moe/align_block_kernel_test.cu b/src/kernels/moe/align_block_kernel_test.cu index 73593799..66f0a418 100644 --- a/src/kernels/moe/align_block_kernel_test.cu +++ b/src/kernels/moe/align_block_kernel_test.cu @@ -146,10 +146,6 @@ class AlignBlockTest TEST_P(AlignBlockTest, AlignBlock) { const auto [dtype, n_tokens, dim, n_experts, topk, block_size] = GetParam(); const int64_t n_flatten_tokens = n_tokens * topk; - if (n_flatten_tokens >= 1024 || n_experts > 64) { - // TODO: reenable unittest after fixing tokens out of order issue - return; - } const auto options = torch::dtype(dtype).device(torch::kCUDA); const auto options_int32 = options.dtype(torch::kInt32); From 45adfb6e85e8d6a2e205e3b2c2f1d6da6c3e1524 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 9 Jun 2025 11:28:25 -0700 Subject: [PATCH 05/11] added code for gmem tensor --- src/kernels/gemm/gather_tensor.hpp | 167 ++++++++++++++++++ src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 103 +++++++++-- .../gemm/grouped_gemm_kernel_sm80_test.cu | 21 ++- 3 files changed, 275 insertions(+), 16 deletions(-) create mode 100644 src/kernels/gemm/gather_tensor.hpp diff --git a/src/kernels/gemm/gather_tensor.hpp b/src/kernels/gemm/gather_tensor.hpp new file mode 100644 index 00000000..79ca581b --- /dev/null +++ b/src/kernels/gemm/gather_tensor.hpp @@ -0,0 +1,167 @@ +// adapted from +// https://github.com/NVIDIA/cutlass/blob/main/examples/common/gather_tensor.hpp +#pragma once + +#include "cute/layout.hpp" +#include "cute/layout_composed.hpp" +#include "cute/tensor.hpp" +namespace llm { + +using namespace cute; + +namespace detail { + +// every stride must be divisible by div +template +CUTE_HOST_DEVICE constexpr auto safe_stride_div(Stride const& s, + const Div& div) { + if constexpr (is_tuple::value) { + return transform(s, [&](auto const& a) { return safe_stride_div(a, div); }); + } else { + return safe_div(s, div); + } + CUTE_GCC_UNREACHABLE; +} + +} // namespace detail + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride { + CUTE_HOST_DEVICE constexpr CustomStride(const Func& func, + const Stride& stride) + : func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, const CustomStride& s) { + return inner_product(s.func_(i), s.stride_); + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(const CustomStride& s, I i) { + return inner_product(s.func_(i), s.stride_); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(const CustomStride& s, + const Div& div) { + auto stride = detail::safe_stride_div(s.stride_, div); + return CustomStride(s.func_, stride); + } + + template + CUTE_HOST_DEVICE constexpr friend auto make_layout( + const Shape& shape, + const CustomStride& stride) { + return Layout(shape, stride); + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) { + print("CustomStride{func,"); + print(s.stride_); + print("}"); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Func&& func, + const Shape& shape, + const Stride& stride) { + // Use a dummy shape and replace the first non-unit stride with a custom + // gather stride + auto idx = + find_if(stride, [](auto x) { return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(shape, _1{}), + replace(stride, + CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, + const Shape& shape, + const Stride& stride, + Func&& func) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = + make_custom_stride_layout(static_cast(func), shape, stride); + return make_tensor(iter, + ComposedLayout{gather_layout, offset, matrix_layout}); +} + +} // namespace llm + +namespace cute { + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, + Stride const& stride) { + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { + return upcast(s, d); + }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, + Offset, + Layout> const& layout) { + // Find index of the stride-1 mode - that is the only one that requires + // updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), + [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple( + replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = + upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +template +CUTE_HOST_DEVICE constexpr auto max_common_vector( + Layout const& a, + ComposedLayout, + OffsetB, + Layout> const& b) { + return max_common_vector(b.layout_b(), a); +} + +} // namespace cute diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index cdca5f89..bceee724 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -6,6 +6,8 @@ #include #include +#include "gather_tensor.hpp" + namespace llm { using namespace cute; @@ -71,33 +73,112 @@ template struct GEMMSharedStorageSM80 {}; struct GEMMParams { - // input tensors: + using AStride = Stride; + using BStride = Stride; + using CStride = Stride; + // A: (m, k) - // sorted_token_idxes: (m*topk+) = (n_blocks, BLK_M) + const void* __restrict__ a_ptr = nullptr; + AStride a_stride; + // B: (e, n, k) - // expert_ids/group_ids: (n_blocks) - // n_tokens - // output tensors: + const void* __restrict__ b_ptr = nullptr; + BStride b_stride; + // C: ((m, topk), n) - // + void* __restrict__ c_ptr = nullptr; + CStride c_stride; + + // (m_blocks, BLK_M) + const int* __restrict__ sorted_token_idxes_ptr = nullptr; + // (m_blocks) + const int* __restrict__ expert_ids_ptr = nullptr; + + int m = 0; + int n = 0; + int k = 0; + int topk = 0; + int n_tokens_padded = 0; }; template __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( __grid_constant__ const Params params) { - // using namespace cute; + // Traits + constexpr int kBlockM = Traits::kBlockM; + constexpr int kBlockN = Traits::kBlockN; + constexpr int kBlockK = Traits::kBlockK; + + using _BLK_M = Int; + using _BLK_N = Int; + using _BLK_K = Int; + + using DTYPE = typename Traits::DType; + + const auto topk = params.topk; + // ProblemShape + // each thread block takes care of one block: (BLK_M, BLK_N) + const auto m_block_idx = blockIdx.x; + const auto n_block_idx = blockIdx.y; + // const auto expert_id = params.expert_ids_ptr[m_block_idx]; + const auto expert_id = 0; - // each block takes care of one block: (BLK_M, BLK_N) // 1: load A to smem: (BLK_M, BLK_K, STAGES) - // load sorted_token_idxes from global memory to registers, (BLK_M) + // load sorted_token_idxes from gmem, (m, topk) => (BLK_M) + const int* sorted_token_idxes = + params.sorted_token_idxes_ptr + m_block_idx * kBlockM; + auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) { + // Convert to token index + return sorted_token_idxes[idx] / topk; + }; + // A: (BLK_M, K) + auto A = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.a_ptr), + make_shape(kBlockM, params.k), + make_stride(get<0>(params.a_stride), _1{}), + idx_to_t_idx); + if (thread0()) { + print("A: "); + print(A); + print("\n"); + } + // 2: load B to smem: (BLK_N, BLK_K, STAGES) - // load group_id for current block from global memory to registers, (1) + // load expert_id for current block from gmem, (1) + // B: (BLK_N, K) + // (e, n, k) => (BLK_N, k) + const auto b_offset = expert_id * get<0>(params.b_stride) + + n_block_idx * get<1>(params.b_stride); + auto B = make_tensor(make_gmem_ptr((const DTYPE*)params.b_ptr + b_offset), + make_shape(kBlockN, params.k), + make_stride(get<1>(params.b_stride), _1{})); + if (thread0()) { + print("B: "); + print(B); + print("\n"); + } + + // Accumulator: (BLK_M, BLK_N) // 3: iterate over k // 4: partition A to tCsA, tCrA // 5: partition B to tCsB, tCrB // load a, b to registers // 6: compute tCrA * tCrB with gemm - // 7: write tCrC to global memory using sorted_token_idxes + + // C: (BLK_M, BLK_N) + // 7: write tCrC to global memory using sorted_token_idxes (m, topk) + auto idx_to_f_idx = [sorted_token_idxes](int idx) { + // Convert to token index + return sorted_token_idxes[idx]; + }; + auto C = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.c_ptr), + make_shape(kBlockM, kBlockN), + make_stride(get<0>(params.c_stride), _1{}), + idx_to_f_idx); + if (thread0()) { + print("C: "); + print(C); + print("\n"); + } } template diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index d1a608d9..38880940 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -9,8 +9,18 @@ namespace llm { namespace { -torch::Tensor grouped_gemm_sm80() { - auto out = torch::empty({}); // Placeholder for output tensor +torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) + const torch::Tensor& w, // (e, n, k) + const torch::Tensor& topk_ids // (m, topk) +) { + const auto m = a.size(0); + const auto k = a.size(1); + const auto n = w.size(1); + const auto n_experts = w.size(0); + const auto topk = topk_ids.size(1); + + // (m * topk, n) + auto out = torch::zeros({m * topk, n}, a.options()); using Traits = GEMMTraitsSM80(params, nullptr); - return out; + // (m * topk, n) => (m, topk, n) + return out.view({m, topk, n}); } // returns (m, topk, n) @@ -97,7 +108,7 @@ TEST_P(GroupedGemmKernelTest, GEMM) { auto ref_out = grouped_gemm_ref(a, w, topk_ids); // LOG(ERROR) << "ref_out: " << ref_out; - // auto out = grouped_gemm_sm80(); + auto out = grouped_gemm_sm80(a, w, topk_ids); // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); } @@ -106,7 +117,7 @@ INSTANTIATE_TEST_SUITE_P( GEMM, GroupedGemmKernelTest, ::testing::Combine(::testing::Values(torch::kHalf), // dtype - ::testing::Values(1), // m + ::testing::Values(64), // m ::testing::Values(64), // n ::testing::Values(64), // k ::testing::Values(8), // n_experts From fa2112f8feaf19fa4f65a950d5a7386052eb955c Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 9 Jun 2025 17:03:34 -0700 Subject: [PATCH 06/11] added tiledcopy and tiledmma --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 312 ++++++++++++++---- .../gemm/grouped_gemm_kernel_sm80_test.cu | 90 ++++- 2 files changed, 330 insertions(+), 72 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index bceee724..1e19a9d1 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -31,46 +31,92 @@ struct GEMMTraitsSM80 { using _STAGES = Int; using _DIM = Int; - // TiledMMA: (64x32x16) + // MMA Atom: (16x8x16) for F32F16F16F32 or F32BF16BF16F32 using MMA_Atom_ = std::conditional_t, MMA_Atom, MMA_Atom>; - using TiledMma = - TiledMMA>, Tile<_64, _32, _16>>; - - // // Shared memory LayoutAtom (8x64) - // using SmemLayoutAtom_8x64 = - // decltype(composition(Swizzle<3, 3, 3>{}, - // Layout, Stride<_64, _1>>{})); - // using SmemLayoutAtom_8x32 = - // decltype(composition(Swizzle<2, 3, 3>{}, - // Layout, Stride<_32, _1>>{})); - - // using SmemLayoutAtomK = std::conditional_t; - // // SMEM Layout for A: (BLK_M, BLK_K, STAGES) - // using SmemLayoutA = - // decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_M, _BLK_K>{})); - // // SMEM Layout for B: (BLK_N, BLK_K, STAGES) - // using SmemLayoutB = - // decltype(tile_to_shape(SmemLayoutAtomK{}, Shape<_BLK_N, _BLK_K>{})); - - // // Gmem tiled copy: copy A/B from global memory to shared memory (32x64) - // using GmemTiledCopy = decltype(make_tiled_copy( - // Copy_Atom, DType>{}, - // Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) - // Layout>{} // Val layout: 8 vals per - // read - // )); + + // TiledMMA: (64x32x16) + using TiledMma = TiledMMA>, // warp layout: (4x1x1) + Tile<_64, _32, _16>>; // tile layout: (64x16x16) + + // Shared memory LayoutAtom (8x64) + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + + using SmemLayoutAtom = std::conditional_t; + // SMEM Layout for A: (BLK_M, BLK_K, STAGES) + using SmemLayoutA = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_K>{})); + // SMEM Layout for B: (BLK_N, BLK_K, STAGES) + using SmemLayoutB = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _BLK_K>{})); + + // Thread layout for gmem copy: (_16,_8)/(_32, _4) + using GmemCopyThrLayout = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + // g2s tiled copy: copy A/B from global memory to shared memory + using GmemTiledCopyAB = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // s2r tiled copy for A and B + using SmemTiledCopyA = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma{})); + using SmemTiledCopyB = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // ******* Epilogue ******* + + using SmemLayoutAtomC = std::conditional_t; + using SmemLayoutC = + decltype(tile_to_shape(SmemLayoutAtomC{}, Shape<_BLK_M, _BLK_N>{})); + + // use 128-bit vectorizing copy + using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + // s2g tiled copy for C + using SmemTiledCopyC = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma{})); // constexpr values for kernel launch static constexpr size_t kThreadNum = size(TiledMma{}); }; template -struct GEMMSharedStorageSM80 {}; +struct GEMMSharedStorageSM80 { + using DType = typename Traits::DType; + using SmemLayoutA = typename Traits::SmemLayoutA; + using SmemLayoutB = typename Traits::SmemLayoutB; + using SmemLayoutC = typename Traits::SmemLayoutC; + + union { + struct { + // Shared memory for A: (BLK_M, BLK_K, STAGES) + cute::array_aligned> a_smem; + // Shared memory for B: (BLK_N, BLK_K, STAGES) + cute::array_aligned> b_smem; + }; + // Shared memory for C: (BLK_M, BLK_N) + cute::array_aligned> c_smem; + }; +}; struct GEMMParams { using AStride = Stride; @@ -94,11 +140,13 @@ struct GEMMParams { // (m_blocks) const int* __restrict__ expert_ids_ptr = nullptr; + const int* __restrict__ n_tokens_padded = nullptr; + int m = 0; int n = 0; int k = 0; int topk = 0; - int n_tokens_padded = 0; + int n_experts = 0; }; template @@ -114,49 +162,192 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( using _BLK_K = Int; using DTYPE = typename Traits::DType; + using TiledMma = typename Traits::TiledMma; + + using SmemLayoutA = typename Traits::SmemLayoutA; + using SmemLayoutB = typename Traits::SmemLayoutB; + using SmemLayoutC = typename Traits::SmemLayoutC; + + using GmemTiledCopyAB = typename Traits::GmemTiledCopyAB; + using SmemTiledCopyA = typename Traits::SmemTiledCopyA; + using SmemTiledCopyB = typename Traits::SmemTiledCopyB; + using SmemTiledCopyC = typename Traits::SmemTiledCopyC; + + using SharedStorage = GEMMSharedStorageSM80; + + // TODO: m + const auto M = kBlockM * gridDim.x; + const auto N = params.n; + const auto K = params.k; const auto topk = params.topk; - // ProblemShape + const auto n_experts = params.n_experts; + // each thread block takes care of one block: (BLK_M, BLK_N) const auto m_block_idx = blockIdx.x; const auto n_block_idx = blockIdx.y; - // const auto expert_id = params.expert_ids_ptr[m_block_idx]; - const auto expert_id = 0; + const auto tidx = threadIdx.x; - // 1: load A to smem: (BLK_M, BLK_K, STAGES) - // load sorted_token_idxes from gmem, (m, topk) => (BLK_M) - const int* sorted_token_idxes = - params.sorted_token_idxes_ptr + m_block_idx * kBlockM; + const int expert_id = params.expert_ids_ptr[m_block_idx]; + + if (thread0()) { + print("m: %d, n: %d, k: %d, topk: %d, n_experts: %d\n", + M, + N, + K, + topk, + n_experts); + print("m_block_idx: %d, n_block_idx: %d, expert_id: %d\n", + m_block_idx, + n_block_idx, + expert_id); + } + + // ProblemShape + const int* sorted_token_idxes = params.sorted_token_idxes_ptr; auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) { - // Convert to token index return sorted_token_idxes[idx] / topk; }; - // A: (BLK_M, K) + // A: (M, K), k-major auto A = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.a_ptr), - make_shape(kBlockM, params.k), + make_shape(M, K), make_stride(get<0>(params.a_stride), _1{}), idx_to_t_idx); + + // B: (N, K), k-major + const auto b_offset = expert_id * get<0>(params.b_stride); + auto B = make_tensor(make_gmem_ptr((const DTYPE*)params.b_ptr + b_offset), + make_shape(N, K), + make_stride(get<1>(params.b_stride), _1{})); + + // C: (M, N), n-major + auto idx_to_f_idx = [sorted_token_idxes](int idx) { + return sorted_token_idxes[idx]; + }; + auto C = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.c_ptr), + make_shape(M, N), + make_stride(get<0>(params.c_stride), _1{}), + idx_to_f_idx); + if (thread0()) { print("A: "); print(A); print("\n"); + print("B: "); + print(B); + print("\n"); + print("C: "); + print(C); + print("\n"); } - // 2: load B to smem: (BLK_N, BLK_K, STAGES) - // load expert_id for current block from gmem, (1) - // B: (BLK_N, K) - // (e, n, k) => (BLK_N, k) - const auto b_offset = expert_id * get<0>(params.b_stride) + - n_block_idx * get<1>(params.b_stride); - auto B = make_tensor(make_gmem_ptr((const DTYPE*)params.b_ptr + b_offset), - make_shape(kBlockN, params.k), - make_stride(get<1>(params.b_stride), _1{})); + // (M*TOPK, K) => (BLK_M, BLK_K, k) + Tensor gA = + local_tile(A, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); + // (N, K) => (BLK_N, BLK_K, k) + Tensor gB = + local_tile(B, Shape<_BLK_N, _BLK_K>{}, make_coord(n_block_idx, _)); + // (M, N) => (BLK_M, BLK_N) + Tensor gC = local_tile( + C, Shape<_BLK_M, _BLK_N>{}, make_coord(m_block_idx, n_block_idx)); + if (thread0()) { - print("B: "); - print(B); + print("gA: "); + print(gA); + print("\n"); + print("gB: "); + print(gB); + print("\n"); + print("gC: "); + print(gC); print("\n"); } + // Smem + extern __shared__ char smem[]; + auto& ss = *reinterpret_cast(smem); + + // (BLK_M, BLK_K) + Tensor sA = make_tensor(make_smem_ptr(ss.a_smem.data()), SmemLayoutA{}); + // (BLK_N, BLK_K) + Tensor sB = make_tensor(make_smem_ptr(ss.b_smem.data()), SmemLayoutB{}); + // (BLK_M, BLK_N) + // Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); + + // Tiled Copy + GmemTiledCopyAB gmem_tiled_copy_ab; + auto gmem_thr_copy_ab = gmem_tiled_copy_ab.get_thread_slice(tidx); + + auto produce_a = [&](int ki) { + // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K) + auto tAgA = gmem_thr_copy_ab.partition_S(gA(_, _, ki)); + // (BLK_M, BLK_K) => (COPY, CP_M, CP_K) + auto tAsA = gmem_thr_copy_ab.partition_D(sA); + copy(gmem_tiled_copy_ab, tAgA, tAsA); + }; + + auto produce_b = [&](int ki) { + // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K) + auto tBgB = gmem_thr_copy_ab.partition_S(gB(_, _, ki)); + // (BLK_N, BLK_K) => (COPY, CP_N, CP_K) + auto tBsB = gmem_thr_copy_ab.partition_D(sB); + copy(gmem_tiled_copy_ab, tBgB, tBsB); + }; + + // GEMM: C = A@B.T + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + // rA: (BLK_M, BLK_K) + auto tCrA = thr_mma.partition_fragment_A(sA); + // rB: (BLK_N, BLK_K) + auto tCrB = thr_mma.partition_fragment_B(sB); + + // s2r tiled copy for A and B + auto smem_tiled_copy_a = SmemTiledCopyA{}; + auto smem_thr_copy_a = smem_tiled_copy_a.get_thread_slice(tidx); + auto tCsA = smem_thr_copy_a.partition_S(sA); + auto tCrA_cpv = smem_thr_copy_a.retile_D(tCrA); + + auto smem_tiled_copy_b = SmemTiledCopyB{}; + auto smem_thr_copy_b = smem_tiled_copy_b.get_thread_slice(tidx); + auto tCsB = smem_thr_copy_b.partition_S(sB); + auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB); + + // ############### Prologue ############### + + // ############### Mainloop ############### + // Accumulator: (BLK_M, BLK_N) + auto tCrC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + cute::clear(tCrC); // Clear the accumulator + + CUTE_NO_UNROLL + for (int ki = 0; ki < size<2>(gA); ++ki) { + // load A and B to shared memory + produce_a(ki); + produce_b(ki); + cp_async_fence(); + + // Wait for A and B to be loaded + cp_async_wait<0>(); + __syncthreads(); + + // copy sA and sB to registers + cute::copy(smem_tiled_copy_a, tCsA, tCrA_cpv); + cute::copy(smem_tiled_copy_b, tCsB, tCrB_cpv); + + // compute tCrA * tCrB with gemm + cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + } + + // ############### Epilogue ############### + // write output to global memory + + // 1: load A to smem: (BLK_M, BLK_K, STAGES) + // load sorted_token_idxes from gmem, (m*topk) => (m_blocks, BLK_M) + + // 2: load B to smem: (BLK_N, BLK_K, STAGES) + // load expert_id for current block from gmem, (1) + // Accumulator: (BLK_M, BLK_N) // 3: iterate over k // 4: partition A to tCsA, tCrA @@ -164,21 +355,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( // load a, b to registers // 6: compute tCrA * tCrB with gemm - // C: (BLK_M, BLK_N) // 7: write tCrC to global memory using sorted_token_idxes (m, topk) - auto idx_to_f_idx = [sorted_token_idxes](int idx) { - // Convert to token index - return sorted_token_idxes[idx]; - }; - auto C = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.c_ptr), - make_shape(kBlockM, kBlockN), - make_stride(get<0>(params.c_stride), _1{}), - idx_to_f_idx); - if (thread0()) { - print("C: "); - print(C); - print("\n"); - } } template @@ -188,6 +365,7 @@ void launch_grouped_gemm_kernel_sm80(const Params& params, // const auto max_q_packed_len = params.max_q_len * params.n_heads; const auto smem_size = sizeof(GEMMSharedStorageSM80); + std::cout << "SMEM size: " << smem_size << " bytes\n"; auto gemm_kernel = grouped_gemm_kernel_sm80; cudaFuncSetAttribute( @@ -195,7 +373,7 @@ void launch_grouped_gemm_kernel_sm80(const Params& params, // TODO: support persistent kernels // dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, // 1); - dim3 grid(1, 1, 1); // Placeholder for grid dimensions, adjust as needed + dim3 grid(1, 1); // Placeholder for grid dimensions, adjust as needed dim3 block = Traits::kThreadNum; gemm_kernel<<>>(params); } diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index 38880940..3ecd2b5f 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -9,6 +9,61 @@ namespace llm { namespace { +// reference implementation +std::tuple permute_align_block( + torch::Tensor topk_ids, // [n_tokens, topk] + int64_t n_experts, + int64_t block_size) { + const int64_t n_tokens = topk_ids.size(0); + const int64_t topk = topk_ids.size(1); + const int64_t n_flatten_tokens = topk_ids.numel(); + + auto topk_ids_cpu = topk_ids.cpu().contiguous(); + const int32_t* topk_ids_ptr = topk_ids_cpu.data_ptr(); + + std::vector> expert_to_idxes(n_experts); + for (int i = 0; i < n_flatten_tokens; ++i) { + const int32_t expert_id = topk_ids_ptr[i]; + assert(expert_id >= 0 && expert_id < n_experts); + expert_to_idxes[expert_id].push_back(i); + } + + std::vector sorted_token_idxes; + std::vector expert_ids; + int32_t n_padded_tokens = 0; + for (int e_idx = 0; e_idx < n_experts; ++e_idx) { + // flatten indices for each expert, sorted by token id + const auto& idxes = expert_to_idxes[e_idx]; + if (idxes.empty()) { + continue; + } + const auto count = idxes.size(); + const auto n_blocks = cute::ceil_div(count, block_size); + n_padded_tokens += (n_blocks * block_size); + // fill flatten indices for each block + for (int b_idx = 0; b_idx < n_blocks; ++b_idx) { + // expert id for each block + expert_ids.push_back(e_idx); + for (int offset = 0; offset < block_size; ++offset) { + auto idx = (b_idx * block_size) + offset; + if (idx < count) { + // fill flatten indices + sorted_token_idxes.push_back(idxes[idx]); + } else { + // fill padding + sorted_token_idxes.push_back(n_flatten_tokens); + } + } + } + } + + // construct tensor and return + const auto options = topk_ids.options(); + return {torch::tensor(sorted_token_idxes, options), + torch::tensor(expert_ids, options), + torch::tensor({n_padded_tokens}, options)}; +} + torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) const torch::Tensor& w, // (e, n, k) const torch::Tensor& topk_ids // (m, topk) @@ -19,6 +74,14 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) const auto n_experts = w.size(0); const auto topk = topk_ids.size(1); + // construct aligned + auto [sorted_token_idex, expert_ids, n_tokens_padded] = permute_align_block( + topk_ids.to(torch::kInt32), n_experts, /*block_size=*/64); + + // LOG(ERROR) << "sorted_token_idex: " << sorted_token_idex; + // LOG(ERROR) << "expert_ids: " << expert_ids; + // LOG(ERROR) << "n_padded_tokens: " << n_tokens_padded; + // (m * topk, n) auto out = torch::zeros({m * topk, n}, a.options()); @@ -29,7 +92,24 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) 64, /*BLK_K*/ 2>; /*STAGES*/ + // construct params GEMMParams params; + params.a_ptr = a.const_data_ptr(); + params.a_stride = make_stride(a.stride(0)); + params.b_ptr = w.const_data_ptr(); + params.b_stride = make_stride(w.stride(0), w.stride(1)); + params.c_ptr = out.mutable_data_ptr(); + params.c_stride = make_stride(out.stride(0)); + + params.sorted_token_idxes_ptr = sorted_token_idex.const_data_ptr(); + params.expert_ids_ptr = expert_ids.const_data_ptr(); + params.n_tokens_padded = n_tokens_padded.const_data_ptr(); + + params.m = m; + params.n = n; + params.k = k; + params.topk = topk; + launch_grouped_gemm_kernel_sm80(params, nullptr); // (m * topk, n) => (m, topk, n) @@ -101,16 +181,16 @@ TEST_P(GroupedGemmKernelTest, GEMM) { auto logits = torch::randn({m, n_experts}, options).softmax(/*dim=*/1); auto [topk_weights, topk_ids] = logits.topk(topk, /*dim=*/1); - // LOG(ERROR) << "a: " << a; - // LOG(ERROR) << "w: " << w; - // LOG(ERROR) << "topk_ids: " << topk_ids; - // LOG(ERROR) << "topk_weights: " << topk_weights; - auto ref_out = grouped_gemm_ref(a, w, topk_ids); // LOG(ERROR) << "ref_out: " << ref_out; auto out = grouped_gemm_sm80(a, w, topk_ids); // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + + // LOG(ERROR) << "a: " << a; + // LOG(ERROR) << "w: " << w; + // LOG(ERROR) << "topk_ids: " << topk_ids; + // LOG(ERROR) << "topk_weights: " << topk_weights; } INSTANTIATE_TEST_SUITE_P( From 79858e2f0d30ee1ad1ad588a3def95d419695f4b Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 9 Jun 2025 21:58:24 -0700 Subject: [PATCH 07/11] added epilogue part --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 89 +++++++------------ .../gemm/grouped_gemm_kernel_sm80_test.cu | 11 ++- 2 files changed, 38 insertions(+), 62 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index 1e19a9d1..ce04acda 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -37,10 +37,10 @@ struct GEMMTraitsSM80 { MMA_Atom, MMA_Atom>; - // TiledMMA: (64x32x16) + // TiledMMA: (64x16x16) using TiledMma = TiledMMA>, // warp layout: (4x1x1) - Tile<_64, _32, _16>>; // tile layout: (64x16x16) + Layout>, // warp layout: (4x1x1) + Tile<_64, _16, _16>>; // tile layout: (64x16x16) // Shared memory LayoutAtom (8x64) using SmemLayoutAtom_8x64 = @@ -90,11 +90,18 @@ struct GEMMTraitsSM80 { // use 128-bit vectorizing copy using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; - // s2g tiled copy for C + // r2s tiled copy for C using SmemTiledCopyC = decltype(make_tiled_copy_C(Copy_Atom{}, TiledMma{})); + // s2g tiled copy for O + using GmemTiledCopyC = decltype(make_tiled_copy( + Copy_Atom{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + // constexpr values for kernel launch static constexpr size_t kThreadNum = size(TiledMma{}); }; @@ -172,6 +179,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( using SmemTiledCopyA = typename Traits::SmemTiledCopyA; using SmemTiledCopyB = typename Traits::SmemTiledCopyB; using SmemTiledCopyC = typename Traits::SmemTiledCopyC; + using GmemTiledCopyC = typename Traits::GmemTiledCopyC; using SharedStorage = GEMMSharedStorageSM80; @@ -190,19 +198,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( const int expert_id = params.expert_ids_ptr[m_block_idx]; - if (thread0()) { - print("m: %d, n: %d, k: %d, topk: %d, n_experts: %d\n", - M, - N, - K, - topk, - n_experts); - print("m_block_idx: %d, n_block_idx: %d, expert_id: %d\n", - m_block_idx, - n_block_idx, - expert_id); - } - // ProblemShape const int* sorted_token_idxes = params.sorted_token_idxes_ptr; auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) { @@ -224,24 +219,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto idx_to_f_idx = [sorted_token_idxes](int idx) { return sorted_token_idxes[idx]; }; - auto C = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.c_ptr), + auto C = make_gather_tensor(make_gmem_ptr((DTYPE*)params.c_ptr), make_shape(M, N), make_stride(get<0>(params.c_stride), _1{}), idx_to_f_idx); - if (thread0()) { - print("A: "); - print(A); - print("\n"); - print("B: "); - print(B); - print("\n"); - print("C: "); - print(C); - print("\n"); - } - - // (M*TOPK, K) => (BLK_M, BLK_K, k) + // (M, K) => (BLK_M, BLK_K, k) Tensor gA = local_tile(A, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); // (N, K) => (BLK_N, BLK_K, k) @@ -251,18 +234,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( Tensor gC = local_tile( C, Shape<_BLK_M, _BLK_N>{}, make_coord(m_block_idx, n_block_idx)); - if (thread0()) { - print("gA: "); - print(gA); - print("\n"); - print("gB: "); - print(gB); - print("\n"); - print("gC: "); - print(gC); - print("\n"); - } - // Smem extern __shared__ char smem[]; auto& ss = *reinterpret_cast(smem); @@ -340,22 +311,22 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( } // ############### Epilogue ############### - // write output to global memory - - // 1: load A to smem: (BLK_M, BLK_K, STAGES) - // load sorted_token_idxes from gmem, (m*topk) => (m_blocks, BLK_M) - - // 2: load B to smem: (BLK_N, BLK_K, STAGES) - // load expert_id for current block from gmem, (1) - - // Accumulator: (BLK_M, BLK_N) - // 3: iterate over k - // 4: partition A to tCsA, tCrA - // 5: partition B to tCsB, tCrB - // load a, b to registers - // 6: compute tCrA * tCrB with gemm - - // 7: write tCrC to global memory using sorted_token_idxes (m, topk) + // (BLK_M, BLK_N) + Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); + + // copy tCrC from registers to smem + SmemTiledCopyC smem_tiled_copy_c; + auto smem_thr_copy_c = smem_tiled_copy_c.get_thread_slice(tidx); + auto tSrC = smem_thr_copy_c.retile_S(tCrC); + auto tSsC = smem_thr_copy_c.partition_D(sC); + cute::copy(smem_tiled_copy_c, tSrC, tSsC); + + // copy sC from smem to gmem + GmemTiledCopyC gmem_tiled_copy_c; + auto gmem_thr_copy_c = gmem_tiled_copy_c.get_thread_slice(tidx); + auto tGsC = gmem_thr_copy_c.partition_S(sC); + auto tGgC = gmem_thr_copy_c.partition_D(gC); + cute::copy(gmem_tiled_copy_c, tGsC, tGgC); } template diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index 3ecd2b5f..0e924363 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -185,7 +185,12 @@ TEST_P(GroupedGemmKernelTest, GEMM) { // LOG(ERROR) << "ref_out: " << ref_out; auto out = grouped_gemm_sm80(a, w, topk_ids); - // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + + // auto max_diff = (out - ref_out).abs().max().item(); + // LOG(ERROR) << "Max diff: " << max_diff; + // LOG(ERROR) << "ref_out: " << ref_out; + // LOG(ERROR) << "out: " << out; // LOG(ERROR) << "a: " << a; // LOG(ERROR) << "w: " << w; @@ -200,8 +205,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(64), // m ::testing::Values(64), // n ::testing::Values(64), // k - ::testing::Values(8), // n_experts - ::testing::Values(4) // topk + ::testing::Values(1), // n_experts + ::testing::Values(1) // topk )); } // namespace llm From 49282d60b43827419e130bd0e90e71b1215188d1 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 12 Jun 2025 21:40:02 -0700 Subject: [PATCH 08/11] pipeline for gemm --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 86 ++++++++++++------- .../gemm/grouped_gemm_kernel_sm80_test.cu | 16 ++-- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index ce04acda..446d5b1c 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -142,7 +142,7 @@ struct GEMMParams { void* __restrict__ c_ptr = nullptr; CStride c_stride; - // (m_blocks, BLK_M) + // (m_blocks*BLK_M) const int* __restrict__ sorted_token_idxes_ptr = nullptr; // (m_blocks) const int* __restrict__ expert_ids_ptr = nullptr; @@ -154,6 +154,9 @@ struct GEMMParams { int k = 0; int topk = 0; int n_experts = 0; + + int m_blocks = 0; + int n_blocks = 0; }; template @@ -183,7 +186,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( using SharedStorage = GEMMSharedStorageSM80; - // TODO: m const auto M = kBlockM * gridDim.x; const auto N = params.n; const auto K = params.k; @@ -249,20 +251,22 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( GmemTiledCopyAB gmem_tiled_copy_ab; auto gmem_thr_copy_ab = gmem_tiled_copy_ab.get_thread_slice(tidx); + // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K, k) + auto tAgA = gmem_thr_copy_ab.partition_S(gA); + // (BLK_M, BLK_K) => (COPY, CP_M, CP_K) + auto tAsA = gmem_thr_copy_ab.partition_D(sA); + + // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K, k) + auto tBgB = gmem_thr_copy_ab.partition_S(gB); + // (BLK_N, BLK_K) => (COPY, CP_N, CP_K) + auto tBsB = gmem_thr_copy_ab.partition_D(sB); + auto produce_a = [&](int ki) { - // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K) - auto tAgA = gmem_thr_copy_ab.partition_S(gA(_, _, ki)); - // (BLK_M, BLK_K) => (COPY, CP_M, CP_K) - auto tAsA = gmem_thr_copy_ab.partition_D(sA); - copy(gmem_tiled_copy_ab, tAgA, tAsA); + copy(gmem_tiled_copy_ab, tAgA(_, _, _, ki), tAsA); }; auto produce_b = [&](int ki) { - // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K) - auto tBgB = gmem_thr_copy_ab.partition_S(gB(_, _, ki)); - // (BLK_N, BLK_K) => (COPY, CP_N, CP_K) - auto tBsB = gmem_thr_copy_ab.partition_D(sB); - copy(gmem_tiled_copy_ab, tBgB, tBsB); + copy(gmem_tiled_copy_ab, tBgB(_, _, _, ki), tBsB); }; // GEMM: C = A@B.T @@ -285,35 +289,55 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB); // ############### Prologue ############### + produce_a(0); + produce_b(0); + cp_async_fence(); // ############### Mainloop ############### // Accumulator: (BLK_M, BLK_N) auto tCrC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); cute::clear(tCrC); // Clear the accumulator - CUTE_NO_UNROLL - for (int ki = 0; ki < size<2>(gA); ++ki) { - // load A and B to shared memory - produce_a(ki); - produce_b(ki); - cp_async_fence(); + // total count of tiles in the k dimension + int k_tiles = size<2>(gA); - // Wait for A and B to be loaded + CUTE_NO_UNROLL + for (int ki = 0; ki < k_tiles; ++ki) { + // Wait for A and B to be loaded into smem cp_async_wait<0>(); __syncthreads(); - // copy sA and sB to registers - cute::copy(smem_tiled_copy_a, tCsA, tCrA_cpv); - cute::copy(smem_tiled_copy_b, tCsB, tCrB_cpv); - - // compute tCrA * tCrB with gemm - cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // prefetch sA and sB to registers + cute::copy(smem_tiled_copy_a, tCsA(_, _, _0{}), tCrA_cpv(_, _, _0{})); + cute::copy(smem_tiled_copy_b, tCsB(_, _, _0{}), tCrB_cpv(_, _, _0{})); + + CUTE_UNROLL + for (int i = 0; i < size<2>(tCrA); ++i) { + const auto next_i = i + 1; + if (next_i < size<2>(tCrA)) { + cute::copy( + smem_tiled_copy_a, tCsA(_, _, next_i), tCrA_cpv(_, _, next_i)); + cute::copy( + smem_tiled_copy_b, tCsB(_, _, next_i), tCrB_cpv(_, _, next_i)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); + } + + // load next A and B to smem + const int next_ki = ki + 1; + if (next_ki < k_tiles) { + produce_a(next_ki); + produce_b(next_ki); + cp_async_fence(); + } } // ############### Epilogue ############### // (BLK_M, BLK_N) Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); + // TODO: fastcast tCrC to DTYPE + // copy tCrC from registers to smem SmemTiledCopyC smem_tiled_copy_c; auto smem_thr_copy_c = smem_tiled_copy_c.get_thread_slice(tidx); @@ -321,6 +345,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto tSsC = smem_thr_copy_c.partition_D(sC); cute::copy(smem_tiled_copy_c, tSrC, tSsC); + // wait for smem copy done before gmem copy + __syncthreads(); + // copy sC from smem to gmem GmemTiledCopyC gmem_tiled_copy_c; auto gmem_thr_copy_c = gmem_tiled_copy_c.get_thread_slice(tidx); @@ -332,19 +359,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( template void launch_grouped_gemm_kernel_sm80(const Params& params, cudaStream_t stream) { - // const auto batch_size = params.batch_size; - // const auto max_q_packed_len = params.max_q_len * params.n_heads; - const auto smem_size = sizeof(GEMMSharedStorageSM80); - std::cout << "SMEM size: " << smem_size << " bytes\n"; + // std::cout << "SMEM size: " << smem_size << " bytes\n"; auto gemm_kernel = grouped_gemm_kernel_sm80; cudaFuncSetAttribute( gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); // TODO: support persistent kernels - // dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, - // 1); - dim3 grid(1, 1); // Placeholder for grid dimensions, adjust as needed + dim3 grid(params.m_blocks, params.n_blocks); dim3 block = Traits::kThreadNum; gemm_kernel<<>>(params); } diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index 0e924363..e6758f29 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -110,6 +110,9 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) params.k = k; params.topk = topk; + params.m_blocks = expert_ids.size(0); + params.n_blocks = cute::ceil_div(n, 64); + launch_grouped_gemm_kernel_sm80(params, nullptr); // (m * topk, n) => (m, topk, n) @@ -187,24 +190,19 @@ TEST_P(GroupedGemmKernelTest, GEMM) { EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); - // auto max_diff = (out - ref_out).abs().max().item(); + // auto max_diff = (out - ref_out).abs().max(); // LOG(ERROR) << "Max diff: " << max_diff; // LOG(ERROR) << "ref_out: " << ref_out; // LOG(ERROR) << "out: " << out; - - // LOG(ERROR) << "a: " << a; - // LOG(ERROR) << "w: " << w; - // LOG(ERROR) << "topk_ids: " << topk_ids; - // LOG(ERROR) << "topk_weights: " << topk_weights; } INSTANTIATE_TEST_SUITE_P( GEMM, GroupedGemmKernelTest, ::testing::Combine(::testing::Values(torch::kHalf), // dtype - ::testing::Values(64), // m - ::testing::Values(64), // n - ::testing::Values(64), // k + ::testing::Values(64, 128), // m + ::testing::Values(64, 128), // n + ::testing::Values(64, 128), // k ::testing::Values(1), // n_experts ::testing::Values(1) // topk )); From ceda21610259f1e11c36d87fa72f5fbc03d24f90 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 12 Jun 2025 22:35:12 -0700 Subject: [PATCH 09/11] added pipe support --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 152 +++++++++++------- .../gemm/grouped_gemm_kernel_sm80_test.cu | 3 +- 2 files changed, 98 insertions(+), 57 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index 446d5b1c..7676cc85 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -11,13 +11,12 @@ namespace llm { using namespace cute; -template +template struct GEMMTraitsSM80 { - static constexpr int kDim = DIM; static constexpr int kBlockM = BLK_M; static constexpr int kBlockN = BLK_N; static constexpr int kBlockK = BLK_K; - static constexpr int kStages = STAGES; + static constexpr int kPipe = PIPE; static_assert(kBlockM % 64 == 0); static_assert(kBlockN % 32 == 0); @@ -28,8 +27,7 @@ struct GEMMTraitsSM80 { using _BLK_M = Int; using _BLK_N = Int; using _BLK_K = Int; - using _STAGES = Int; - using _DIM = Int; + using _PIPE = Int; // MMA Atom: (16x8x16) for F32F16F16F32 or F32BF16BF16F32 using MMA_Atom_ = @@ -53,12 +51,12 @@ struct GEMMTraitsSM80 { using SmemLayoutAtom = std::conditional_t; - // SMEM Layout for A: (BLK_M, BLK_K, STAGES) + // SMEM Layout for A: (BLK_M, BLK_K, PIPE) using SmemLayoutA = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_K>{})); - // SMEM Layout for B: (BLK_N, BLK_K, STAGES) + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_K, _PIPE>{})); + // SMEM Layout for B: (BLK_N, BLK_K, PIPE) using SmemLayoutB = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _BLK_K>{})); + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _BLK_K, _PIPE>{})); // Thread layout for gmem copy: (_16,_8)/(_32, _4) using GmemCopyThrLayout = @@ -115,9 +113,9 @@ struct GEMMSharedStorageSM80 { union { struct { - // Shared memory for A: (BLK_M, BLK_K, STAGES) + // Shared memory for A: (BLK_M, BLK_K, PIPE) cute::array_aligned> a_smem; - // Shared memory for B: (BLK_N, BLK_K, STAGES) + // Shared memory for B: (BLK_N, BLK_K, PIPE) cute::array_aligned> b_smem; }; // Shared memory for C: (BLK_M, BLK_N) @@ -240,9 +238,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( extern __shared__ char smem[]; auto& ss = *reinterpret_cast(smem); - // (BLK_M, BLK_K) + // (BLK_M, BLK_K, PIPE) Tensor sA = make_tensor(make_smem_ptr(ss.a_smem.data()), SmemLayoutA{}); - // (BLK_N, BLK_K) + // (BLK_N, BLK_K, PIPE) Tensor sB = make_tensor(make_smem_ptr(ss.b_smem.data()), SmemLayoutB{}); // (BLK_M, BLK_N) // Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); @@ -253,82 +251,126 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K, k) auto tAgA = gmem_thr_copy_ab.partition_S(gA); - // (BLK_M, BLK_K) => (COPY, CP_M, CP_K) + // (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE) auto tAsA = gmem_thr_copy_ab.partition_D(sA); // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K, k) auto tBgB = gmem_thr_copy_ab.partition_S(gB); - // (BLK_N, BLK_K) => (COPY, CP_N, CP_K) + // (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE) auto tBsB = gmem_thr_copy_ab.partition_D(sB); - auto produce_a = [&](int ki) { - copy(gmem_tiled_copy_ab, tAgA(_, _, _, ki), tAsA); - }; - - auto produce_b = [&](int ki) { - copy(gmem_tiled_copy_ab, tBgB(_, _, _, ki), tBsB); - }; - // GEMM: C = A@B.T TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - // rA: (BLK_M, BLK_K) - auto tCrA = thr_mma.partition_fragment_A(sA); - // rB: (BLK_N, BLK_K) - auto tCrB = thr_mma.partition_fragment_B(sB); + // rA: (BLK_M, BLK_K) => (MMA,MMA_M,MMA_K) + auto tCrA = thr_mma.partition_fragment_A(sA(_, _, _0{})); + // rB: (BLK_N, BLK_K) => (MMA,MMA_N,MMA_K) + auto tCrB = thr_mma.partition_fragment_B(sB(_, _, _0{})); // s2r tiled copy for A and B auto smem_tiled_copy_a = SmemTiledCopyA{}; auto smem_thr_copy_a = smem_tiled_copy_a.get_thread_slice(tidx); + // (BLK_M, BLK_K, PIPE) => (COPY, COPY_M, COPY_K, PIPE) auto tCsA = smem_thr_copy_a.partition_S(sA); + // (COPY, COPY_M, COPY_K) auto tCrA_cpv = smem_thr_copy_a.retile_D(tCrA); auto smem_tiled_copy_b = SmemTiledCopyB{}; auto smem_thr_copy_b = smem_tiled_copy_b.get_thread_slice(tidx); + // (BLK_N, BLK_K, PIPE) => (COPY, COPY_N, COPY_K, PIPE) auto tCsB = smem_thr_copy_b.partition_S(sB); + // (COPY, COPY_N, COPY_K) auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB); // ############### Prologue ############### - produce_a(0); - produce_b(0); - cp_async_fence(); + // remaining k-tile count + int k_tile_remaining = size<3>(tAgA); + // next tile index in gmem to read from + int k_tile_next = 0; + + // async loads for all pipes except the last one + auto kPipe = size<3>(tAsA); + CUTE_UNROLL + for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) { + copy(gmem_tiled_copy_ab, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); + copy(gmem_tiled_copy_ab, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); + cp_async_fence(); + + // advance to next k-tile + if (--k_tile_remaining > 0) { + ++k_tile_next; + } + } // ############### Mainloop ############### - // Accumulator: (BLK_M, BLK_N) + // (BLK_M, BLK_N) => (MMA, MMA_M, MMA_N) auto tCrC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); cute::clear(tCrC); // Clear the accumulator - // total count of tiles in the k dimension - int k_tiles = size<2>(gA); + // pipe index in smem to read from + int pipe_read = 0; + // pipe index in smem to write to + int pipe_write = kPipe - 1; - CUTE_NO_UNROLL - for (int ki = 0; ki < k_tiles; ++ki) { - // Wait for A and B to be loaded into smem - cp_async_wait<0>(); + // pipe to read from: (COPY, COPY_N, COPY_K) + Tensor tCsA_p = tCsA(_, _, _, pipe_read); + Tensor tCsB_p = tCsB(_, _, _, pipe_read); + + // Size of the register pipeline + auto kBlocks = size<2>(tCrA); + + // prefetch register pipeline + if (kBlocks > 1) { + // wait until our first prefetched tile is loaded in + cp_async_wait(); __syncthreads(); - // prefetch sA and sB to registers - cute::copy(smem_tiled_copy_a, tCsA(_, _, _0{}), tCrA_cpv(_, _, _0{})); - cute::copy(smem_tiled_copy_b, tCsB(_, _, _0{}), tCrB_cpv(_, _, _0{})); + // prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_a, tCsA_p(_, _, _0{}), tCrA_cpv(_, _, _0{})); + cute::copy(smem_tiled_copy_b, tCsB_p(_, _, _0{}), tCrB_cpv(_, _, _0{})); + } + CUTE_NO_UNROLL + while (k_tile_remaining > -(kPipe - 1)) { CUTE_UNROLL - for (int i = 0; i < size<2>(tCrA); ++i) { - const auto next_i = i + 1; - if (next_i < size<2>(tCrA)) { - cute::copy( - smem_tiled_copy_a, tCsA(_, _, next_i), tCrA_cpv(_, _, next_i)); - cute::copy( - smem_tiled_copy_b, tCsB(_, _, next_i), tCrB_cpv(_, _, next_i)); + for (int ki = 0; ki < kBlocks; ++ki) { + if (ki == kBlocks - 1) { + // advance to next pipe to read from + tCsA_p = tCsA(_, _, _, pipe_read); + tCsB_p = tCsB(_, _, _, pipe_read); + + // wait until our next prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + } + + // load A, B from smem to registers for next ki + auto ki_next = (ki + _1{}) % kBlocks; + copy(smem_tiled_copy_a, tCsA_p(_, _, ki_next), tCrA_cpv(_, _, ki_next)); + copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next)); + + if (ki == 0) { + // copy gmem to smeme for next pipe + copy(gmem_tiled_copy_ab, + tAgA(_, _, _, k_tile_next), + tAsA(_, _, _, pipe_write)); + copy(gmem_tiled_copy_ab, + tBgB(_, _, _, k_tile_next), + tBsB(_, _, _, pipe_write)); + cp_async_fence(); + + // advance to next k-tile + if (--k_tile_remaining > 0) { + ++k_tile_next; + } + + // advance to next pipe + pipe_write = pipe_read; + pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); - } - // load next A and B to smem - const int next_ki = ki + 1; - if (next_ki < k_tiles) { - produce_a(next_ki); - produce_b(next_ki); - cp_async_fence(); + // thread-level gemm for ki + gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrC); } } diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index e6758f29..6915d0ad 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -86,11 +86,10 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) auto out = torch::zeros({m * topk, n}, a.options()); using Traits = GEMMTraitsSM80; /*STAGES*/ + 2>; /*PIPE*/ // construct params GEMMParams params; From 1b2a8fe6f14d60bca933e28e2b171f54cf97ea58 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 12 Jun 2025 22:35:12 -0700 Subject: [PATCH 10/11] added pipe support --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 71 +++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index 7676cc85..981cf200 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -64,7 +64,7 @@ struct GEMMTraitsSM80 { Layout, Stride<_4, _1>>, Layout, Stride<_8, _1>>>; // g2s tiled copy: copy A/B from global memory to shared memory - using GmemTiledCopyAB = decltype(make_tiled_copy( + using GmemTiledCopy = decltype(make_tiled_copy( Copy_Atom, DType>{}, GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) Layout>{} // Val layout: 8 vals per read @@ -176,7 +176,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( using SmemLayoutB = typename Traits::SmemLayoutB; using SmemLayoutC = typename Traits::SmemLayoutC; - using GmemTiledCopyAB = typename Traits::GmemTiledCopyAB; + using GmemTiledCopy = typename Traits::GmemTiledCopy; using SmemTiledCopyA = typename Traits::SmemTiledCopyA; using SmemTiledCopyB = typename Traits::SmemTiledCopyB; using SmemTiledCopyC = typename Traits::SmemTiledCopyC; @@ -246,10 +246,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( // Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); // Tiled Copy - GmemTiledCopyAB gmem_tiled_copy_ab; - auto gmem_thr_copy_ab = gmem_tiled_copy_ab.get_thread_slice(tidx); + GmemTiledCopy gmem_tiled_copy; + auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K, k) +<<<<<<< HEAD auto tAgA = gmem_thr_copy_ab.partition_S(gA); // (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE) auto tAsA = gmem_thr_copy_ab.partition_D(sA); @@ -258,6 +259,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto tBgB = gmem_thr_copy_ab.partition_S(gB); // (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE) auto tBsB = gmem_thr_copy_ab.partition_D(sB); +======= + auto tAgA = gmem_thr_copy.partition_S(gA); + // (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE) + auto tAsA = gmem_thr_copy.partition_D(sA); + + // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K, k) + auto tBgB = gmem_thr_copy.partition_S(gB); + // (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE) + auto tBsB = gmem_thr_copy.partition_D(sB); +>>>>>>> b916d0d (added pipe support) // GEMM: C = A@B.T TiledMma tiled_mma; @@ -284,7 +295,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( // ############### Prologue ############### // remaining k-tile count +<<<<<<< HEAD int k_tile_remaining = size<3>(tAgA); +======= + int k_tiles_remaining = size<3>(tAgA); +>>>>>>> b916d0d (added pipe support) // next tile index in gmem to read from int k_tile_next = 0; @@ -292,12 +307,21 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto kPipe = size<3>(tAsA); CUTE_UNROLL for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) { +<<<<<<< HEAD copy(gmem_tiled_copy_ab, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); copy(gmem_tiled_copy_ab, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); cp_async_fence(); // advance to next k-tile if (--k_tile_remaining > 0) { +======= + copy(gmem_tiled_copy, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); + copy(gmem_tiled_copy, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); + cp_async_fence(); + + // advance to next k-tile + if (--k_tiles_remaining > 0) { +>>>>>>> b916d0d (added pipe support) ++k_tile_next; } } @@ -331,6 +355,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( } CUTE_NO_UNROLL +<<<<<<< HEAD while (k_tile_remaining > -(kPipe - 1)) { CUTE_UNROLL for (int ki = 0; ki < kBlocks; ++ki) { @@ -339,16 +364,52 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( tCsA_p = tCsA(_, _, _, pipe_read); tCsB_p = tCsB(_, _, _, pipe_read); +======= + while (k_tiles_remaining > -(kPipe - 1)) { + CUTE_UNROLL + for (int ki = 0; ki < kBlocks; ++ki) { + // first block + if (ki == 0) { + // copy gmem to smem for next pipe + copy(gmem_tiled_copy, + tAgA(_, _, _, k_tile_next), + tAsA(_, _, _, pipe_write)); + copy(gmem_tiled_copy, + tBgB(_, _, _, k_tile_next), + tBsB(_, _, _, pipe_write)); + cp_async_fence(); + + // advance to next k-tile + if (--k_tiles_remaining > 0) { + ++k_tile_next; + } + } + // last block + if (ki == kBlocks - 1) { + // advance to next pipe + pipe_write = pipe_read; + pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; + + // advance to next pipe to read from + tCsA_p = tCsA(_, _, _, pipe_read); + tCsB_p = tCsB(_, _, _, pipe_read); + +>>>>>>> b916d0d (added pipe support) // wait until our next prefetched tile is loaded in cp_async_wait(); __syncthreads(); } +<<<<<<< HEAD // load A, B from smem to registers for next ki +======= + // prefetch for next ki +>>>>>>> b916d0d (added pipe support) auto ki_next = (ki + _1{}) % kBlocks; copy(smem_tiled_copy_a, tCsA_p(_, _, ki_next), tCrA_cpv(_, _, ki_next)); copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next)); +<<<<<<< HEAD if (ki == 0) { // copy gmem to smeme for next pipe copy(gmem_tiled_copy_ab, @@ -369,6 +430,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; } +======= +>>>>>>> b916d0d (added pipe support) // thread-level gemm for ki gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrC); } From d9114d2a01f20b37841cf8558b54d55969d022ce Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 12 Jun 2025 22:53:51 -0700 Subject: [PATCH 11/11] fix build --- src/kernels/gemm/grouped_gemm_kernel_sm80.cuh | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index 981cf200..3711c3e5 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -250,16 +250,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K, k) -<<<<<<< HEAD - auto tAgA = gmem_thr_copy_ab.partition_S(gA); - // (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE) - auto tAsA = gmem_thr_copy_ab.partition_D(sA); - - // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K, k) - auto tBgB = gmem_thr_copy_ab.partition_S(gB); - // (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE) - auto tBsB = gmem_thr_copy_ab.partition_D(sB); -======= auto tAgA = gmem_thr_copy.partition_S(gA); // (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE) auto tAsA = gmem_thr_copy.partition_D(sA); @@ -268,7 +258,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto tBgB = gmem_thr_copy.partition_S(gB); // (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE) auto tBsB = gmem_thr_copy.partition_D(sB); ->>>>>>> b916d0d (added pipe support) // GEMM: C = A@B.T TiledMma tiled_mma; @@ -295,11 +284,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( // ############### Prologue ############### // remaining k-tile count -<<<<<<< HEAD - int k_tile_remaining = size<3>(tAgA); -======= int k_tiles_remaining = size<3>(tAgA); ->>>>>>> b916d0d (added pipe support) // next tile index in gmem to read from int k_tile_next = 0; @@ -307,21 +292,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( auto kPipe = size<3>(tAsA); CUTE_UNROLL for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) { -<<<<<<< HEAD - copy(gmem_tiled_copy_ab, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); - copy(gmem_tiled_copy_ab, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); - cp_async_fence(); - - // advance to next k-tile - if (--k_tile_remaining > 0) { -======= copy(gmem_tiled_copy, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); copy(gmem_tiled_copy, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); cp_async_fence(); // advance to next k-tile if (--k_tiles_remaining > 0) { ->>>>>>> b916d0d (added pipe support) ++k_tile_next; } } @@ -355,16 +331,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( } CUTE_NO_UNROLL -<<<<<<< HEAD - while (k_tile_remaining > -(kPipe - 1)) { - CUTE_UNROLL - for (int ki = 0; ki < kBlocks; ++ki) { - if (ki == kBlocks - 1) { - // advance to next pipe to read from - tCsA_p = tCsA(_, _, _, pipe_read); - tCsB_p = tCsB(_, _, _, pipe_read); - -======= while (k_tiles_remaining > -(kPipe - 1)) { CUTE_UNROLL for (int ki = 0; ki < kBlocks; ++ki) { @@ -394,44 +360,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( tCsA_p = tCsA(_, _, _, pipe_read); tCsB_p = tCsB(_, _, _, pipe_read); ->>>>>>> b916d0d (added pipe support) // wait until our next prefetched tile is loaded in cp_async_wait(); __syncthreads(); } -<<<<<<< HEAD - // load A, B from smem to registers for next ki -======= // prefetch for next ki ->>>>>>> b916d0d (added pipe support) auto ki_next = (ki + _1{}) % kBlocks; copy(smem_tiled_copy_a, tCsA_p(_, _, ki_next), tCrA_cpv(_, _, ki_next)); copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next)); -<<<<<<< HEAD - if (ki == 0) { - // copy gmem to smeme for next pipe - copy(gmem_tiled_copy_ab, - tAgA(_, _, _, k_tile_next), - tAsA(_, _, _, pipe_write)); - copy(gmem_tiled_copy_ab, - tBgB(_, _, _, k_tile_next), - tBsB(_, _, _, pipe_write)); - cp_async_fence(); - - // advance to next k-tile - if (--k_tile_remaining > 0) { - ++k_tile_next; - } - - // advance to next pipe - pipe_write = pipe_read; - pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; - } - -======= ->>>>>>> b916d0d (added pipe support) // thread-level gemm for ki gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrC); }