From 3af93da946e6266224923f9f979fdac01c36c052 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 19 Jun 2025 23:33:07 -0700 Subject: [PATCH 1/3] feat: added single tile scheduler for attn kernel --- src/kernels/attention/sm80_kernel_mha.cuh | 154 ++++++++++++---------- src/kernels/attention/sm80_mha_launch.cuh | 21 ++- src/kernels/attention/tile_scheduler.cuh | 64 +++++++++ 3 files changed, 161 insertions(+), 78 deletions(-) create mode 100644 src/kernels/attention/tile_scheduler.cuh diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh index 19d4cb64..adfa2e32 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -13,11 +13,14 @@ namespace llm { using namespace cute; -template +template class Sm80KernelMha { public: using CollectiveMainloop = CollectiveMainloop_; using CollectiveEpilogue = CollectiveEpilogue_; + using TileScheduler = TileScheduler_; using TiledMma = typename CollectiveMainloop::TiledMma; @@ -39,84 +42,89 @@ class Sm80KernelMha { // Kernel params using MainloopParams = typename CollectiveMainloop::Params; using EpilogueParams = typename CollectiveEpilogue::Params; + using TileSchedulerParams = typename TileScheduler::Params; template - CUTE_DEVICE void operator()(const Params& params, char* smem) { + CUTE_DEVICE void operator()(const Params& params, + const TileSchedulerParams& scheduler_params, + char* smem) { CollectiveMainloop mha; CollectiveEpilogue epilogue; - const auto tidx = threadIdx.x; - - // block coord - const int m_block_idx = blockIdx.x; - const int batch_idx = blockIdx.y; - const int kv_head_idx = blockIdx.z; - auto block_coord_mnk = make_coord(m_block_idx, batch_idx, kv_head_idx); - - // (q_packed_len, HEAD_DIM) - MHATile tile(params, batch_idx, kv_head_idx); - auto [Q, O] = tile.template get_qo_tile(); - // (kv_len, HEAD_DIM) - auto [K, V] = tile.template get_kv_tile(); - - // problem shape - const int q_packed_len = size<0>(Q); - const int kv_len = size<0>(K); - const int head_dim = params.head_dim; - auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); - - if (m_block_idx * kBlockM >= q_packed_len) { - // m out of bound, return - return; + TileScheduler scheduler(scheduler_params); + + for (auto work_tile = scheduler.get_initial_work(); work_tile.valid(); + work_tile = scheduler.get_next_work(work_tile)) { + // block coord: (batch_idx, m_block_idx, kv_head_idx) + auto block_coord_mnk = work_tile.get_block_coord(); + auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto tidx = threadIdx.x; + + // (q_packed_len, HEAD_DIM) + MHATile tile(params, batch_idx, kv_head_idx); + auto [Q, O] = tile.template get_qo_tile(); + // (kv_len, HEAD_DIM) + auto [K, V] = tile.template get_kv_tile(); + + // problem shape + const int q_packed_len = size<0>(Q); + const int kv_len = size<0>(K); + const int head_dim = params.head_dim; + auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); + + if (m_block_idx * kBlockM >= q_packed_len) { + // m out of bound, skip this block + continue; + } + + // (BLK_M, HEAD_DIM) + Tensor gQ = local_tile( + Q, Shape{}, make_coord(m_block_idx, _0{})); + Tensor gO = local_tile( + O, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_N, HEAD_DIM, n) + Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); + Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + + // construct params + MainloopParams mainloop_params{params.sliding_window, + params.logits_soft_cap, + params.sm_scale, + params.sm_scale_log2, + params.alibi_slopes_ptr, + params.group_size}; + EpilogueParams epilogue_params; + + TiledMma tiled_mma; + // accumulator: MMA,MMA_M,MMA_K) + auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); + clear(tOrAccO); + + constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO); + OnlineSoftmax softmax(params.sm_scale_log2); + + // mainloop + mha(mainloop_params, + gQ, + gK, + gV, + tOrAccO, + softmax, + tidx, + block_coord_mnk, + problem_shape_mnk, + smem); + + // epilogue + epilogue(epilogue_params, + tOrAccO, + tiled_mma, + gO, + tidx, + block_coord_mnk, + problem_shape_mnk, + smem); } - - // (BLK_M, HEAD_DIM) - Tensor gQ = - local_tile(Q, Shape{}, make_coord(m_block_idx, _0{})); - Tensor gO = - local_tile(O, Shape{}, make_coord(m_block_idx, _0{})); - // (BLK_N, HEAD_DIM, n) - Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); - Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); - - // construct params - MainloopParams mainloop_params{params.sliding_window, - params.logits_soft_cap, - params.sm_scale, - params.sm_scale_log2, - params.alibi_slopes_ptr, - params.group_size}; - EpilogueParams epilogue_params; - - TiledMma tiled_mma; - // accumulator: MMA,MMA_M,MMA_K) - auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); - clear(tOrAccO); - - constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO); - OnlineSoftmax softmax(params.sm_scale_log2); - - // mainloop - mha(mainloop_params, - gQ, - gK, - gV, - tOrAccO, - softmax, - tidx, - block_coord_mnk, - problem_shape_mnk, - smem); - - // epilogue - epilogue(epilogue_params, - tOrAccO, - tiled_mma, - gO, - tidx, - block_coord_mnk, - problem_shape_mnk, - smem); } }; diff --git a/src/kernels/attention/sm80_mha_launch.cuh b/src/kernels/attention/sm80_mha_launch.cuh index 791b22a6..61944ae0 100644 --- a/src/kernels/attention/sm80_mha_launch.cuh +++ b/src/kernels/attention/sm80_mha_launch.cuh @@ -9,6 +9,7 @@ #include "sm80_collective_epilogue.cuh" #include "sm80_collective_mha.cuh" #include "sm80_kernel_mha.cuh" +#include "tile_scheduler.cuh" namespace llm { @@ -16,10 +17,12 @@ namespace detail { /// Generic kernel template. template __global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel( - __grid_constant__ const Params params) { + __grid_constant__ const Params params, + __grid_constant__ const typename Operator::TileSchedulerParams + scheduler_params) { extern __shared__ char smem[]; Operator op; - op(params, smem); + op(params, scheduler_params, smem); } } // namespace detail @@ -60,8 +63,16 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { LOCAL>; using CollectiveEpilogue = Sm80CollectiveEpilogue; + using TileScheduler = SingleTileScheduler; - using AttnKernel = Sm80KernelMha; + const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); + typename TileScheduler::Arguments scheduler_args{ + m_blocks, batch_size, n_kv_heads}; + auto scheduler_params = + TileScheduler::to_underlying_arguments(scheduler_args); + + using AttnKernel = + Sm80KernelMha; auto mha_kernel = detail::device_kernel; @@ -72,10 +83,10 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { } // TODO: support persistent kernels - dim3 grid(cute::ceil_div(max_q_packed_len, BLK_M), batch_size, n_kv_heads); + dim3 grid = TileScheduler::get_grid_dim(scheduler_args); dim3 block = AttnKernel::kMmaThreads; - mha_kernel<<>>(params); + mha_kernel<<>>(params, scheduler_params); // TODO: check launch status } diff --git a/src/kernels/attention/tile_scheduler.cuh b/src/kernels/attention/tile_scheduler.cuh new file mode 100644 index 00000000..705b52ac --- /dev/null +++ b/src/kernels/attention/tile_scheduler.cuh @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include +#include + +namespace llm { + +class SingleTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + int m_blocks = 0; + int batch_size = 0; + int n_kv_heads = 0; + }; + + // Device side kernel params + using Params = Arguments; + + static Params to_underlying_arguments(const Arguments& args) { return args; } + + static dim3 get_grid_dim(Arguments const& args) { + return {(uint32_t)args.m_blocks, + (uint32_t)args.batch_size, + (uint32_t)args.n_kv_heads}; + } + + struct WorkTileInfo { + int m_block_idx = 0; + int batch_idx = 0; + int kv_head_idx = 0; + bool is_valid = false; + + CUTE_DEVICE + bool valid() const { return is_valid; } + + CUTE_DEVICE + auto get_block_coord() const { + return cute::tuple{m_block_idx, batch_idx, kv_head_idx}; + } + }; + + CUTE_DEVICE + SingleTileScheduler(const Params& params) {} + + CUTE_DEVICE + WorkTileInfo get_initial_work() const { + return {(int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z, + /*is_valid_tile*/ true}; + } + + CUTE_DEVICE WorkTileInfo + get_next_work(const WorkTileInfo& /*current_work*/) const { + // no more works + return {0, 0, 0, /*is_valid_tile*/ false}; + } +}; + +} // namespace llm From 6fb531125e32793a3a4ddfae9ed2d29173e5a606 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 20 Jun 2025 00:12:23 -0700 Subject: [PATCH 2/3] add for loop support --- .../attention/sm80_collective_epilogue.cuh | 2 +- src/kernels/attention/sm80_collective_mha.cuh | 2 +- src/kernels/attention/sm80_kernel_mha.cuh | 37 +++++++------ src/kernels/attention/sm80_mha_launch.cuh | 7 +-- src/kernels/attention/tile_scheduler.cuh | 53 +++++++++---------- 5 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/kernels/attention/sm80_collective_epilogue.cuh b/src/kernels/attention/sm80_collective_epilogue.cuh index 1e0a0182..6a00ffc7 100644 --- a/src/kernels/attention/sm80_collective_epilogue.cuh +++ b/src/kernels/attention/sm80_collective_epilogue.cuh @@ -83,7 +83,7 @@ struct Sm80CollectiveEpilogue { char* smem) { static constexpr int kBlockM = get<0>(TileShape{}); - const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; // Smem diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/sm80_collective_mha.cuh index 68d61833..d4413569 100644 --- a/src/kernels/attention/sm80_collective_mha.cuh +++ b/src/kernels/attention/sm80_collective_mha.cuh @@ -169,7 +169,7 @@ struct Sm80CollectiveMha { static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); - const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; const int sliding_window = LOCAL ? params.sliding_window : kv_len; diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh index adfa2e32..954bf40b 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -44,20 +44,34 @@ class Sm80KernelMha { using EpilogueParams = typename CollectiveEpilogue::Params; using TileSchedulerParams = typename TileScheduler::Params; + // returns grid and block shape for kernel launch + using TileSchedulerArgs = typename TileScheduler::Arguments; + static dim3 get_grid_shape(TileSchedulerArgs const& args) { + return TileScheduler::get_grid_shape(args); + } + static dim3 get_block_shape() { return kMmaThreads; } + template CUTE_DEVICE void operator()(const Params& params, const TileSchedulerParams& scheduler_params, char* smem) { CollectiveMainloop mha; CollectiveEpilogue epilogue; - TileScheduler scheduler(scheduler_params); - for (auto work_tile = scheduler.get_initial_work(); work_tile.valid(); - work_tile = scheduler.get_next_work(work_tile)) { + // construct params + MainloopParams mainloop_params{params.sliding_window, + params.logits_soft_cap, + params.sm_scale, + params.sm_scale_log2, + params.alibi_slopes_ptr, + params.group_size}; + EpilogueParams epilogue_params; + + // process each block + for (const auto block_coord : scheduler) { // block coord: (batch_idx, m_block_idx, kv_head_idx) - auto block_coord_mnk = work_tile.get_block_coord(); - auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord; const auto tidx = threadIdx.x; // (q_packed_len, HEAD_DIM) @@ -86,15 +100,6 @@ class Sm80KernelMha { Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); - // construct params - MainloopParams mainloop_params{params.sliding_window, - params.logits_soft_cap, - params.sm_scale, - params.sm_scale_log2, - params.alibi_slopes_ptr, - params.group_size}; - EpilogueParams epilogue_params; - TiledMma tiled_mma; // accumulator: MMA,MMA_M,MMA_K) auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); @@ -111,7 +116,7 @@ class Sm80KernelMha { tOrAccO, softmax, tidx, - block_coord_mnk, + block_coord, problem_shape_mnk, smem); @@ -121,7 +126,7 @@ class Sm80KernelMha { tiled_mma, gO, tidx, - block_coord_mnk, + block_coord, problem_shape_mnk, smem); } diff --git a/src/kernels/attention/sm80_mha_launch.cuh b/src/kernels/attention/sm80_mha_launch.cuh index 61944ae0..675070b9 100644 --- a/src/kernels/attention/sm80_mha_launch.cuh +++ b/src/kernels/attention/sm80_mha_launch.cuh @@ -63,6 +63,8 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { LOCAL>; using CollectiveEpilogue = Sm80CollectiveEpilogue; + + // TODO: support persistent kernels using TileScheduler = SingleTileScheduler; const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); @@ -82,9 +84,8 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - // TODO: support persistent kernels - dim3 grid = TileScheduler::get_grid_dim(scheduler_args); - dim3 block = AttnKernel::kMmaThreads; + const dim3 grid = AttnKernel::get_grid_shape(scheduler_args); + const dim3 block = AttnKernel::get_block_shape(); mha_kernel<<>>(params, scheduler_params); // TODO: check launch status diff --git a/src/kernels/attention/tile_scheduler.cuh b/src/kernels/attention/tile_scheduler.cuh index 705b52ac..75a2082d 100644 --- a/src/kernels/attention/tile_scheduler.cuh +++ b/src/kernels/attention/tile_scheduler.cuh @@ -12,53 +12,52 @@ class SingleTileScheduler { public: // Host side kernel arguments struct Arguments { - int m_blocks = 0; int batch_size = 0; + int m_blocks = 0; int n_kv_heads = 0; }; + static dim3 get_grid_shape(Arguments const& args) { + return {(uint32_t)args.batch_size, + (uint32_t)args.m_blocks, + (uint32_t)args.n_kv_heads}; + } // Device side kernel params using Params = Arguments; - static Params to_underlying_arguments(const Arguments& args) { return args; } - static dim3 get_grid_dim(Arguments const& args) { - return {(uint32_t)args.m_blocks, - (uint32_t)args.batch_size, - (uint32_t)args.n_kv_heads}; - } - - struct WorkTileInfo { - int m_block_idx = 0; - int batch_idx = 0; - int kv_head_idx = 0; - bool is_valid = false; + // End Iterator tag + class EndIterator {}; + class Iterator { + public: + CUTE_DEVICE + Iterator() = default; CUTE_DEVICE - bool valid() const { return is_valid; } + dim3 operator*() const { return blockIdx; } CUTE_DEVICE - auto get_block_coord() const { - return cute::tuple{m_block_idx, batch_idx, kv_head_idx}; + Iterator& operator++() { + valid_ = false; + return *this; } + + // compare against end iterator + CUTE_DEVICE + bool operator!=(const EndIterator&) const { return valid_; } + + private: + bool valid_ = true; }; CUTE_DEVICE SingleTileScheduler(const Params& params) {} CUTE_DEVICE - WorkTileInfo get_initial_work() const { - return {(int)blockIdx.x, - (int)blockIdx.y, - (int)blockIdx.z, - /*is_valid_tile*/ true}; - } + Iterator begin() const { return {}; } - CUTE_DEVICE WorkTileInfo - get_next_work(const WorkTileInfo& /*current_work*/) const { - // no more works - return {0, 0, 0, /*is_valid_tile*/ false}; - } + CUTE_DEVICE + EndIterator end() const { return {}; } }; } // namespace llm From d119b338f6742f963da8376c75870245954f9629 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 20 Jun 2025 00:18:08 -0700 Subject: [PATCH 3/3] fix scheduler args --- src/kernels/attention/sm80_mha_launch.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/attention/sm80_mha_launch.cuh b/src/kernels/attention/sm80_mha_launch.cuh index 675070b9..f26f71e8 100644 --- a/src/kernels/attention/sm80_mha_launch.cuh +++ b/src/kernels/attention/sm80_mha_launch.cuh @@ -69,7 +69,7 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); typename TileScheduler::Arguments scheduler_args{ - m_blocks, batch_size, n_kv_heads}; + batch_size, m_blocks, n_kv_heads}; auto scheduler_params = TileScheduler::to_underlying_arguments(scheduler_args);