diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt index 99cc5920..97f28a80 100644 --- a/src/kernels/gemm/CMakeLists.txt +++ b/src/kernels/gemm/CMakeLists.txt @@ -5,7 +5,10 @@ cc_library( NAME gemm.kernels HDRS - grouped_gemm_kernel_sm80.cuh + sm80_collective_grouped_gemm.cuh + sm80_collective_epilogue.cuh + sm80_grouped_gemm_launch.cuh + tile_scheduler.cuh DEPS cutlass ) @@ -15,7 +18,7 @@ cc_test( NAME gemm_kernel_test SRCS - grouped_gemm_kernel_sm80_test.cu + sm80_grouped_gemm_test.cu DEPS :gemm.kernels absl::random_random diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh deleted file mode 100644 index 091a037f..00000000 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ /dev/null @@ -1,483 +0,0 @@ -#pragma once -#include -#include - -#include -#include -#include - -#include "fast_cast.cuh" -#include "gather_tensor.hpp" -#include "safe_copy.hpp" - -namespace llm { -using namespace cute; - -template -struct GEMMTraitsSM80 { - static constexpr int kBlockM = BLK_M; - static constexpr int kBlockN = BLK_N; - static constexpr int kBlockK = BLK_K; - static constexpr int kPipe = PIPE; - - static_assert(kBlockM % 64 == 0); - static_assert(kBlockN % 32 == 0); - static_assert(kBlockK % 16 == 0); - - // helpful aliases - using DType = DTYPE; - using _BLK_M = Int; - using _BLK_N = Int; - using _BLK_K = Int; - using _PIPE = Int; - - // MMA Atom: (16x8x16) for F32F16F16F32 or F32BF16BF16F32 - using MMA_Atom_ = - std::conditional_t, - MMA_Atom, - MMA_Atom>; - - // TiledMMA: (64x16x16) - using TiledMma = TiledMMA>, // warp layout: (4x1x1) - Tile<_64, _16, _16>>; // tile layout: (64x16x16) - - // Shared memory LayoutAtom (8x64) - using SmemLayoutAtom_8x64 = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride<_64, _1>>{})); - using SmemLayoutAtom_8x32 = - decltype(composition(Swizzle<2, 3, 3>{}, - Layout, Stride<_32, _1>>{})); - - using SmemLayoutAtom = std::conditional_t; - // SMEM Layout for A: (BLK_M, BLK_K, PIPE) - using SmemLayoutA = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_K, _PIPE>{})); - // SMEM Layout for B: (BLK_N, BLK_K, PIPE) - using SmemLayoutB = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _BLK_K, _PIPE>{})); - - // Thread layout for gmem copy: (_16,_8)/(_32, _4) - using GmemCopyThrLayout = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - // g2s tiled copy: copy A/B from global memory to shared memory - using GmemTiledCopy = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); - - // s2r tiled copy for A and B - using SmemTiledCopyA = - decltype(make_tiled_copy_A(Copy_Atom{}, - TiledMma{})); - using SmemTiledCopyB = - decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); - - // ******* Epilogue ******* - - using SmemLayoutAtomC = std::conditional_t; - using SmemLayoutC = - decltype(tile_to_shape(SmemLayoutAtomC{}, Shape<_BLK_M, _BLK_N>{})); - - // use 128-bit vectorizing copy - using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; - // r2s tiled copy for C - using SmemTiledCopyC = - decltype(make_tiled_copy_C(Copy_Atom{}, - TiledMma{})); - - // s2g tiled copy for O - using GmemTiledCopyC = decltype(make_tiled_copy( - Copy_Atom{}, - GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); - - // constexpr values for kernel launch - static constexpr size_t kThreadNum = size(TiledMma{}); -}; - -template -struct GEMMSharedStorageSM80 { - using DType = typename Traits::DType; - using SmemLayoutA = typename Traits::SmemLayoutA; - using SmemLayoutB = typename Traits::SmemLayoutB; - using SmemLayoutC = typename Traits::SmemLayoutC; - - union { - struct { - // Shared memory for A: (BLK_M, BLK_K, PIPE) - cute::array_aligned> a_smem; - // Shared memory for B: (BLK_N, BLK_K, PIPE) - cute::array_aligned> b_smem; - }; - // Shared memory for C: (BLK_M, BLK_N) - cute::array_aligned> c_smem; - }; -}; - -struct GEMMParams { - using AStride = Stride; - using BStride = Stride; - using CStride = Stride; - - // A: (m, k) - const void* __restrict__ a_ptr = nullptr; - AStride a_stride; - - // B: (e, n, k) - const void* __restrict__ b_ptr = nullptr; - BStride b_stride; - - // C: ((m, topk), n) - void* __restrict__ c_ptr = nullptr; - CStride c_stride; - - // (m_blocks*BLK_M) - const int* __restrict__ sorted_token_idxes_ptr = nullptr; - // (m_blocks) - const int* __restrict__ expert_ids_ptr = nullptr; - - const int* __restrict__ n_tokens_padded = nullptr; - - int m = 0; - int n = 0; - int k = 0; - int topk = 0; - - int m_blocks = 0; - int n_blocks = 0; -}; - -template -__global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( - __grid_constant__ const Params params) { - // Traits - constexpr int kBlockM = Traits::kBlockM; - constexpr int kBlockN = Traits::kBlockN; - constexpr int kBlockK = Traits::kBlockK; - - using _BLK_M = Int; - using _BLK_N = Int; - using _BLK_K = Int; - - using DType = typename Traits::DType; - using TiledMma = typename Traits::TiledMma; - - using SmemLayoutA = typename Traits::SmemLayoutA; - using SmemLayoutB = typename Traits::SmemLayoutB; - using SmemLayoutC = typename Traits::SmemLayoutC; - - using GmemTiledCopy = typename Traits::GmemTiledCopy; - using SmemTiledCopyA = typename Traits::SmemTiledCopyA; - using SmemTiledCopyB = typename Traits::SmemTiledCopyB; - using SmemTiledCopyC = typename Traits::SmemTiledCopyC; - using GmemTiledCopyC = typename Traits::GmemTiledCopyC; - - using SharedStorage = GEMMSharedStorageSM80; - - const auto M = kBlockM * gridDim.x; - const auto N = params.n; - const auto K = params.k; - const auto topk = params.topk; - - // each thread block takes care of one block: (BLK_M, BLK_N) - const auto m_block_idx = blockIdx.x; - const auto n_block_idx = blockIdx.y; - const auto tidx = threadIdx.x; - - const int expert_id = params.expert_ids_ptr[m_block_idx]; - const int n_flatten_tokens = params.m * topk; - - // ProblemShape - const int* sorted_token_idxes = params.sorted_token_idxes_ptr; - auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) { - return sorted_token_idxes[idx] / topk; - }; - // A: (M, K), k-major - auto A = make_gather_tensor(make_gmem_ptr((const DType*)params.a_ptr), - make_shape(M, K), - params.a_stride, - idx_to_t_idx); - - // B: (N, K), k-major - const auto b_offset = expert_id * get<0>(params.b_stride); - auto B = make_tensor(make_gmem_ptr((const DType*)params.b_ptr + b_offset), - make_shape(N, K), - select<1, 2>(params.b_stride)); - - // C: (M, N), n-major - auto idx_to_f_idx = [sorted_token_idxes](int idx) { - return sorted_token_idxes[idx]; - }; - auto C = make_gather_tensor(make_gmem_ptr((DType*)params.c_ptr), - make_shape(M, N), - params.c_stride, - idx_to_f_idx); - - auto max_coord_mk = make_coord(M, K); - auto max_coord_nk = make_coord(N, K); - auto max_coord_mn = make_coord(M, N); - - // (M, K) => (BLK_M, BLK_K, k) - Tensor gA = - local_tile(A, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); - // (BLK_M, BLK_K, k) => (M, K) - Tensor cA = local_tile(make_identity_tensor(make_shape(M, K)), - Shape<_BLK_M, _BLK_K>{}, - make_coord(m_block_idx, _)); - // (N, K) => (BLK_N, BLK_K, k) - Tensor gB = - local_tile(B, Shape<_BLK_N, _BLK_K>{}, make_coord(n_block_idx, _)); - // (BLK_N, BLK_K, k) => (N, K) - Tensor cB = local_tile(make_identity_tensor(make_shape(N, K)), - Shape<_BLK_N, _BLK_K>{}, - make_coord(n_block_idx, _)); - // (M, N) => (BLK_M, BLK_N) - Tensor gC = local_tile( - C, Shape<_BLK_M, _BLK_N>{}, make_coord(m_block_idx, n_block_idx)); - // (BLK_M, BLK_N) => (M, N) - Tensor cC = local_tile(make_identity_tensor(make_shape(M, N)), - Shape<_BLK_M, _BLK_N>{}, - make_coord(m_block_idx, n_block_idx)); - - // Smem - extern __shared__ char smem[]; - auto& ss = *reinterpret_cast(smem); - - // (BLK_M, BLK_K, PIPE) - Tensor sA = make_tensor(make_smem_ptr(ss.a_smem.data()), SmemLayoutA{}); - // (BLK_N, BLK_K, PIPE) - Tensor sB = make_tensor(make_smem_ptr(ss.b_smem.data()), SmemLayoutB{}); - // (BLK_M, BLK_N) - // Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); - - // Tiled Copy - GmemTiledCopy gmem_tiled_copy; - auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); - - // (BLK_M, BLK_K, k) => (CPY, CPY_M, CPY_K, k) - auto tAgA = gmem_thr_copy.partition_S(gA); - // (CPY, CPY_M, CPY_K, k) => (M, K) - auto tAcA = gmem_thr_copy.partition_S(cA); - // (BLK_M, BLK_K, PIPE) => (CPY, CPY_M, CPY_K, PIPE) - auto tAsA = gmem_thr_copy.partition_D(sA); - - // (CPY_M) => (M, K) - auto tAcA_m = tAcA(_0{}, _, _0{}, _0{}); - auto tApA = make_tensor(make_shape(size(tAcA_m))); - CUTE_UNROLL - for (int i = 0; i < size(tApA); ++i) { - const auto f_idx = sorted_token_idxes[get<0>(tAcA_m(i))]; - tApA(i) = f_idx < n_flatten_tokens; - } - - // (BLK_N, BLK_K, k) => (CPY, CPY_N, CPY_K, k) - auto tBgB = gmem_thr_copy.partition_S(gB); - // (CPY, CPY_N, CPY_K, k) => (N, K) - auto tBcB = gmem_thr_copy.partition_S(cB); - // (BLK_N, BLK_K, PIPE) => (CPY, CPY_N, CPY_K, PIPE) - auto tBsB = gmem_thr_copy.partition_D(sB); - - auto produce_ab = [&](int k_tile, int k_pipe) { - safe_copy_m( - gmem_tiled_copy, - tAgA(_, _, _, k_tile), - tAsA(_, _, _, k_pipe), - tApA, - tAcA(_, _, _, k_tile), - max_coord_mk); - - safe_copy_n( - gmem_tiled_copy, - tBgB(_, _, _, k_tile), - tBsB(_, _, _, k_pipe), - tBcB(_, _, _, k_tile), - max_coord_nk); - }; - - auto produce_ab_no_oob = [&](int k_tile, int k_pipe) { - safe_copy_m( - gmem_tiled_copy, - tAgA(_, _, _, k_tile), - tAsA(_, _, _, k_pipe), - tApA, - tAcA(_, _, _, k_tile), - max_coord_mk); - - safe_copy_n( - gmem_tiled_copy, - tBgB(_, _, _, k_tile), - tBsB(_, _, _, k_pipe), - tBcB(_, _, _, k_tile), - max_coord_nk); - }; - - // GEMM: C = A@B.T - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - // rA: (BLK_M, BLK_K) => (MMA,MMA_M,MMA_K) - auto tCrA = thr_mma.partition_fragment_A(sA(_, _, _0{})); - // rB: (BLK_N, BLK_K) => (MMA,MMA_N,MMA_K) - auto tCrB = thr_mma.partition_fragment_B(sB(_, _, _0{})); - - // s2r tiled copy for A and B - auto smem_tiled_copy_a = SmemTiledCopyA{}; - auto smem_thr_copy_a = smem_tiled_copy_a.get_thread_slice(tidx); - // (BLK_M, BLK_K, PIPE) => (CPY, CPY_M, CPY_K, PIPE) - auto tCsA = smem_thr_copy_a.partition_S(sA); - // (CPY, CPY_M, CPY_K) - auto tCrA_cpv = smem_thr_copy_a.retile_D(tCrA); - - auto smem_tiled_copy_b = SmemTiledCopyB{}; - auto smem_thr_copy_b = smem_tiled_copy_b.get_thread_slice(tidx); - // (BLK_N, BLK_K, PIPE) => (CPY, CPY_N, CPY_K, PIPE) - auto tCsB = smem_thr_copy_b.partition_S(sB); - // (CPY, CPY_N, CPY_K) - auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB); - - // ############### Prologue ############### - // remaining k-tile count - int k_tiles_remaining = size<3>(tAgA); - // next tile index in gmem to read from - int k_tile = 0; - - // async loads for all pipes except the last one - auto kPipe = size<3>(tAsA); - CUTE_UNROLL - for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) { - if (k_pipe == 0) { - produce_ab(k_tile, k_pipe); - } else { - produce_ab_no_oob(k_tile, k_pipe); - } - cp_async_fence(); - - // advance to next k-tile - if (--k_tiles_remaining > 0) { - ++k_tile; - } - } - - // ############### Mainloop ############### - // (BLK_M, BLK_N) => (MMA, MMA_M, MMA_N) - auto tCrAccC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); - cute::clear(tCrAccC); // Clear the accumulator - - // pipe index in smem to read from - int pipe_read = 0; - // pipe index in smem to write to - int pipe_write = kPipe - 1; - - // pipe to read from: (CPY, CPY_N, CPY_K) - Tensor tCsA_p = tCsA(_, _, _, pipe_read); - Tensor tCsB_p = tCsB(_, _, _, pipe_read); - - // Size of the register pipeline - auto kBlocks = size<2>(tCrA); - - // prefetch register pipeline - if (kBlocks > 1) { - // wait until our first prefetched tile is loaded in - cp_async_wait(); - __syncthreads(); - - // prefetch the first rmem from the first k-tile - cute::copy(smem_tiled_copy_a, tCsA_p(_, _, _0{}), tCrA_cpv(_, _, _0{})); - cute::copy(smem_tiled_copy_b, tCsB_p(_, _, _0{}), tCrB_cpv(_, _, _0{})); - } - - CUTE_NO_UNROLL - while (k_tiles_remaining > -(kPipe - 1)) { - CUTE_UNROLL - for (int ki = 0; ki < kBlocks; ++ki) { - // first block - if (ki == 0) { - // copy gmem to smem for next pipe - produce_ab_no_oob(k_tile, pipe_write); - cp_async_fence(); - - // advance to next k-tile - if (--k_tiles_remaining > 0) { - ++k_tile; - } - } - // last block - if (ki == kBlocks - 1) { - // advance to next pipe - pipe_write = pipe_read; - pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; - - // advance to next pipe to read from - tCsA_p = tCsA(_, _, _, pipe_read); - tCsB_p = tCsB(_, _, _, pipe_read); - - // wait until our next prefetched tile is loaded in - cp_async_wait(); - __syncthreads(); - } - - // prefetch for next ki - auto ki_next = (ki + _1{}) % kBlocks; - copy(smem_tiled_copy_a, tCsA_p(_, _, ki_next), tCrA_cpv(_, _, ki_next)); - copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next)); - - // thread-level gemm for ki - gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrAccC); - } - } - - // ############### Epilogue ############### - // (BLK_M, BLK_N) - Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); - - // fastcast tCrAccC to DType - auto tCrC = make_tensor_like(tCrAccC); - fast_cast(tCrAccC, tCrC); - - // copy tCrC from registers to smem - SmemTiledCopyC smem_tiled_copy_c; - auto smem_thr_copy_c = smem_tiled_copy_c.get_thread_slice(tidx); - auto tSrC = smem_thr_copy_c.retile_S(tCrC); - auto tSsC = smem_thr_copy_c.partition_D(sC); - cute::copy(smem_tiled_copy_c, tSrC, tSsC); - - // wait for smem copy done before gmem copy - __syncthreads(); - - // copy sC from smem to gmem - GmemTiledCopyC gmem_tiled_copy_c; - auto gmem_thr_copy_c = gmem_tiled_copy_c.get_thread_slice(tidx); - auto tGsC = gmem_thr_copy_c.partition_S(sC); - auto tGgC = gmem_thr_copy_c.partition_D(gC); - // (CPY, CPY_M, CPY_N) => (M, N) - auto tGcC = gmem_thr_copy_c.partition_D(cC); - safe_copy_m( - gmem_tiled_copy_c, tGsC, tGgC, tApA, tGcC, max_coord_mn); -} - -template -void launch_grouped_gemm_kernel_sm80(const Params& params, - cudaStream_t stream) { - const auto smem_size = sizeof(GEMMSharedStorageSM80); - // std::cout << "SMEM size: " << smem_size << " bytes\n"; - - auto gemm_kernel = grouped_gemm_kernel_sm80; - cudaFuncSetAttribute( - gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // TODO: support persistent kernels - dim3 grid(params.m_blocks, params.n_blocks); - dim3 block = Traits::kThreadNum; - gemm_kernel<<>>(params); -} - -} // namespace llm diff --git a/src/kernels/gemm/safe_copy.hpp b/src/kernels/gemm/safe_copy.hpp index 423f71c2..ea7c16d6 100644 --- a/src/kernels/gemm/safe_copy.hpp +++ b/src/kernels/gemm/safe_copy.hpp @@ -24,7 +24,7 @@ template CUTE_HOST_DEVICE void safe_copy_m( @@ -32,8 +32,8 @@ CUTE_HOST_DEVICE void safe_copy_m( const SrcTensor& src, // (CPY, CPY_M, CPY_K) DstTensor& dst, // (CPY, CPY_M, CPY_K) const PrdTensor& pred_m, // (CPY_M) -> bool - const IdenTensor& identity, // (CPY, CPY_M, CPY_K) -> (blk_m, blk_k) - const MaxCoord& max_coord // max_coord(blk_m, blk_k) + const IdenTensor& identity, // (CPY, CPY_M, CPY_K) -> (m, k) + const ResidueMK& residue_mk // (m, k) ) { CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(dst)); // CPY == CPY CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(identity)); // CPY == CPY @@ -50,7 +50,7 @@ CUTE_HOST_DEVICE void safe_copy_m( if (pred_m(mi)) { CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { - if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + if (elem_less<1>(identity(_0{}, _0{}, ki), residue_mk)) { copy(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } else if constexpr (ZFILL_K) { clear(dst(_, mi, ki)); @@ -82,15 +82,15 @@ template CUTE_HOST_DEVICE void safe_copy_n( const TiledCopy& tiled_copy, const SrcTensor& src, // (CPY, CPY_N, CPY_K) DstTensor& dst, // (CPY, CPY_N, CPY_K) - const IdenTensor& identity, // (CPY, CPY_N, CPY_K) -> (blk_n, blk_k) - const MaxCoord& max_coord // max_coord(blk_n, blk_k) + const IdenTensor& identity, // (CPY, CPY_N, CPY_K) -> (n, k) + const ResidueNK& residue_nk // (n, k) ) { CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(dst)); // CPY == CPY CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(identity)); // CPY == CPY @@ -104,10 +104,10 @@ CUTE_HOST_DEVICE void safe_copy_n( // handle both n and k oob CUTE_UNROLL for (int ni = 0; ni < size<1>(src); ++ni) { - if (elem_less<0>(identity(_0{}, ni, _0{}), max_coord)) { + if (elem_less<0>(identity(_0{}, ni, _0{}), residue_nk)) { CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { - if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + if (elem_less<1>(identity(_0{}, _0{}, ki), residue_nk)) { copy(copy_atom, src(_, ni, ki), dst(_, ni, ki)); } else if constexpr (ZFILL_K) { clear(dst(_, ni, ki)); @@ -121,7 +121,7 @@ CUTE_HOST_DEVICE void safe_copy_n( // only handle n oob CUTE_UNROLL for (int mi = 0; mi < size<1>(src); ++mi) { - if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { + if (elem_less<0>(identity(_0{}, mi, _0{}), residue_nk)) { copy(copy_atom, src(_, mi, _), dst(_, mi, _)); } else if constexpr (ZFILL_N) { clear(dst(_, mi, _)); @@ -131,7 +131,7 @@ CUTE_HOST_DEVICE void safe_copy_n( // only handle k oob CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { - if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + if (elem_less<1>(identity(_0{}, _0{}, ki), residue_nk)) { copy(copy_atom, src(_, _, ki), dst(_, _, ki)); } else if constexpr (ZFILL_K) { clear(dst(_, _, ki)); @@ -152,17 +152,17 @@ template + class ResidueMK> CUTE_HOST_DEVICE void safe_copy_m( const CopyPolicy& tiled_copy, const SrcTensor& src, // (CPY, CPY_M, CPY_K) DstTensor&& dst, // (CPY, CPY_M, CPY_K) const PrdTensor& pred_m, // (CPY_M) -> bool - const IdenTensor& identity, // (CPY, CPY_M, CPY_K) -> (blk_m, blk_k) - const MaxCoord& max_coord // max_coord(blk_m, blk_k) + const IdenTensor& identity, // (CPY, CPY_M, CPY_K) -> (m, k) + const ResidueMK& residue_mk // (m, k) ) { return safe_copy_m( - tiled_copy, src, dst, pred_m, identity, max_coord); + tiled_copy, src, dst, pred_m, identity, residue_mk); } template + class ResidueNK> CUTE_HOST_DEVICE void safe_copy_n( const CopyPolicy& tiled_copy, const SrcTensor& src, // (CPY, CPY_N, CPY_K) DstTensor&& dst, // (CPY, CPY_N, CPY_K) - const IdenTensor& identity, // (CPY, CPY_N, CPY_K) -> (blk_n, blk_k) - const MaxCoord& max_coord // max_coord(blk_n, blk_k) + const IdenTensor& identity, // (CPY, CPY_N, CPY_K) -> (n, k) + const ResidueNK& residue_nk // (n, k) ) { return safe_copy_n( - tiled_copy, src, dst, identity, max_coord); + tiled_copy, src, dst, identity, residue_nk); } } // namespace llm diff --git a/src/kernels/gemm/sm80_collective_epilogue.cuh b/src/kernels/gemm/sm80_collective_epilogue.cuh new file mode 100644 index 00000000..ac84ae00 --- /dev/null +++ b/src/kernels/gemm/sm80_collective_epilogue.cuh @@ -0,0 +1,142 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "fast_cast.cuh" +#include "safe_copy.hpp" + +namespace llm { +using namespace cute; + +template +struct Sm80CollectiveEpilogue { + using TileShape = TileShape_; + using Element = Element_; + + static constexpr bool EVEN_N = EVEN_N_; + + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockN = get<1>(TileShape{}); + + using BLK_M = Int; + using BLK_N = Int; + + // Shared memory LayoutAtom (8x64) + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + + using SmemLayoutAtomC = std::conditional_t; + using SmemLayoutC = + decltype(tile_to_shape(SmemLayoutAtomC{}, Shape{})); + + // use 128-bit vectorizing copy + using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>; + // r2s tiled copy for C + using SmemCopyAtomC_ = Copy_Atom; + + // Thread layout for gmem copy: (_16,_8)/(_32, _4) + using GmemCopyThrLayout = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + + // s2g tiled copy for O + using GmemTiledCopyC = decltype(make_tiled_copy( + Copy_Atom{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + struct SharedStorage : cute::aligned_struct<128> { + // Shared memory for C: (BLK_M, BLK_N) + cute::array_aligned> c_smem; + }; + + // Host side kernel arguments + struct Arguments { + const int* sorted_token_idxes_ptr = nullptr; + int n_flatten_tokens = 0; + }; + + // Device side kernel params + using Params = Arguments; + + // Convert host side arguments to device side params + static Params to_underlying_arguments(Arguments const& args) { return args; } + + template + CUTE_DEVICE void operator()( + const Params& params, + const FrgTensor& tCrAccC, // (MMA, MMA_M, MMA_N) + TiledMma tiled_mma, + TensorC& gC, // (BLK_M, BLK_M) + const IdentTensorC& cC, // (BLK_M, BLK_N) => (M, N) + int tidx, + const ResidueMNK& residue_mnk, + char* smem) { + static_assert(is_rmem::value, + "Accum tensor must be rmem resident."); + static_assert(is_gmem::value, "C tensor must be gmem resident."); + + static constexpr int kBlockM = get<0>(TileShape{}); + const auto residue_mn = select<0, 1>(residue_mnk); + + // Smem + auto& ss = *reinterpret_cast(smem); + // (BLK_M, BLK_N) + auto sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{}); + + // fastcast tCrAccC to Element + auto tCrC = make_tensor_like(tCrAccC); + fast_cast(tCrAccC, tCrC); + + // copy tCrC from registers to smem + auto smem_tiled_copy_c = make_tiled_copy_C(SmemCopyAtomC_{}, tiled_mma); + auto smem_thr_copy_c = smem_tiled_copy_c.get_thread_slice(tidx); + auto tSrC = smem_thr_copy_c.retile_S(tCrC); + auto tSsC = smem_thr_copy_c.partition_D(sC); + cute::copy(smem_tiled_copy_c, tSrC, tSsC); + + // copy sC from smem to gmem + GmemTiledCopyC gmem_tiled_copy_c; + auto gmem_thr_copy_c = gmem_tiled_copy_c.get_thread_slice(tidx); + auto tGsC = gmem_thr_copy_c.partition_S(sC); + auto tGgC = gmem_thr_copy_c.partition_D(gC); + // (CPY, CPY_M, CPY_N) => (M, N) + auto tGcC = gmem_thr_copy_c.partition_D(cC); + + const int* sorted_token_idxes = params.sorted_token_idxes_ptr; + const int n_flatten_tokens = params.n_flatten_tokens; + + // (CPY_M) => (M, N) + auto tGcC_m = tGcC(_0{}, _, _0{}); + auto tGpC = make_tensor(make_shape(size(tGcC_m))); + CUTE_UNROLL + for (int i = 0; i < size(tGpC); ++i) { + const auto f_idx = sorted_token_idxes[get<0>(tGcC_m(i))]; + tGpC(i) = f_idx < n_flatten_tokens; + } + + // wait for smem copy done before gmem copy + __syncthreads(); + + safe_copy_m( + gmem_tiled_copy_c, tGsC, tGgC, tGpC, tGcC, residue_mn); + } +}; +} // namespace llm diff --git a/src/kernels/gemm/sm80_collective_grouped_gemm.cuh b/src/kernels/gemm/sm80_collective_grouped_gemm.cuh new file mode 100644 index 00000000..240fc6dd --- /dev/null +++ b/src/kernels/gemm/sm80_collective_grouped_gemm.cuh @@ -0,0 +1,318 @@ +#pragma once +#include +#include + +#include +#include +#include + +#include "safe_copy.hpp" + +namespace llm { +using namespace cute; + +template +struct Sm80CollectiveGroupedGEMM { + using TileShape = TileShape_; + using Element = Element_; + using ElementAccum = float; + + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockN = get<1>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + + static_assert(kBlockM % 64 == 0); + static_assert(kBlockN % 32 == 0); + static_assert(kBlockK % 16 == 0); + + using BLK_M = Int; + using BLK_N = Int; + using BLK_K = Int; + using PIPE = Int; + + // MMA Atom: (16x8x16) for F32F16F16F32 or F32BF16BF16F32 + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + + // TiledMMA: (64x16x16) + using TiledMma = TiledMMA>, // warp layout: (4x1x1) + Tile<_64, _16, _16>>; // tile layout: (64x16x16) + + static constexpr int kMmaThreads = size(TiledMma{}); + + // Shared memory LayoutAtom (8x64) + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + + using SmemLayoutAtom = std::conditional_t; + // SMEM Layout for A: (BLK_M, BLK_K, PIPE) + using SmemLayoutA = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape{})); + // SMEM Layout for B: (BLK_N, BLK_K, PIPE) + using SmemLayoutB = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape{})); + + // Thread layout for gmem copy: (_16,_8)/(_32, _4) + using GmemCopyThrLayout = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + // g2s tiled copy: copy A/B from global memory to shared memory + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // s2r tiled copy for A and B + using SmemTiledCopyA = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma{})); + using SmemTiledCopyB = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + struct SharedStorage : cute::aligned_struct<128> { + // Shared memory for A: (BLK_M, BLK_K, PIPE) + cute::array_aligned> a_smem; + // Shared memory for B: (BLK_N, BLK_K, PIPE) + cute::array_aligned> b_smem; + }; + + // Host side arguments + struct Arguments { + const int* sorted_token_idxes_ptr = nullptr; + int n_flatten_tokens = 0; + }; + + // Device side params + using Params = Arguments; + + // Convert host side arguments to device side params + static Params to_underlying_arguments(Arguments const& args) { + // no convertion needed. + return args; + } + + // returns false if the block has been skipped + template + CUTE_DEVICE void operator()( + const Params& params, + const TensorA& gA, // (BLK_M, BLK_K, k) + const IdentTensorA& cA, // (BLK_M, BLK_K, k) => (M, K) + const TensorB& gB, // (BLK_N, HEAD_DIM, n) + const IdentTensorB& cB, // (BLK_N, BLK_K, k) => (N, K) + FrgTensor& tCrAccC, // (MMA, MMA_M, MMA_N) + int tidx, + const ResidueMNK& residue_mnk, + char* smem) { + static_assert(is_rmem::value, + "Accum tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + + const auto residue_mk = select<0, 2>(residue_mnk); + const auto residue_nk = select<1, 2>(residue_mnk); + + auto& ss = *reinterpret_cast(smem); + + // (BLK_M, BLK_K, PIPE) + Tensor sA = make_tensor(make_smem_ptr(ss.a_smem.data()), SmemLayoutA{}); + // (BLK_N, BLK_K, PIPE) + Tensor sB = make_tensor(make_smem_ptr(ss.b_smem.data()), SmemLayoutB{}); + + // Tiled Copy + GmemTiledCopy gmem_tiled_copy; + auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); + + // (BLK_M, BLK_K, k) => (CPY, CPY_M, CPY_K, k) + auto tAgA = gmem_thr_copy.partition_S(gA); + // (CPY, CPY_M, CPY_K, k) => (M, K) + auto tAcA = gmem_thr_copy.partition_S(cA); + // (BLK_M, BLK_K, PIPE) => (CPY, CPY_M, CPY_K, PIPE) + auto tAsA = gmem_thr_copy.partition_D(sA); + + // (CPY_M) => (M, K) + const int* sorted_token_idxes = params.sorted_token_idxes_ptr; + const int n_flatten_tokens = params.n_flatten_tokens; + auto tAcA_m = tAcA(_0{}, _, _0{}, _0{}); + // (CPY_M) => bool + auto tApA = make_tensor(make_shape(size(tAcA_m))); + CUTE_UNROLL + for (int i = 0; i < size(tApA); ++i) { + const auto f_idx = sorted_token_idxes[get<0>(tAcA_m(i))]; + tApA(i) = f_idx < n_flatten_tokens; + } + + // (BLK_N, BLK_K, k) => (CPY, CPY_N, CPY_K, k) + auto tBgB = gmem_thr_copy.partition_S(gB); + // (CPY, CPY_N, CPY_K, k) => (N, K) + auto tBcB = gmem_thr_copy.partition_S(cB); + // (BLK_N, BLK_K, PIPE) => (CPY, CPY_N, CPY_K, PIPE) + auto tBsB = gmem_thr_copy.partition_D(sB); + + auto produce_ab = [&](int k_tile, int k_pipe) { + safe_copy_m( + gmem_tiled_copy, + tAgA(_, _, _, k_tile), + tAsA(_, _, _, k_pipe), + tApA, + tAcA(_, _, _, k_tile), + residue_mk); + + safe_copy_n( + gmem_tiled_copy, + tBgB(_, _, _, k_tile), + tBsB(_, _, _, k_pipe), + tBcB(_, _, _, k_tile), + residue_nk); + }; + + auto produce_ab_no_oob = [&](int k_tile, int k_pipe) { + safe_copy_m( + gmem_tiled_copy, + tAgA(_, _, _, k_tile), + tAsA(_, _, _, k_pipe), + tApA, + tAcA(_, _, _, k_tile), + residue_mk); + + safe_copy_n( + gmem_tiled_copy, + tBgB(_, _, _, k_tile), + tBsB(_, _, _, k_pipe), + tBcB(_, _, _, k_tile), + residue_nk); + }; + + // GEMM: C = A@B.T + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + // rA: (BLK_M, BLK_K) => (MMA,MMA_M,MMA_K) + auto tCrA = thr_mma.partition_fragment_A(sA(_, _, _0{})); + // rB: (BLK_N, BLK_K) => (MMA,MMA_N,MMA_K) + auto tCrB = thr_mma.partition_fragment_B(sB(_, _, _0{})); + + // s2r tiled copy for A and B + auto smem_tiled_copy_a = SmemTiledCopyA{}; + auto smem_thr_copy_a = smem_tiled_copy_a.get_thread_slice(tidx); + // (BLK_M, BLK_K, PIPE) => (CPY, CPY_M, CPY_K, PIPE) + auto tCsA = smem_thr_copy_a.partition_S(sA); + // (CPY, CPY_M, CPY_K) + auto tCrA_cpv = smem_thr_copy_a.retile_D(tCrA); + + auto smem_tiled_copy_b = SmemTiledCopyB{}; + auto smem_thr_copy_b = smem_tiled_copy_b.get_thread_slice(tidx); + // (BLK_N, BLK_K, PIPE) => (CPY, CPY_N, CPY_K, PIPE) + auto tCsB = smem_thr_copy_b.partition_S(sB); + // (CPY, CPY_N, CPY_K) + auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB); + + // ############### Prologue ############### + // remaining k-tile count + int k_tiles_remaining = size<3>(tAgA); + // next tile index in gmem to read from + int k_tile = 0; + + // async loads for all pipes except the last one + auto kPipe = size<3>(tAsA); + CUTE_UNROLL + for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) { + if (k_pipe == 0) { + produce_ab(k_tile, k_pipe); + } else { + produce_ab_no_oob(k_tile, k_pipe); + } + cp_async_fence(); + + // advance to next k-tile + if (--k_tiles_remaining > 0) { + ++k_tile; + } + } + + // ############### Mainloop ############### + // pipe index in smem to read from + int pipe_read = 0; + // pipe index in smem to write to + int pipe_write = kPipe - 1; + + // pipe to read from: (CPY, CPY_N, CPY_K) + Tensor tCsA_p = tCsA(_, _, _, pipe_read); + Tensor tCsB_p = tCsB(_, _, _, pipe_read); + + // Size of the register pipeline + auto kBlocks = size<2>(tCrA); + + // prefetch register pipeline + if (kBlocks > 1) { + // wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_a, tCsA_p(_, _, _0{}), tCrA_cpv(_, _, _0{})); + cute::copy(smem_tiled_copy_b, tCsB_p(_, _, _0{}), tCrB_cpv(_, _, _0{})); + } + + CUTE_NO_UNROLL + while (k_tiles_remaining > -(kPipe - 1)) { + CUTE_UNROLL + for (int ki = 0; ki < kBlocks; ++ki) { + // first block + if (ki == 0) { + // copy gmem to smem for next pipe + produce_ab_no_oob(k_tile, pipe_write); + cp_async_fence(); + + // advance to next k-tile + if (--k_tiles_remaining > 0) { + ++k_tile; + } + } + // last block + if (ki == kBlocks - 1) { + // advance to next pipe + pipe_write = pipe_read; + pipe_read = (pipe_read == kPipe - 1) ? 0 : pipe_read + 1; + + // advance to next pipe to read from + tCsA_p = tCsA(_, _, _, pipe_read); + tCsB_p = tCsB(_, _, _, pipe_read); + + // wait until our next prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + } + + // prefetch for next ki + auto ki_next = (ki + _1{}) % kBlocks; + copy(smem_tiled_copy_a, tCsA_p(_, _, ki_next), tCrA_cpv(_, _, ki_next)); + copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next)); + + // thread-level gemm for ki + gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrAccC); + } + } + } +}; + +} // namespace llm diff --git a/src/kernels/gemm/sm80_grouped_gemm_dispatch.cuh b/src/kernels/gemm/sm80_grouped_gemm_dispatch.cuh new file mode 100644 index 00000000..815e6190 --- /dev/null +++ b/src/kernels/gemm/sm80_grouped_gemm_dispatch.cuh @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +#include "huggingface/safetensors.h" +#include "sm80_grouped_gemm_launch.cuh" +#include "static_dispatch.h" + +namespace llm { +using namespace cute; + +struct GEMMParams { + using AStride = Stride; + using BStride = Stride; + using CStride = Stride; + + // A: (m, k) + const void* __restrict__ a_ptr = nullptr; + AStride a_stride; + + // B: (e, n, k) + const void* __restrict__ b_ptr = nullptr; + BStride b_stride; + + // C: ((m, topk), n) + void* __restrict__ c_ptr = nullptr; + CStride c_stride; + + // (m_blocks*BLK_M) + const int* __restrict__ sorted_token_idxes_ptr = nullptr; + // (m_blocks) + const int* __restrict__ expert_ids_ptr = nullptr; + + const int* __restrict__ n_tokens_padded = nullptr; + + int m = 0; + int n = 0; + int k = 0; + int topk = 0; + int n_experts = 0; + + int m_blocks = 0; +}; + +// forward declaration +// template +// void sm80_launch_grouped_gemm_kernel(const Params& params, cudaStream_t +// stream); + +// user-facing function to run the attention kernel +template +void sm80_run_grouped_gemm(Params& params, cudaStream_t stream = nullptr) { + // TODO: tune block shape MNK based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| + // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | + // valid dynamic shared memory sizes for different compute capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 + constexpr int BLK_M = 64; + constexpr int BLK_N = 64; + constexpr int BLK_K = 64; + constexpr int Stages = 2; + + using TileShape = Shape, Int, Int>; + + // dispatch to proper kernel instantiation based on params + DISPATCH_BOOL((params.n % BLK_N) == 0, EVEN_N, [&] { + DISPATCH_BOOL((params.k % BLK_K) == 0, EVEN_K, [&] { + sm80_launch_grouped_gemm_kernel(params, stream); + }); + }); +} + +} // namespace llm diff --git a/src/kernels/gemm/sm80_grouped_gemm_launch.cuh b/src/kernels/gemm/sm80_grouped_gemm_launch.cuh new file mode 100644 index 00000000..de1929f4 --- /dev/null +++ b/src/kernels/gemm/sm80_grouped_gemm_launch.cuh @@ -0,0 +1,71 @@ +#pragma once + +#include +#include + +#include +#include + +#include "sm80_collective_epilogue.cuh" +#include "sm80_collective_grouped_gemm.cuh" +#include "sm80_kernel_grouped_gemm.cuh" +#include "tile_scheduler.cuh" + +namespace llm { + +namespace detail { +/// Generic kernel template. +template +__global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel( + __grid_constant__ const Params params, + __grid_constant__ const typename Operator::TileSchedulerParams + scheduler_params) { + extern __shared__ char smem[]; + Operator op; + op(params, scheduler_params, smem); +} +} // namespace detail + +template +void sm80_launch_grouped_gemm_kernel(const Params& params, + cudaStream_t stream) { + constexpr int kBlockN = get<1>(TileShape{}); + + using CollectiveMainloop = + Sm80CollectiveGroupedGEMM; + + using CollectiveEpilogue = Sm80CollectiveEpilogue; + + // TODO: support persistent kernels + using TileScheduler = SingleTileScheduler; + + const auto n_blocks = cute::ceil_div(params.n, kBlockN); + typename TileScheduler::Arguments scheduler_args{params.m_blocks, n_blocks}; + auto scheduler_params = + TileScheduler::to_underlying_arguments(scheduler_args); + + using GEMMKernel = Sm80KernelGroupedGEMM; + + auto gemm_kernel = detail::device_kernel; + + const auto smem_size = GEMMKernel::kSharedStorageSize; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + + const dim3 grid = GEMMKernel::get_grid_shape(scheduler_args); + const dim3 block = GEMMKernel::get_block_shape(); + + gemm_kernel<<>>(params, scheduler_params); + // TODO: check launch status +} + +} // namespace llm diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/sm80_grouped_gemm_test.cu similarity index 91% rename from src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu rename to src/kernels/gemm/sm80_grouped_gemm_test.cu index 2fad39a5..0503d5a4 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/sm80_grouped_gemm_test.cu @@ -3,7 +3,7 @@ #include -#include "grouped_gemm_kernel_sm80.cuh" // IWYU pragma: keep +#include "sm80_grouped_gemm_dispatch.cuh" // IWYU pragma: keep #include "static_dispatch.h" namespace llm { @@ -68,8 +68,8 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) ) { const auto m = a.size(0); const auto k = a.size(1); - const auto n = w.size(1); const auto n_experts = w.size(0); + const auto n = w.size(1); const auto topk = topk_ids.size(1); // construct aligned @@ -95,24 +95,11 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) params.n = n; params.k = k; params.topk = topk; - - constexpr int BLK_M = 64; - constexpr int BLK_N = 64; - constexpr int BLK_K = 64; - constexpr int PIPE = 2; - + params.n_experts = n_experts; params.m_blocks = expert_ids.size(0); - params.n_blocks = cute::ceil_div(n, BLK_N); - - DISPATCH_TORCH_DTYPE(a.dtype(), DTYPE, [&] { - DISPATCH_BOOL((n % BLK_N) == 0, EVEN_N, [&] { - DISPATCH_BOOL((k % BLK_K) == 0, EVEN_K, [&] { - using Traits = GEMMTraitsSM80; - launch_grouped_gemm_kernel_sm80(params, - nullptr); - }); - }); - }); + + DISPATCH_TORCH_DTYPE( + a.dtype(), DTYPE, [&] { sm80_run_grouped_gemm(params); }); // (m * topk, n) => (m, topk, n) return out.view({m, topk, n}); diff --git a/src/kernels/gemm/sm80_kernel_grouped_gemm.cuh b/src/kernels/gemm/sm80_kernel_grouped_gemm.cuh new file mode 100644 index 00000000..b21e060d --- /dev/null +++ b/src/kernels/gemm/sm80_kernel_grouped_gemm.cuh @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include + +#include "gather_tensor.hpp" + +namespace llm { + +using namespace cute; + +template +class Sm80KernelGroupedGEMM { + public: + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + using TileScheduler = TileScheduler_; + + using TiledMma = typename CollectiveMainloop::TiledMma; + + using Element = typename CollectiveMainloop::Element; + using BLK_M = typename CollectiveMainloop::BLK_M; + using BLK_N = typename CollectiveMainloop::BLK_N; + using BLK_K = typename CollectiveMainloop::BLK_K; + + static constexpr int kBlockM = CollectiveMainloop::kBlockM; + static constexpr int kBlockN = CollectiveMainloop::kBlockN; + + static constexpr int kSharedStorageSize = + cute::max(sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage)); + + static constexpr int kMmaThreads = CollectiveMainloop::kMmaThreads; + + // Kernel params + using MainloopParams = typename CollectiveMainloop::Params; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileSchedulerParams = typename TileScheduler::Params; + + // returns grid and block shape for kernel launch + using TileSchedulerArgs = typename TileScheduler::Arguments; + static dim3 get_grid_shape(TileSchedulerArgs const& args) { + return TileScheduler::get_grid_shape(args); + } + static dim3 get_block_shape() { return kMmaThreads; } + + template + CUTE_DEVICE void operator()(const Params& params, + const TileSchedulerParams& scheduler_params, + char* smem) { + CollectiveMainloop gemm; + CollectiveEpilogue epilogue; + TileScheduler scheduler(scheduler_params); + + // ProblemShape + const auto M = kBlockM * params.m_blocks; + const auto N = params.n; + const auto K = params.k; + const auto topk = params.topk; + const auto n_experts = params.n_experts; + const auto n_flatten_tokens = params.m * topk; + + const auto residue_mnk = make_tuple(M, N, K); + + const int* sorted_token_idxes = params.sorted_token_idxes_ptr; + auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) { + return sorted_token_idxes[idx] / topk; + }; + // A: (M, K), k-major + auto A = make_gather_tensor(make_gmem_ptr((const Element*)params.a_ptr), + make_shape(M, K), + params.a_stride, + idx_to_t_idx); + // (M, K) => (BLK_M, BLK_K, m, k) + Tensor gA_t = local_tile(A, Shape{}, make_coord(_, _)); + // (BLK_M, BLK_K, m, k) => (M, K) + Tensor cA_t = local_tile(make_identity_tensor(make_shape(M, K)), + Shape{}, + make_coord(_, _)); + + // B: (E, N, K), k-major + auto B = make_tensor(make_gmem_ptr((const Element*)params.b_ptr), + make_shape(n_experts, N, K), + params.b_stride); + // (E, N, K) => (_1, BLK_N, BLK_K, e, n, k) + Tensor gB_t = local_tile(B, Shape<_1, BLK_N, BLK_K>{}, make_coord(_, _, _)); + // (BLK_N, BLK_K, n, k) => (N, K) + Tensor cB_t = local_tile(make_identity_tensor(make_shape(N, K)), + Shape{}, + make_coord(_, _)); + + // C: (M, N), n-major + auto idx_to_f_idx = [sorted_token_idxes](int idx) { + return sorted_token_idxes[idx]; + }; + auto C = make_gather_tensor(make_gmem_ptr((Element*)params.c_ptr), + make_shape(M, N), + params.c_stride, + idx_to_f_idx); + // (M, N) => (BLK_M, BLK_N, m, n) + Tensor gC_t = local_tile(C, Shape{}, make_coord(_, _)); + // (BLK_M, BLK_N, m, n) => (M, N) + Tensor cC_t = local_tile(make_identity_tensor(make_shape(M, N)), + Shape{}, + make_coord(_, _)); + + // construct params + MainloopParams mainloop_params{params.sorted_token_idxes_ptr, + n_flatten_tokens}; + EpilogueParams epilogue_params{params.sorted_token_idxes_ptr, + n_flatten_tokens}; + + // process each block + for (const auto block_coord : scheduler) { + // block coord: (batch_idx, m_block_idx, kv_head_idx) + const auto [m_block_idx, n_block_idx] = block_coord; + const auto tidx = threadIdx.x; + const int expert_id = params.expert_ids_ptr[m_block_idx]; + + // (BLK_M, BLK_K, m, k) => (BLK_M, BLK_K, k) + auto gA = gA_t(_, _, m_block_idx, _); + auto cA = cA_t(_, _, m_block_idx, _); + // (_1, BLK_N, BLK_K, e, n, k) => (BLK_N, BLK_K, k) + auto gB = gB_t(_0{}, _, _, expert_id, n_block_idx, _); + // (BLK_N, BLK_K, n, k) => (BLK_N, BLK_K, k) + auto cB = cB_t(_, _, n_block_idx, _); + // (BLK_M, BLK_N, m, n) => (BLK_M, BLK_N) + auto gC = gC_t(_, _, m_block_idx, n_block_idx); + auto cC = cC_t(_, _, m_block_idx, n_block_idx); + + TiledMma tiled_mma; + // (BLK_M, BLK_N) => (MMA, MMA_M, MMA_N) + auto tCrAccC = partition_fragment_C(tiled_mma, Shape{}); + cute::clear(tCrAccC); // Clear the accumulator + + // mainloop + gemm(mainloop_params, gA, cA, gB, cB, tCrAccC, tidx, residue_mnk, smem); + + // epilogue + epilogue( + epilogue_params, tCrAccC, tiled_mma, gC, cC, tidx, residue_mnk, smem); + } + } +}; + +} // namespace llm diff --git a/src/kernels/gemm/tile_scheduler.cuh b/src/kernels/gemm/tile_scheduler.cuh new file mode 100644 index 00000000..5388e5aa --- /dev/null +++ b/src/kernels/gemm/tile_scheduler.cuh @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include +#include + +namespace llm { + +class SingleTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + int m_blocks = 0; + int n_blocks = 0; + }; + static dim3 get_grid_shape(Arguments const& args) { + return {(uint32_t)args.m_blocks, (uint32_t)args.n_blocks}; + } + + // Device side kernel params + using Params = Arguments; + static Params to_underlying_arguments(const Arguments& args) { return args; } + + // End Iterator tag + class EndIterator {}; + class Iterator { + public: + CUTE_DEVICE + Iterator() = default; + + CUTE_DEVICE + cute::tuple operator*() const { return {blockIdx.x, blockIdx.y}; } + + CUTE_DEVICE + Iterator& operator++() { + valid_ = false; + return *this; + } + + // compare against end iterator + CUTE_DEVICE + bool operator!=(const EndIterator&) const { return valid_; } + + private: + bool valid_ = true; + }; + + CUTE_DEVICE + SingleTileScheduler(const Params& params) {} + + CUTE_DEVICE + Iterator begin() const { return {}; } + + CUTE_DEVICE + EndIterator end() const { return {}; } +}; + +} // namespace llm