diff --git a/src/kernels/attention/collective/sm120_collective_epilogue.cuh b/src/kernels/attention/collective/sm120_collective_epilogue.cuh index 8ce427ec..3543db08 100644 --- a/src/kernels/attention/collective/sm120_collective_epilogue.cuh +++ b/src/kernels/attention/collective/sm120_collective_epilogue.cuh @@ -53,7 +53,12 @@ struct Sm120CollectiveEpilogue { using Params = Arguments; // Convert host side arguments to device side params - static Params to_underlying_arguments(Arguments const& args) { return args; } + template + static Params to_underlying_arguments(const ProblemShape& /*problem_shape*/, + const Arguments& args, + void* /*workspace*/) { + return args; + } template CUTE_DEVICE void operator()(const Params& /*params*/, diff --git a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh index ba5e8f39..e747c913 100644 --- a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh @@ -124,6 +124,18 @@ struct Sm120CollectiveFMhaWs { // Host side arguments struct Arguments { + // sliding window attention + int sliding_window; + // softcap + float logits_soft_cap; + // softmax scale + float sm_scale; + // alibi slopes pointer, [n_heads] + const float* alibi_slopes_ptr; + }; + + // Device side params + struct Params { // sliding window attention int sliding_window; // softcap @@ -132,20 +144,37 @@ struct Sm120CollectiveFMhaWs { float sm_scale; // softmax scale in log2 float sm_scale_log2; - // group size - const FastDivmod& group_size; - // alibi slopes pointer const float* alibi_slopes_ptr; }; - // Device side params - using Params = Arguments; - // Convert host side arguments to device side params - static Params to_underlying_arguments(Arguments const& args) { - // no convertion needed. - return args; + template + static Params to_underlying_arguments(const ProblemShape& /*problem_shape*/, + const Arguments& args, + void* /*workspace*/) { + float sm_scale = args.sm_scale; + float logits_soft_cap = args.logits_soft_cap; + if (logits_soft_cap != 0.0) { + // Softmax(x * sm_scale) + apply_logits_soft_cap + // => Softmax(Tanh(x * sm_scale / soft_cap) * soft_cap) + // => Softmax(S' * sm_scale') where + // S' = Tanh(x * sm_scale / soft_cap) + // = Tanh(x * soft_cap') + // soft_cap' = sm_scale / soft_cap + // sm_scale' = soft_cap + const auto sm_scale_hat = logits_soft_cap; + logits_soft_cap = sm_scale / logits_soft_cap; + sm_scale = sm_scale_hat; + } + const float sm_scale_log2 = static_cast(sm_scale * M_LOG2E); + return { + .sliding_window = args.sliding_window, + .logits_soft_cap = logits_soft_cap, + .sm_scale = sm_scale, + .sm_scale_log2 = sm_scale_log2, + .alibi_slopes_ptr = args.alibi_slopes_ptr, + }; } // load Q/K/V from gmem to smem @@ -193,6 +222,7 @@ struct Sm120CollectiveFMhaWs { const auto q_packed_len = block.get_packed_len(); const auto q_len = block.get_q_len(); const auto kv_len = block.get_kv_len(); + const auto& group_size = block.get_group_size(); // Construct smem tensors // (BLK_M, BLK_K), k-major @@ -335,7 +365,7 @@ struct Sm120CollectiveFMhaWs { // Create softmax and mask OnlineSoftmax softmax(params.sm_scale_log2); Mask mask( - q_len, kv_len, params.group_size, params.sliding_window); + q_len, kv_len, group_size, params.sliding_window); if constexpr (kAlibi) { const auto tScS_mn = tScMN_mn(_, _, _0{}); mask.init_alibi( diff --git a/src/kernels/attention/common/fmha_block.h b/src/kernels/attention/common/fmha_block.h index 8d11f818..ffb4e450 100644 --- a/src/kernels/attention/common/fmha_block.h +++ b/src/kernels/attention/common/fmha_block.h @@ -1,9 +1,10 @@ #pragma once +#include #include #include -#include "cute/config.hpp" +#include "common/fast_math.h" #include "gather_tensor.h" namespace llm { @@ -11,11 +12,82 @@ namespace llm { using namespace cute; // AttentionTile specialization for AttentionParams -template struct FmhaBlock { + // (B, Q, H, D) + using StrideQ = Stride; + using StrideO = StrideQ; + // (B, K, KH, D) + using StrideK = Stride; + using StrideV = StrideK; + + // Host side parameters + + struct Arguments { + const void* __restrict__ q_ptr; + const void* __restrict__ k_ptr; + const void* __restrict__ v_ptr; + void* __restrict__ o_ptr; + + StrideQ q_stride; + StrideK k_stride; + StrideV v_stride; + StrideO o_stride; + + int sliding_window = -1; // -1 means no sliding window + }; + + // Device side parameters + struct Params { + const void* __restrict__ q_ptr; + const void* __restrict__ k_ptr; + const void* __restrict__ v_ptr; + void* __restrict__ o_ptr; + + StrideQ q_stride; + StrideK k_stride; + StrideV v_stride; + StrideO o_stride; + + int sliding_window; + + // Parameters from problem shape + int q_len; + int kv_len; + int head_dim; + FastDivmod group_size; + }; + + template + static Params to_underlying_arguments(const ProblemShape& problem_shape, + const Arguments& args, + void* workspace = nullptr) { + // ProblemShape: (Q, K, D, ((KH, G), B)) + const int q_len = size<0>(problem_shape); + const int kv_len = size<1>(problem_shape); + const int head_dim = size<2>(problem_shape); + const int group_size = size<3, 0, 1>(problem_shape); + + // TODO: construct tma_load for k/v tensors + return { + .q_ptr = args.q_ptr, + .k_ptr = args.k_ptr, + .v_ptr = args.v_ptr, + .o_ptr = args.o_ptr, + .q_stride = args.q_stride, + .k_stride = args.k_stride, + .v_stride = args.v_stride, + .o_stride = args.o_stride, + .sliding_window = args.sliding_window, + .q_len = q_len, + .kv_len = kv_len, + .head_dim = head_dim, + .group_size = FastDivmod(group_size), + }; + } + static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); static constexpr int kBlockK = get<2>(TileShape{}); @@ -26,6 +98,7 @@ struct FmhaBlock { // hold a reference to the parameters and block coordination const Params& params_; + // TODO: (m_block_idx, (kv_head_idx, batch_idx)) // (batch_idx, m_block_idx, kv_head_idx) const tuple& blk_coord_; @@ -33,6 +106,7 @@ struct FmhaBlock { int m_block_base_; int packed_len_; + // Constructor CUTE_HOST_DEVICE FmhaBlock(const Params& params, const tuple& blk_coord) : params_(params), blk_coord_(blk_coord) { @@ -50,20 +124,26 @@ struct FmhaBlock { // returns packed_len CUTE_HOST_DEVICE int get_packed_len() const { return packed_len_; } - // returns q_len + // returns actual query length CUTE_HOST_DEVICE int get_q_len() const { return params_.q_len; } - // returns kv_len + // returns actual kv length CUTE_HOST_DEVICE int get_kv_len() const { return params_.kv_len; } // returns head_dim CUTE_HOST_DEVICE int get_head_dim() const { return params_.head_dim; } + // returns group size + CUTE_HOST_DEVICE const FastDivmod& get_group_size() const { + return params_.group_size; + } + // returns redidue mnk CUTE_HOST_DEVICE auto get_residue_mnk() const { return make_tuple(packed_len_, params_.kv_len, params_.head_dim); } + // returns (m_block_idx, (kv_head_idx, batch_idx)) // return (batch_idx, m_block_idx, kv_head_idx) CUTE_HOST_DEVICE const auto& get_block_coord() const { return blk_coord_; } diff --git a/src/kernels/attention/common/tile_scheduler.cuh b/src/kernels/attention/common/tile_scheduler.cuh index b155c5ad..42941066 100644 --- a/src/kernels/attention/common/tile_scheduler.cuh +++ b/src/kernels/attention/common/tile_scheduler.cuh @@ -7,24 +7,36 @@ #include namespace llm { - +using namespace cute; class SingleTileScheduler { public: - // Host side kernel arguments - struct Arguments { + // Device side kernel arguments + struct Params { int batch_size = 0; int m_blocks = 0; int n_kv_heads = 0; }; - static dim3 get_grid_shape(Arguments const& args) { - return {(uint32_t)args.batch_size, - (uint32_t)args.m_blocks, - (uint32_t)args.n_kv_heads}; + + static dim3 get_grid_shape(Params const& params) { + return {(uint32_t)params.batch_size, + (uint32_t)params.m_blocks, + (uint32_t)params.n_kv_heads}; } - // Device side kernel params - using Params = Arguments; - static Params to_underlying_arguments(const Arguments& args) { return args; } + template + static Params to_underlying_arguments(const ProblemShape& problem_shape, + const TileShape& tile_shape) { + // problem_shape: (Q, K, D, ((KH, G), B)) + const int max_q_len = size<0>(problem_shape); + const int n_kv_heads = size<3, 0, 0>(problem_shape); + const int group_size = size<3, 0, 1>(problem_shape); + const int batch_size = size<3, 1>(problem_shape); + + const int max_q_packed_len = max_q_len * group_size; + const int m_blocks = ceil_div(max_q_packed_len, size<0>(tile_shape)); + + return {batch_size, m_blocks, n_kv_heads}; + } // End Iterator tag class EndIterator {}; @@ -34,7 +46,7 @@ class SingleTileScheduler { Iterator() = default; CUTE_DEVICE - cute::tuple operator*() const { + tuple operator*() const { // (batch, m_blocks, kv_heads) return {blockIdx.x, blockIdx.y, blockIdx.z}; } diff --git a/src/kernels/attention/device/sm120_fmha_launch.cuh b/src/kernels/attention/device/sm120_fmha_launch.cuh index a6cefe04..396be6a4 100644 --- a/src/kernels/attention/device/sm120_fmha_launch.cuh +++ b/src/kernels/attention/device/sm120_fmha_launch.cuh @@ -8,6 +8,7 @@ #include "collective/sm120_collective_epilogue.cuh" #include "collective/sm120_collective_fmha_mainloop_ws.cuh" +#include "common/fmha_block.h" #include "common/tile_scheduler.cuh" #include "kernel/sm120_kernel_fmha_ws.cuh" @@ -15,29 +16,23 @@ namespace llm { namespace detail { /// Generic kernel template. -template +template __global__ __launch_bounds__(Operator::kThreadsPerBlock) void device_kernel( - __grid_constant__ const Params params, - __grid_constant__ const typename Operator::TileSchedulerParams - scheduler_params) { + __grid_constant__ const typename Operator::Params params) { extern __shared__ char smem[]; Operator op; - op(params, scheduler_params, smem); + op(params, smem); } } // namespace detail template void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream) { - const auto batch_size = params.batch_size; - const auto n_kv_heads = params.n_kv_heads; - const auto max_q_packed_len = params.max_q_len * params.group_size; - // TODO: tune tile shape M/N based on the head dim and smem size // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| @@ -55,7 +50,14 @@ void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream) { // TMA is used for K/V loading constexpr bool KV_USE_TMA = false; - using TileShape = Shape, Int, Int>; + assert(params.n_heads % params.n_kv_heads == 0 && + "n_heads must be divisible by n_kv_heads"); + const int group_size = params.n_heads / params.n_kv_heads; + + using TileShape = Shape, Int, Int>; + + using Block = FmhaBlock; + using CollectiveMainloop = Sm120CollectiveFMhaWs; - - auto mha_kernel = detail::device_kernel; + // TODO: pass in max_q_len and max_kv_len for variable length + // MNKL: (Q K D ((KH G), B)) + using ProblemShape = + cute::tuple, int>>; + ProblemShape problem_shape = make_tuple( + params.q_len, + params.kv_len, + params.head_dim, + make_tuple(make_tuple(params.n_kv_heads, group_size), params.batch_size)); + + using AttnKernel = Sm120KernelFmhaWs; + + // TODO: convert params to Kernel Args + auto q_stride = make_stride( + params.q_batch_stride, params.q_seq_stride, params.q_head_stride, _1{}); + auto k_stride = make_stride( + params.k_batch_stride, params.k_seq_stride, params.k_head_stride, _1{}); + auto v_stride = make_stride( + params.v_batch_stride, params.v_seq_stride, params.v_head_stride, _1{}); + auto o_stride = make_stride( + params.o_batch_stride, params.o_seq_stride, params.o_head_stride, _1{}); + + typename AttnKernel::Arguments attn_args{ + .problem_shape = problem_shape, + // Block arguments + .block = + { + .q_ptr = params.q_ptr, + .k_ptr = params.k_ptr, + .v_ptr = params.v_ptr, + .o_ptr = params.o_ptr, + .q_stride = q_stride, + .k_stride = k_stride, + .v_stride = v_stride, + .o_stride = o_stride, + .sliding_window = params.sliding_window, + }, + // mainloop arguments + .mainloop = + { + .sliding_window = params.sliding_window, + .logits_soft_cap = params.logits_soft_cap, + .sm_scale = params.sm_scale, + .alibi_slopes_ptr = params.alibi_slopes_ptr, + }, + // epilogue arguments + .epilogue = {}, + }; + + auto attn_params = + AttnKernel::to_underlying_arguments(attn_args, /*workspace*/ nullptr); + + auto mha_kernel = detail::device_kernel; const auto smem_size = AttnKernel::kSharedStorageSize; if (smem_size >= 48 * 1024) { @@ -85,10 +134,10 @@ void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream) { mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - const dim3 grid = AttnKernel::get_grid_shape(scheduler_args); + const dim3 grid = AttnKernel::get_grid_shape(attn_params); const dim3 block = AttnKernel::get_block_shape(); - mha_kernel<<>>(params, scheduler_params); + mha_kernel<<>>(attn_params); // TODO: check launch status } diff --git a/src/kernels/attention/device/sm80_mha_launch.cuh b/src/kernels/attention/device/sm80_mha_launch.cuh index 1c44ac90..0c232c40 100644 --- a/src/kernels/attention/device/sm80_mha_launch.cuh +++ b/src/kernels/attention/device/sm80_mha_launch.cuh @@ -61,11 +61,8 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { using TileScheduler = SingleTileScheduler; const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); - typename TileScheduler::Arguments scheduler_args{ - batch_size, m_blocks, n_kv_heads}; - auto scheduler_params = - TileScheduler::to_underlying_arguments(scheduler_args); - + typename TileScheduler::Params scheduler_params{ + .batch_size = batch_size, .m_blocks = m_blocks, .n_kv_heads = n_kv_heads}; using AttnKernel = Sm80KernelMha; @@ -77,7 +74,7 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - const dim3 grid = AttnKernel::get_grid_shape(scheduler_args); + const dim3 grid = AttnKernel::get_grid_shape(scheduler_params); const dim3 block = AttnKernel::get_block_shape(); mha_kernel<<>>(params, scheduler_params); diff --git a/src/kernels/attention/device/sm80_mla_launch.cuh b/src/kernels/attention/device/sm80_mla_launch.cuh index 33f0aad9..462d2baa 100644 --- a/src/kernels/attention/device/sm80_mla_launch.cuh +++ b/src/kernels/attention/device/sm80_mla_launch.cuh @@ -86,10 +86,8 @@ void sm80_launch_mla_kernel(const Params& params, cudaStream_t stream) { constexpr int BLK_M = get<0>(TileShape{}); const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); - typename TileScheduler::Arguments scheduler_args{ - batch_size, m_blocks, /*n_kv_heads=*/1}; - auto scheduler_params = - TileScheduler::to_underlying_arguments(scheduler_args); + typename TileScheduler::Params scheduler_params{ + .batch_size = batch_size, .m_blocks = m_blocks, .n_kv_heads = 1}; using AttnKernel = Sm80KernelMla; @@ -102,7 +100,7 @@ void sm80_launch_mla_kernel(const Params& params, cudaStream_t stream) { mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - const dim3 grid = AttnKernel::get_grid_shape(scheduler_args); + const dim3 grid = AttnKernel::get_grid_shape(scheduler_params); const dim3 block = AttnKernel::get_block_shape(); mla_kernel<<>>(params, scheduler_params); diff --git a/src/kernels/attention/fmha_params.h b/src/kernels/attention/fmha_params.h index f36bfbe0..e39287e1 100644 --- a/src/kernels/attention/fmha_params.h +++ b/src/kernels/attention/fmha_params.h @@ -46,20 +46,20 @@ struct FmhaParams { int q_len = 0; int kv_len = 0; - // Only used for variable length sequence - // array of length batch_size + 1 holding starting offset of each sequence. - const int* __restrict__ q_cu_lens = nullptr; - const int* __restrict__ kv_cu_lens = nullptr; - - //////////////////////////////////////////////// - // Parameters for paged KV cache - //////////////////////////////////////////////// - // size for each cache block - int block_size = 1; - // the first slot id of each block - const int* __restrict__ block_table = nullptr; - // array of length batch_size + 1 holding starting offset of each sequence. - const int* __restrict__ block_cu_lens = nullptr; + // // Only used for variable length sequence + // // array of length batch_size + 1 holding starting offset of each sequence. + // const int* __restrict__ q_cu_lens = nullptr; + // const int* __restrict__ kv_cu_lens = nullptr; + + // //////////////////////////////////////////////// + // // Parameters for paged KV cache + // //////////////////////////////////////////////// + // // size for each cache block + // int block_size = 1; + // // the first slot id of each block + // const int* __restrict__ block_table = nullptr; + // // array of length batch_size + 1 holding starting offset of each sequence. + // const int* __restrict__ block_cu_lens = nullptr; //////////////////////////////////////////////// // Parameters for local attention @@ -87,7 +87,7 @@ struct FmhaParams { // Parameters for scheduling //////////////////////////////////////////////// // TODO: remove it after persistent kernel - int max_q_len = 0; + // int max_q_len = 0; }; } // namespace llm diff --git a/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh index be36b22f..af8e9674 100644 --- a/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh +++ b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh @@ -8,8 +8,6 @@ #include #include -#include "common/fmha_block.h" - namespace llm { using namespace cute; @@ -66,7 +64,9 @@ CUTE_DEVICE void warpgroup_reg_set() { } // namespace detail -template @@ -100,32 +100,49 @@ class Sm120KernelFmhaWs { static constexpr int kSharedStorageSize = sizeof(SharedStorage); - // Kernel params - using MainloopParams = typename CollectiveMainloop::Params; - using EpilogueParams = typename CollectiveEpilogue::Params; - using TileSchedulerParams = typename TileScheduler::Params; + struct Arguments { + ProblemShape problem_shape; // (Q, K, D, ((KH, G), B)) + typename Block::Arguments block; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + // cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + typename Block::Params block; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params scheduler; + }; + + // convert arguments to params + static Params to_underlying_arguments(Arguments const& args, + void* workspace) { + return Params{Block::to_underlying_arguments( + args.problem_shape, args.block, workspace), + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, + TileShape{})}; + } // returns grid and block shape for kernel launch - using TileSchedulerArgs = typename TileScheduler::Arguments; - static dim3 get_grid_shape(TileSchedulerArgs const& args) { - return TileScheduler::get_grid_shape(args); + static dim3 get_grid_shape(const Params& params) { + return TileScheduler::get_grid_shape(params.scheduler); } static dim3 get_block_shape() { return kThreadsPerBlock; } - template CUTE_DEVICE void load_loop(const Params& params, - const TileSchedulerParams& scheduler_params, PipelineQ& q_pipeline, PipelineKV& kv_pipeline, SharedStorage& ss) { - static constexpr bool kLocal = CollectiveMainloop::kLocal; - using Block = FmhaBlock; - auto q_state = cutlass::make_producer_start_state(); auto kv_state = cutlass::make_producer_start_state(); CollectiveMainloop mainloop; - TileScheduler scheduler(scheduler_params); + TileScheduler scheduler(params.scheduler); // thread idx within warp group (4 warps = 128 threads) const auto tidx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; @@ -133,7 +150,7 @@ class Sm120KernelFmhaWs { // process each block for (const auto blk_coord : scheduler) { // block coord: (batch_idx, m_block_idx, kv_head_idx) - const Block block(params, blk_coord); + const Block block(params.block, blk_coord); mainloop.load( block, tidx, q_pipeline, q_state, kv_pipeline, kv_state, ss.mainloop); } @@ -143,52 +160,36 @@ class Sm120KernelFmhaWs { kv_pipeline.producer_tail(kv_state); } // end of load_loop - template CUTE_DEVICE void fmha_loop(const Params& params, - const TileSchedulerParams& scheduler_params, PipelineQ& q_pipeline, PipelineKV& kv_pipeline, SharedStorage& ss) { - static constexpr bool kLocal = CollectiveMainloop::kLocal; using TiledMma = typename CollectiveMainloop::TiledMma; using BLK_M = typename CollectiveMainloop::BLK_M; using BLK_K = typename CollectiveMainloop::BLK_K; - using Block = FmhaBlock; - PipelineStateQ q_state; PipelineStateKV kv_state; CollectiveMainloop mainloop; CollectiveEpilogue epilogue; - // construct params - MainloopParams mainloop_params{params.sliding_window, - params.logits_soft_cap, - params.sm_scale, - params.sm_scale_log2, - params.group_size, - params.alibi_slopes_ptr}; - - EpilogueParams epilogue_params; - - TileScheduler scheduler(scheduler_params); + TileScheduler scheduler(params.scheduler); // thread idx within warp group (4 warps = 128 threads) const auto tidx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; // process each block - const auto& group_size = params.group_size; for (const auto blk_coord : scheduler) { // block coord: (batch_idx, m_block_idx, kv_head_idx) - const Block block(params, blk_coord); + const Block block(params.block, blk_coord); TiledMma tiled_mma; // accumulator: (MMA,MMA_M,MMA_K) auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); clear(tOrAccO); - mainloop.fmha(mainloop_params, + mainloop.fmha(params.mainloop, block, tOrAccO, tidx, @@ -198,14 +199,11 @@ class Sm120KernelFmhaWs { kv_state, ss.mainloop); - epilogue(epilogue_params, block, tOrAccO, tiled_mma, tidx, ss.epilogue); + epilogue(params.epilogue, block, tOrAccO, tiled_mma, tidx, ss.epilogue); } } // end of fmha_loop - template - CUTE_DEVICE void operator()(const Params& params, - const TileSchedulerParams& scheduler_params, - char* smem) { + CUTE_DEVICE void operator()(const Params& params, char* smem) { static constexpr bool kKVUseTma = CollectiveMainloop::kKVUseTma; static constexpr int kNumThreadsLoad = WarpScheduler::kNumWarpsLoad * cutlass::NumThreadsPerWarp; @@ -264,11 +262,11 @@ class Sm120KernelFmhaWs { if (role == WarpRole::Load) { detail::warpgroup_reg_set(); // load Q, K, V from gmem to smem - load_loop(params, scheduler_params, q_pipeline, kv_pipeline, ss); + load_loop(params, q_pipeline, kv_pipeline, ss); } else if (role == WarpRole::FMHA) { detail::warpgroup_reg_set(); // FMHA mainloop - fmha_loop(params, scheduler_params, q_pipeline, kv_pipeline, ss); + fmha_loop(params, q_pipeline, kv_pipeline, ss); } else if (role == WarpRole::Empty) { // Empty warp, do nothing except donating registers detail::warpgroup_reg_set(); diff --git a/src/kernels/attention/kernel/sm80_kernel_mha.cuh b/src/kernels/attention/kernel/sm80_kernel_mha.cuh index 6b67f279..1eec8c07 100644 --- a/src/kernels/attention/kernel/sm80_kernel_mha.cuh +++ b/src/kernels/attention/kernel/sm80_kernel_mha.cuh @@ -199,9 +199,8 @@ class Sm80KernelMha { using TileSchedulerParams = typename TileScheduler::Params; // returns grid and block shape for kernel launch - using TileSchedulerArgs = typename TileScheduler::Arguments; - static dim3 get_grid_shape(TileSchedulerArgs const& args) { - return TileScheduler::get_grid_shape(args); + static dim3 get_grid_shape(TileSchedulerParams const& params) { + return TileScheduler::get_grid_shape(params); } static dim3 get_block_shape() { return kMmaThreads; } diff --git a/src/kernels/attention/kernel/sm80_kernel_mla.cuh b/src/kernels/attention/kernel/sm80_kernel_mla.cuh index 1f3a79cb..79b1803b 100644 --- a/src/kernels/attention/kernel/sm80_kernel_mla.cuh +++ b/src/kernels/attention/kernel/sm80_kernel_mla.cuh @@ -188,9 +188,8 @@ class Sm80KernelMla { using TileSchedulerParams = typename TileScheduler::Params; // returns grid and block shape for kernel launch - using TileSchedulerArgs = typename TileScheduler::Arguments; - static dim3 get_grid_shape(TileSchedulerArgs const& args) { - return TileScheduler::get_grid_shape(args); + static dim3 get_grid_shape(TileSchedulerParams const& params) { + return TileScheduler::get_grid_shape(params); } static dim3 get_block_shape() { return kMmaThreads; } diff --git a/src/kernels/attention/tests/sm120_fmha_test.cu b/src/kernels/attention/tests/sm120_fmha_test.cu index bd481276..ddbd2ce4 100644 --- a/src/kernels/attention/tests/sm120_fmha_test.cu +++ b/src/kernels/attention/tests/sm120_fmha_test.cu @@ -6,7 +6,7 @@ #include "common/static_dispatch.h" #include "device/sm120_fmha_launch.cuh" -#include "mha_params.h" +#include "fmha_params.h" #include "tests/mha_ref.h" namespace llm { @@ -51,36 +51,45 @@ torch::Tensor sm120_fmha( const float sm_scale = 1.0 / sqrt(head_dim); // construct attention params - MHAParams params; + FmhaParams params; params.q_ptr = query.const_data_ptr(); - params.q_stride = - make_stride(query.stride(0), query.stride(1), query.stride(2), _1{}); params.k_ptr = key.const_data_ptr(); - params.k_stride = - make_stride(key.stride(0), key.stride(1), key.stride(2), _1{}); params.v_ptr = value.const_data_ptr(); - params.v_stride = - make_stride(value.stride(0), value.stride(1), value.stride(2), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = - make_stride(out.stride(0), out.stride(1), out.stride(2), _1{}); - params.alibi_slopes_ptr = alibi_slopes.has_value() - ? alibi_slopes.value().const_data_ptr() - : nullptr; + + params.q_batch_stride = query.stride(0); + params.q_seq_stride = query.stride(1); + params.q_head_stride = query.stride(2); + + params.k_batch_stride = key.stride(0); + params.k_seq_stride = key.stride(1); + params.k_head_stride = key.stride(2); + + params.v_batch_stride = value.stride(0); + params.v_seq_stride = value.stride(1); + params.v_head_stride = value.stride(2); + + params.o_batch_stride = out.stride(0); + params.o_seq_stride = out.stride(1); + params.o_head_stride = out.stride(2); params.batch_size = batch_size; - params.max_q_len = max_q_len; params.n_heads = n_heads; params.n_kv_heads = n_kv_heads; + params.head_dim = head_dim; + params.q_len = q_len; params.kv_len = kv_len; - params.head_dim = head_dim; - params.sm_scale = sm_scale; - params.logits_soft_cap = logits_soft_cap; + params.sliding_window = sliding_window; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale = sm_scale; + + params.alibi_slopes_ptr = alibi_slopes.has_value() + ? alibi_slopes.value().const_data_ptr() + : nullptr; - // normalize params that for performance optimization - params.normalize(); + // params.max_q_len = max_q_len; DISPATCH_TORCH_DTYPE_(query.dtype(), Dtype, [&] { DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { @@ -91,7 +100,7 @@ torch::Tensor sm120_fmha( /*ALIBI*/ false, /*SOFT_CAP*/ false, /*LOCAL*/ false, - MHAParams>(params, nullptr); + FmhaParams>(params, nullptr); }); }); });