diff --git a/src/kernels/attention/sm80_collective_epilogue.cuh b/src/kernels/attention/sm80_collective_epilogue.cuh index 1e0a0182..6a00ffc7 100644 --- a/src/kernels/attention/sm80_collective_epilogue.cuh +++ b/src/kernels/attention/sm80_collective_epilogue.cuh @@ -83,7 +83,7 @@ struct Sm80CollectiveEpilogue { char* smem) { static constexpr int kBlockM = get<0>(TileShape{}); - const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; // Smem diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/sm80_collective_mha.cuh index 68d61833..d4413569 100644 --- a/src/kernels/attention/sm80_collective_mha.cuh +++ b/src/kernels/attention/sm80_collective_mha.cuh @@ -169,7 +169,7 @@ struct Sm80CollectiveMha { static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); - const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; const int sliding_window = LOCAL ? params.sliding_window : kv_len; diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh index 19d4cb64..954bf40b 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -13,11 +13,14 @@ namespace llm { using namespace cute; -template +template class Sm80KernelMha { public: using CollectiveMainloop = CollectiveMainloop_; using CollectiveEpilogue = CollectiveEpilogue_; + using TileScheduler = TileScheduler_; using TiledMma = typename CollectiveMainloop::TiledMma; @@ -39,45 +42,22 @@ class Sm80KernelMha { // Kernel params using MainloopParams = typename CollectiveMainloop::Params; using EpilogueParams = typename CollectiveEpilogue::Params; + using TileSchedulerParams = typename TileScheduler::Params; + + // returns grid and block shape for kernel launch + using TileSchedulerArgs = typename TileScheduler::Arguments; + static dim3 get_grid_shape(TileSchedulerArgs const& args) { + return TileScheduler::get_grid_shape(args); + } + static dim3 get_block_shape() { return kMmaThreads; } template - CUTE_DEVICE void operator()(const Params& params, char* smem) { + CUTE_DEVICE void operator()(const Params& params, + const TileSchedulerParams& scheduler_params, + char* smem) { CollectiveMainloop mha; CollectiveEpilogue epilogue; - - const auto tidx = threadIdx.x; - - // block coord - const int m_block_idx = blockIdx.x; - const int batch_idx = blockIdx.y; - const int kv_head_idx = blockIdx.z; - auto block_coord_mnk = make_coord(m_block_idx, batch_idx, kv_head_idx); - - // (q_packed_len, HEAD_DIM) - MHATile tile(params, batch_idx, kv_head_idx); - auto [Q, O] = tile.template get_qo_tile(); - // (kv_len, HEAD_DIM) - auto [K, V] = tile.template get_kv_tile(); - - // problem shape - const int q_packed_len = size<0>(Q); - const int kv_len = size<0>(K); - const int head_dim = params.head_dim; - auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); - - if (m_block_idx * kBlockM >= q_packed_len) { - // m out of bound, return - return; - } - - // (BLK_M, HEAD_DIM) - Tensor gQ = - local_tile(Q, Shape{}, make_coord(m_block_idx, _0{})); - Tensor gO = - local_tile(O, Shape{}, make_coord(m_block_idx, _0{})); - // (BLK_N, HEAD_DIM, n) - Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); - Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + TileScheduler scheduler(scheduler_params); // construct params MainloopParams mainloop_params{params.sliding_window, @@ -88,35 +68,68 @@ class Sm80KernelMha { params.group_size}; EpilogueParams epilogue_params; - TiledMma tiled_mma; - // accumulator: MMA,MMA_M,MMA_K) - auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); - clear(tOrAccO); - - constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO); - OnlineSoftmax softmax(params.sm_scale_log2); - - // mainloop - mha(mainloop_params, - gQ, - gK, - gV, - tOrAccO, - softmax, - tidx, - block_coord_mnk, - problem_shape_mnk, - smem); - - // epilogue - epilogue(epilogue_params, - tOrAccO, - tiled_mma, - gO, - tidx, - block_coord_mnk, - problem_shape_mnk, - smem); + // process each block + for (const auto block_coord : scheduler) { + // block coord: (batch_idx, m_block_idx, kv_head_idx) + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord; + const auto tidx = threadIdx.x; + + // (q_packed_len, HEAD_DIM) + MHATile tile(params, batch_idx, kv_head_idx); + auto [Q, O] = tile.template get_qo_tile(); + // (kv_len, HEAD_DIM) + auto [K, V] = tile.template get_kv_tile(); + + // problem shape + const int q_packed_len = size<0>(Q); + const int kv_len = size<0>(K); + const int head_dim = params.head_dim; + auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); + + if (m_block_idx * kBlockM >= q_packed_len) { + // m out of bound, skip this block + continue; + } + + // (BLK_M, HEAD_DIM) + Tensor gQ = local_tile( + Q, Shape{}, make_coord(m_block_idx, _0{})); + Tensor gO = local_tile( + O, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_N, HEAD_DIM, n) + Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); + Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + + TiledMma tiled_mma; + // accumulator: MMA,MMA_M,MMA_K) + auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); + clear(tOrAccO); + + constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO); + OnlineSoftmax softmax(params.sm_scale_log2); + + // mainloop + mha(mainloop_params, + gQ, + gK, + gV, + tOrAccO, + softmax, + tidx, + block_coord, + problem_shape_mnk, + smem); + + // epilogue + epilogue(epilogue_params, + tOrAccO, + tiled_mma, + gO, + tidx, + block_coord, + problem_shape_mnk, + smem); + } } }; diff --git a/src/kernels/attention/sm80_mha_launch.cuh b/src/kernels/attention/sm80_mha_launch.cuh index 791b22a6..f26f71e8 100644 --- a/src/kernels/attention/sm80_mha_launch.cuh +++ b/src/kernels/attention/sm80_mha_launch.cuh @@ -9,6 +9,7 @@ #include "sm80_collective_epilogue.cuh" #include "sm80_collective_mha.cuh" #include "sm80_kernel_mha.cuh" +#include "tile_scheduler.cuh" namespace llm { @@ -16,10 +17,12 @@ namespace detail { /// Generic kernel template. template __global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel( - __grid_constant__ const Params params) { + __grid_constant__ const Params params, + __grid_constant__ const typename Operator::TileSchedulerParams + scheduler_params) { extern __shared__ char smem[]; Operator op; - op(params, smem); + op(params, scheduler_params, smem); } } // namespace detail @@ -61,7 +64,17 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { using CollectiveEpilogue = Sm80CollectiveEpilogue; - using AttnKernel = Sm80KernelMha; + // TODO: support persistent kernels + using TileScheduler = SingleTileScheduler; + + const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); + typename TileScheduler::Arguments scheduler_args{ + batch_size, m_blocks, n_kv_heads}; + auto scheduler_params = + TileScheduler::to_underlying_arguments(scheduler_args); + + using AttnKernel = + Sm80KernelMha; auto mha_kernel = detail::device_kernel; @@ -71,11 +84,10 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - // TODO: support persistent kernels - dim3 grid(cute::ceil_div(max_q_packed_len, BLK_M), batch_size, n_kv_heads); - dim3 block = AttnKernel::kMmaThreads; + const dim3 grid = AttnKernel::get_grid_shape(scheduler_args); + const dim3 block = AttnKernel::get_block_shape(); - mha_kernel<<>>(params); + mha_kernel<<>>(params, scheduler_params); // TODO: check launch status } diff --git a/src/kernels/attention/tile_scheduler.cuh b/src/kernels/attention/tile_scheduler.cuh new file mode 100644 index 00000000..75a2082d --- /dev/null +++ b/src/kernels/attention/tile_scheduler.cuh @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +#include +#include + +namespace llm { + +class SingleTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + int batch_size = 0; + int m_blocks = 0; + int n_kv_heads = 0; + }; + static dim3 get_grid_shape(Arguments const& args) { + return {(uint32_t)args.batch_size, + (uint32_t)args.m_blocks, + (uint32_t)args.n_kv_heads}; + } + + // Device side kernel params + using Params = Arguments; + static Params to_underlying_arguments(const Arguments& args) { return args; } + + // End Iterator tag + class EndIterator {}; + class Iterator { + public: + CUTE_DEVICE + Iterator() = default; + + CUTE_DEVICE + dim3 operator*() const { return blockIdx; } + + CUTE_DEVICE + Iterator& operator++() { + valid_ = false; + return *this; + } + + // compare against end iterator + CUTE_DEVICE + bool operator!=(const EndIterator&) const { return valid_; } + + private: + bool valid_ = true; + }; + + CUTE_DEVICE + SingleTileScheduler(const Params& params) {} + + CUTE_DEVICE + Iterator begin() const { return {}; } + + CUTE_DEVICE + EndIterator end() const { return {}; } +}; + +} // namespace llm