diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index 999a1393..1dfdf69e 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -1,16 +1,16 @@ include(cc_library) cc_library( - NAME + NAME kernels - HDRS + HDRS reduce_kernel_utils.cuh activation_kernels.h layernorm_kernels.h pos_embedding_kernels.h kv_cache_kernels.h sampling/sampling_kernels.h - SRCS + SRCS activation_kernels.cu layernorm_kernels.cu pos_embedding_kernels.cu @@ -28,7 +28,7 @@ cc_library( add_subdirectory(attention) add_subdirectory(moe) +add_subdirectory(gemm) add_subdirectory(quantization) add_subdirectory(playground) -add_subdirectory(triton) - +# add_subdirectory(triton) diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt new file mode 100644 index 00000000..99cc5920 --- /dev/null +++ b/src/kernels/gemm/CMakeLists.txt @@ -0,0 +1,24 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + gemm.kernels + HDRS + grouped_gemm_kernel_sm80.cuh + DEPS + cutlass +) + + +cc_test( + NAME + gemm_kernel_test + SRCS + grouped_gemm_kernel_sm80_test.cu + DEPS + :gemm.kernels + absl::random_random + GTest::gtest_main + torch +) diff --git a/src/kernels/gemm/gather_tensor.hpp b/src/kernels/gemm/gather_tensor.hpp new file mode 100644 index 00000000..79ca581b --- /dev/null +++ b/src/kernels/gemm/gather_tensor.hpp @@ -0,0 +1,167 @@ +// adapted from +// https://github.com/NVIDIA/cutlass/blob/main/examples/common/gather_tensor.hpp +#pragma once + +#include "cute/layout.hpp" +#include "cute/layout_composed.hpp" +#include "cute/tensor.hpp" +namespace llm { + +using namespace cute; + +namespace detail { + +// every stride must be divisible by div +template +CUTE_HOST_DEVICE constexpr auto safe_stride_div(Stride const& s, + const Div& div) { + if constexpr (is_tuple::value) { + return transform(s, [&](auto const& a) { return safe_stride_div(a, div); }); + } else { + return safe_div(s, div); + } + CUTE_GCC_UNREACHABLE; +} + +} // namespace detail + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride { + CUTE_HOST_DEVICE constexpr CustomStride(const Func& func, + const Stride& stride) + : func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, const CustomStride& s) { + return inner_product(s.func_(i), s.stride_); + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(const CustomStride& s, I i) { + return inner_product(s.func_(i), s.stride_); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(const CustomStride& s, + const Div& div) { + auto stride = detail::safe_stride_div(s.stride_, div); + return CustomStride(s.func_, stride); + } + + template + CUTE_HOST_DEVICE constexpr friend auto make_layout( + const Shape& shape, + const CustomStride& stride) { + return Layout(shape, stride); + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) { + print("CustomStride{func,"); + print(s.stride_); + print("}"); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Func&& func, + const Shape& shape, + const Stride& stride) { + // Use a dummy shape and replace the first non-unit stride with a custom + // gather stride + auto idx = + find_if(stride, [](auto x) { return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(shape, _1{}), + replace(stride, + CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, + const Shape& shape, + const Stride& stride, + Func&& func) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = + make_custom_stride_layout(static_cast(func), shape, stride); + return make_tensor(iter, + ComposedLayout{gather_layout, offset, matrix_layout}); +} + +} // namespace llm + +namespace cute { + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, + Stride const& stride) { + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { + return upcast(s, d); + }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, + Offset, + Layout> const& layout) { + // Find index of the stride-1 mode - that is the only one that requires + // updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), + [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple( + replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = + upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +template +CUTE_HOST_DEVICE constexpr auto max_common_vector( + Layout const& a, + ComposedLayout, + OffsetB, + Layout> const& b) { + return max_common_vector(b.layout_b(), a); +} + +} // namespace cute diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh new file mode 100644 index 00000000..3711c3e5 --- /dev/null +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -0,0 +1,417 @@ +#pragma once +#include +#include + +#include +#include +#include + +#include "gather_tensor.hpp" + +namespace llm { +using namespace cute; + +template +struct GEMMTraitsSM80 { + static constexpr int kBlockM = BLK_M; + static constexpr int kBlockN = BLK_N; + static constexpr int kBlockK = BLK_K; + static constexpr int kPipe = PIPE; + + static_assert(kBlockM % 64 == 0); + static_assert(kBlockN % 32 == 0); + static_assert(kBlockK % 16 == 0); + + // helpful aliases + using DType = DTYPE; + using _BLK_M = Int; + using _BLK_N = Int; + using _BLK_K = Int; + using _PIPE = Int; + + // MMA Atom: (16x8x16) for F32F16F16F32 or F32BF16BF16F32 + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + + // TiledMMA: (64x16x16) + using TiledMma = TiledMMA>, // warp layout: (4x1x1) + Tile<_64, _16, _16>>; // tile layout: (64x16x16) + + // Shared memory LayoutAtom (8x64) + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + + using SmemLayoutAtom = std::conditional_t; + // SMEM Layout for A: (BLK_M, BLK_K, PIPE) + using SmemLayoutA = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_K, _PIPE>{})); + // SMEM Layout for B: (BLK_N, BLK_K, PIPE) + using SmemLayoutB = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _BLK_K, _PIPE>{})); + + // Thread layout for gmem copy: (_16,_8)/(_32, _4) + using GmemCopyThrLayout = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + // g2s tiled copy: copy A/B from global memory to shared memory + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // s2r tiled copy for A and B + using SmemTiledCopyA = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma{})); + using SmemTiledCopyB = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // ******* Epilogue ******* + + using SmemLayoutAtomC = std::conditional_t; + using SmemLayoutC = + decltype(tile_to_shape(SmemLayoutAtomC{}, Shape<_BLK_M, _BLK_N>{})); + + // use 128-bit vectorizing copy + using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + // r2s tiled copy for C + using SmemTiledCopyC = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma{})); + + // s2g tiled copy for O + using GmemTiledCopyC = decltype(make_tiled_copy( + Copy_Atom{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // constexpr values for kernel launch + static constexpr size_t kThreadNum = size(TiledMma{}); +}; + +template +struct GEMMSharedStorageSM80 { + using DType = typename Traits::DType; + using SmemLayoutA = typename Traits::SmemLayoutA; + using SmemLayoutB = typename Traits::SmemLayoutB; + using SmemLayoutC = typename Traits::SmemLayoutC; + + union { + struct { + // Shared memory for A: (BLK_M, BLK_K, PIPE) + cute::array_aligned> a_smem; + // Shared memory for B: (BLK_N, BLK_K, PIPE) + cute::array_aligned> b_smem; + }; + // Shared memory for C: (BLK_M, BLK_N) + cute::array_aligned> c_smem; + }; +}; + +struct GEMMParams { + using AStride = Stride; + using BStride = Stride; + using CStride = Stride; + + // A: (m, k) + const void* __restrict__ a_ptr = nullptr; + AStride a_stride; + + // B: (e, n, k) + const void* __restrict__ b_ptr = nullptr; + BStride b_stride; + + // C: ((m, topk), n) + void* __restrict__ c_ptr = nullptr; + CStride c_stride; + + // (m_blocks*BLK_M) + const int* __restrict__ sorted_token_idxes_ptr = nullptr; + // (m_blocks) + const int* __restrict__ expert_ids_ptr = nullptr; + + const int* __restrict__ n_tokens_padded = nullptr; + + int m = 0; + int n = 0; + int k = 0; + int topk = 0; + int n_experts = 0; + + int m_blocks = 0; + int n_blocks = 0; +}; + +template +__global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( + __grid_constant__ const Params params) { + // Traits + constexpr int kBlockM = Traits::kBlockM; + constexpr int kBlockN = Traits::kBlockN; + constexpr int kBlockK = Traits::kBlockK; + + using _BLK_M = Int; + using _BLK_N = Int; + using _BLK_K = Int; + + using DTYPE = typename Traits::DType; + using TiledMma = typename Traits::TiledMma; + + using SmemLayoutA = typename Traits::SmemLayoutA; + using SmemLayoutB = typename Traits::SmemLayoutB; + using SmemLayoutC = typename Traits::SmemLayoutC; + + using GmemTiledCopy = typename Traits::GmemTiledCopy; + using SmemTiledCopyA = typename Traits::SmemTiledCopyA; + using SmemTiledCopyB = typename Traits::SmemTiledCopyB; + using SmemTiledCopyC = typename Traits::SmemTiledCopyC; + using GmemTiledCopyC = typename Traits::GmemTiledCopyC; + + using SharedStorage = GEMMSharedStorageSM80; + + const auto M = kBlockM * gridDim.x; + const auto N = params.n; + const auto K = params.k; + + const auto topk = params.topk; + const auto n_experts = params.n_experts; + + // each thread block takes care of one block: (BLK_M, BLK_N) + const auto m_block_idx = blockIdx.x; + const auto n_block_idx = blockIdx.y; + const auto tidx = threadIdx.x; + + const int expert_id = params.expert_ids_ptr[m_block_idx]; + + // ProblemShape + const int* sorted_token_idxes = params.sorted_token_idxes_ptr; + auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) { + return sorted_token_idxes[idx] / topk; + }; + // A: (M, K), k-major + auto A = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.a_ptr), + make_shape(M, K), + make_stride(get<0>(params.a_stride), _1{}), + idx_to_t_idx); + + // B: (N, K), k-major + const auto b_offset = expert_id * get<0>(params.b_stride); + auto B = make_tensor(make_gmem_ptr((const DTYPE*)params.b_ptr + b_offset), + make_shape(N, K), + make_stride(get<1>(params.b_stride), _1{})); + + // C: (M, N), n-major + auto idx_to_f_idx = [sorted_token_idxes](int idx) { + return sorted_token_idxes[idx]; + }; + auto C = make_gather_tensor(make_gmem_ptr((DTYPE*)params.c_ptr), + make_shape(M, N), + make_stride(get<0>(params.c_stride), _1{}), + idx_to_f_idx); + + // (M, K) => (BLK_M, BLK_K, k) + Tensor gA = + local_tile(A, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); + // (N, K) => (BLK_N, BLK_K, k) + Tensor gB = + local_tile(B, Shape<_BLK_N, _BLK_K>{}, make_coord(n_block_idx, _)); + // (M, N) => (BLK_M, BLK_N) + Tensor gC = local_tile( + C, Shape<_BLK_M, _BLK_N>{}, make_coord(m_block_idx, n_block_idx)); + + // Smem + extern __shared__ char smem[]; + auto& ss = *reinterpret_cast(smem); + + // (BLK_M, BLK_K, PIPE) + Tensor sA = make_tensor(make_smem_ptr(ss.a_smem.data()), SmemLayoutA{}); + // (BLK_N, BLK_K, PIPE) + Tensor sB = make_tensor(make_smem_ptr(ss.b_smem.data()), SmemLayoutB{}); + // (BLK_M, BLK_N) + // Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); + + // Tiled Copy + GmemTiledCopy gmem_tiled_copy; + auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); + + // (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K, k) + auto tAgA = gmem_thr_copy.partition_S(gA); + // (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE) + auto tAsA = gmem_thr_copy.partition_D(sA); + + // (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K, k) + auto tBgB = gmem_thr_copy.partition_S(gB); + // (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE) + auto tBsB = gmem_thr_copy.partition_D(sB); + + // GEMM: C = A@B.T + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + // rA: (BLK_M, BLK_K) => (MMA,MMA_M,MMA_K) + auto tCrA = thr_mma.partition_fragment_A(sA(_, _, _0{})); + // rB: (BLK_N, BLK_K) => (MMA,MMA_N,MMA_K) + auto tCrB = thr_mma.partition_fragment_B(sB(_, _, _0{})); + + // s2r tiled copy for A and B + auto smem_tiled_copy_a = SmemTiledCopyA{}; + auto smem_thr_copy_a = smem_tiled_copy_a.get_thread_slice(tidx); + // (BLK_M, BLK_K, PIPE) => (COPY, COPY_M, COPY_K, PIPE) + auto tCsA = smem_thr_copy_a.partition_S(sA); + // (COPY, COPY_M, COPY_K) + auto tCrA_cpv = smem_thr_copy_a.retile_D(tCrA); + + auto smem_tiled_copy_b = SmemTiledCopyB{}; + auto smem_thr_copy_b = smem_tiled_copy_b.get_thread_slice(tidx); + // (BLK_N, BLK_K, PIPE) => (COPY, COPY_N, COPY_K, PIPE) + auto tCsB = smem_thr_copy_b.partition_S(sB); + // (COPY, COPY_N, COPY_K) + auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB); + + // ############### Prologue ############### + // remaining k-tile count + int k_tiles_remaining = size<3>(tAgA); + // next tile index in gmem to read from + int k_tile_next = 0; + + // async loads for all pipes except the last one + auto kPipe = size<3>(tAsA); + CUTE_UNROLL + for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) { + copy(gmem_tiled_copy, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); + copy(gmem_tiled_copy, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); + cp_async_fence(); + + // advance to next k-tile + if (--k_tiles_remaining > 0) { + ++k_tile_next; + } + } + + // ############### Mainloop ############### + // (BLK_M, BLK_N) => (MMA, MMA_M, MMA_N) + auto tCrC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + cute::clear(tCrC); // Clear the accumulator + + // pipe index in smem to read from + int pipe_read = 0; + // pipe index in smem to write to + int pipe_write = kPipe - 1; + + // pipe to read from: (COPY, COPY_N, COPY_K) + Tensor tCsA_p = tCsA(_, _, _, pipe_read); + Tensor tCsB_p = tCsB(_, _, _, pipe_read); + + // Size of the register pipeline + auto kBlocks = size<2>(tCrA); + + // prefetch register pipeline + if (kBlocks > 1) { + // wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_a, tCsA_p(_, _, _0{}), tCrA_cpv(_, _, _0{})); + cute::copy(smem_tiled_copy_b, tCsB_p(_, _, _0{}), tCrB_cpv(_, _, _0{})); + } + + CUTE_NO_UNROLL + while (k_tiles_remaining > -(kPipe - 1)) { + CUTE_UNROLL + for (int ki = 0; ki < kBlocks; ++ki) { + // first block + if (ki == 0) { + // copy gmem to smem for next pipe + copy(gmem_tiled_copy, + tAgA(_, _, _, k_tile_next), + tAsA(_, _, _, pipe_write)); + copy(gmem_tiled_copy, + tBgB(_, _, _, k_tile_next), + tBsB(_, _, _, pipe_write)); + cp_async_fence(); + + // advance to next k-tile + if (--k_tiles_remaining > 0) { + ++k_tile_next; + } + } + // last block + if (ki == kBlocks - 1) { + // advance to next pipe + pipe_write = pipe_read; + pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; + + // advance to next pipe to read from + tCsA_p = tCsA(_, _, _, pipe_read); + tCsB_p = tCsB(_, _, _, pipe_read); + + // wait until our next prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + } + + // prefetch for next ki + auto ki_next = (ki + _1{}) % kBlocks; + copy(smem_tiled_copy_a, tCsA_p(_, _, ki_next), tCrA_cpv(_, _, ki_next)); + copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next)); + + // thread-level gemm for ki + gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrC); + } + } + + // ############### Epilogue ############### + // (BLK_M, BLK_N) + Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); + + // TODO: fastcast tCrC to DTYPE + + // copy tCrC from registers to smem + SmemTiledCopyC smem_tiled_copy_c; + auto smem_thr_copy_c = smem_tiled_copy_c.get_thread_slice(tidx); + auto tSrC = smem_thr_copy_c.retile_S(tCrC); + auto tSsC = smem_thr_copy_c.partition_D(sC); + cute::copy(smem_tiled_copy_c, tSrC, tSsC); + + // wait for smem copy done before gmem copy + __syncthreads(); + + // copy sC from smem to gmem + GmemTiledCopyC gmem_tiled_copy_c; + auto gmem_thr_copy_c = gmem_tiled_copy_c.get_thread_slice(tidx); + auto tGsC = gmem_thr_copy_c.partition_S(sC); + auto tGgC = gmem_thr_copy_c.partition_D(gC); + cute::copy(gmem_tiled_copy_c, tGsC, tGgC); +} + +template +void launch_grouped_gemm_kernel_sm80(const Params& params, + cudaStream_t stream) { + const auto smem_size = sizeof(GEMMSharedStorageSM80); + // std::cout << "SMEM size: " << smem_size << " bytes\n"; + + auto gemm_kernel = grouped_gemm_kernel_sm80; + cudaFuncSetAttribute( + gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // TODO: support persistent kernels + dim3 grid(params.m_blocks, params.n_blocks); + dim3 block = Traits::kThreadNum; + gemm_kernel<<>>(params); +} + +} // namespace llm diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu new file mode 100644 index 00000000..6915d0ad --- /dev/null +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -0,0 +1,209 @@ +#include +#include + +#include + +#include "grouped_gemm_kernel_sm80.cuh" // IWYU pragma: keep + +namespace llm { + +namespace { + +// reference implementation +std::tuple permute_align_block( + torch::Tensor topk_ids, // [n_tokens, topk] + int64_t n_experts, + int64_t block_size) { + const int64_t n_tokens = topk_ids.size(0); + const int64_t topk = topk_ids.size(1); + const int64_t n_flatten_tokens = topk_ids.numel(); + + auto topk_ids_cpu = topk_ids.cpu().contiguous(); + const int32_t* topk_ids_ptr = topk_ids_cpu.data_ptr(); + + std::vector> expert_to_idxes(n_experts); + for (int i = 0; i < n_flatten_tokens; ++i) { + const int32_t expert_id = topk_ids_ptr[i]; + assert(expert_id >= 0 && expert_id < n_experts); + expert_to_idxes[expert_id].push_back(i); + } + + std::vector sorted_token_idxes; + std::vector expert_ids; + int32_t n_padded_tokens = 0; + for (int e_idx = 0; e_idx < n_experts; ++e_idx) { + // flatten indices for each expert, sorted by token id + const auto& idxes = expert_to_idxes[e_idx]; + if (idxes.empty()) { + continue; + } + const auto count = idxes.size(); + const auto n_blocks = cute::ceil_div(count, block_size); + n_padded_tokens += (n_blocks * block_size); + // fill flatten indices for each block + for (int b_idx = 0; b_idx < n_blocks; ++b_idx) { + // expert id for each block + expert_ids.push_back(e_idx); + for (int offset = 0; offset < block_size; ++offset) { + auto idx = (b_idx * block_size) + offset; + if (idx < count) { + // fill flatten indices + sorted_token_idxes.push_back(idxes[idx]); + } else { + // fill padding + sorted_token_idxes.push_back(n_flatten_tokens); + } + } + } + } + + // construct tensor and return + const auto options = topk_ids.options(); + return {torch::tensor(sorted_token_idxes, options), + torch::tensor(expert_ids, options), + torch::tensor({n_padded_tokens}, options)}; +} + +torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) + const torch::Tensor& w, // (e, n, k) + const torch::Tensor& topk_ids // (m, topk) +) { + const auto m = a.size(0); + const auto k = a.size(1); + const auto n = w.size(1); + const auto n_experts = w.size(0); + const auto topk = topk_ids.size(1); + + // construct aligned + auto [sorted_token_idex, expert_ids, n_tokens_padded] = permute_align_block( + topk_ids.to(torch::kInt32), n_experts, /*block_size=*/64); + + // LOG(ERROR) << "sorted_token_idex: " << sorted_token_idex; + // LOG(ERROR) << "expert_ids: " << expert_ids; + // LOG(ERROR) << "n_padded_tokens: " << n_tokens_padded; + + // (m * topk, n) + auto out = torch::zeros({m * topk, n}, a.options()); + + using Traits = GEMMTraitsSM80; /*PIPE*/ + + // construct params + GEMMParams params; + params.a_ptr = a.const_data_ptr(); + params.a_stride = make_stride(a.stride(0)); + params.b_ptr = w.const_data_ptr(); + params.b_stride = make_stride(w.stride(0), w.stride(1)); + params.c_ptr = out.mutable_data_ptr(); + params.c_stride = make_stride(out.stride(0)); + + params.sorted_token_idxes_ptr = sorted_token_idex.const_data_ptr(); + params.expert_ids_ptr = expert_ids.const_data_ptr(); + params.n_tokens_padded = n_tokens_padded.const_data_ptr(); + + params.m = m; + params.n = n; + params.k = k; + params.topk = topk; + + params.m_blocks = expert_ids.size(0); + params.n_blocks = cute::ceil_div(n, 64); + + launch_grouped_gemm_kernel_sm80(params, nullptr); + + // (m * topk, n) => (m, topk, n) + return out.view({m, topk, n}); +} + +// returns (m, topk, n) +torch::Tensor grouped_gemm_ref(const torch::Tensor& a, // (m, k) + const torch::Tensor& w, // (e, n, k) + const torch::Tensor& topk_ids // (m, topk) + +) { + const auto m = a.size(0); + const auto k = a.size(1); + const auto n = w.size(1); + const auto n_experts = w.size(0); + const auto topk = topk_ids.size(1); + + // (m * topk, n) + auto out = torch::zeros({m * topk, n}, a.options()); + + // (m, k) => (m, topk, k) => (m * topk, k) + auto a_expanded_flat = + a.unsqueeze(/*dim=*/1).expand({-1, topk, -1}).reshape({-1, k}); + // (m, topk) => (m * topk) + auto topk_ids_flat = topk_ids.reshape({-1}); + + // process each expert + for (int64_t e = 0; e < n_experts; ++e) { + // 1D indices for the current expert + auto indices = torch::nonzero(topk_ids_flat == e).squeeze(); + // select corresponding tokens + auto a_selected = a_expanded_flat.index_select(/*dim=*/0, indices); + // perform the GEMM operation for this expert + auto e_out = torch::matmul(a_selected, w[e].transpose(0, 1)); + // copy the results into the output tensor + out.index_copy_(/*dim=*/0, indices, e_out); + } + // (m * topk, n) => (m, topk, n) + return out.view({m, topk, n}); +} + +} // namespace + +class GroupedGemmKernelTest + : public ::testing::TestWithParam> { + public: + void SetUp() override { + // Set random seed for test stability + torch::manual_seed(0); + } +}; + +TEST_P(GroupedGemmKernelTest, GEMM) { + const auto [dtype, m, n, k, n_experts, topk] = GetParam(); + + const auto options = torch::dtype(dtype).device(torch::kCUDA); + + // Create input tensors + auto a = torch::randn({m, k}, options); + auto w = torch::randn({n_experts, n, k}, options); + + // Get top-k indices + auto logits = torch::randn({m, n_experts}, options).softmax(/*dim=*/1); + auto [topk_weights, topk_ids] = logits.topk(topk, /*dim=*/1); + + auto ref_out = grouped_gemm_ref(a, w, topk_ids); + // LOG(ERROR) << "ref_out: " << ref_out; + auto out = grouped_gemm_sm80(a, w, topk_ids); + + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + + // auto max_diff = (out - ref_out).abs().max(); + // LOG(ERROR) << "Max diff: " << max_diff; + // LOG(ERROR) << "ref_out: " << ref_out; + // LOG(ERROR) << "out: " << out; +} + +INSTANTIATE_TEST_SUITE_P( + GEMM, + GroupedGemmKernelTest, + ::testing::Combine(::testing::Values(torch::kHalf), // dtype + ::testing::Values(64, 128), // m + ::testing::Values(64, 128), // n + ::testing::Values(64, 128), // k + ::testing::Values(1), // n_experts + ::testing::Values(1) // topk + )); + +} // namespace llm diff --git a/src/kernels/moe/align_block_kernel_test.cu b/src/kernels/moe/align_block_kernel_test.cu index 73593799..66f0a418 100644 --- a/src/kernels/moe/align_block_kernel_test.cu +++ b/src/kernels/moe/align_block_kernel_test.cu @@ -146,10 +146,6 @@ class AlignBlockTest TEST_P(AlignBlockTest, AlignBlock) { const auto [dtype, n_tokens, dim, n_experts, topk, block_size] = GetParam(); const int64_t n_flatten_tokens = n_tokens * topk; - if (n_flatten_tokens >= 1024 || n_experts > 64) { - // TODO: reenable unittest after fixing tokens out of order issue - return; - } const auto options = torch::dtype(dtype).device(torch::kCUDA); const auto options_int32 = options.dtype(torch::kInt32);