diff --git a/src/kernels/attention/device/fmha.cuh b/src/kernels/attention/device/fmha.cuh new file mode 100644 index 00000000..4b1531ff --- /dev/null +++ b/src/kernels/attention/device/fmha.cuh @@ -0,0 +1,74 @@ +#pragma once + +#include "cutlass/cluster_launch.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/kernel_launch.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace llm { +using namespace cute; + +template +class Fmha { + public: + using Arguments = typename Kernel::Arguments; + using Params = typename Kernel::Params; + using ClusterShape = typename Kernel::ClusterShape; + + bool initialize(Arguments const& args, void* workspace = nullptr) { + params_ = Kernel::to_underlying_arguments(args, workspace); + if (is_initialized_) { + return true; + } + + const int smem_size = Kernel::kSharedStorageSize; + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(cutlass::device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + return false; + } + } + is_initialized_ = true; + return true; + } + + bool run(cudaStream_t stream = nullptr) const { + const dim3 block = Kernel::get_block_shape(); + const dim3 grid = Kernel::get_grid_shape(params_); + constexpr int smem_size = Kernel::kSharedStorageSize; + + cutlass::Status status; + if constexpr (Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(size<0>(ClusterShape{}), + size<1>(ClusterShape{}), + size<2>(ClusterShape{})); + + cutlass::ClusterLaunchParams launch_params{ + .grid_dims = grid, + .block_dims = block, + .cluster_dims = cluster, + .smem_size_in_bytes = smem_size, + .cuda_stream = stream, + }; + void const* kernel = (void const*)cutlass::device_kernel; + status = + cutlass::launch_kernel_on_cluster(launch_params, kernel, params_); + } else { + status = cutlass::kernel_launch( + grid, block, smem_size, stream, params_, /*launch_with_pdl=*/false); + } + return cutlass::Status::kSuccess == status; + } + + private: + Params params_; + bool is_initialized_ = false; +}; + +} // namespace llm diff --git a/src/kernels/attention/device/sm120_fmha_dispatch.cuh b/src/kernels/attention/device/sm120_fmha_dispatch.cuh deleted file mode 100644 index 56b79157..00000000 --- a/src/kernels/attention/device/sm120_fmha_dispatch.cuh +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include -#include - -#include "common/static_dispatch.h" - -namespace llm { -// forward declaration -template -void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream); - -// user-facing function to run the attention kernel -template -void sm120_run_mha(Params& params, cudaStream_t stream = nullptr) { - // normalize params that for performance optimization - params.normalize(); - - // dispatch to proper kernel instantiation based on params - DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] { - DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] { - DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] { - DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] { - sm120_launch_mha_kernel(params, stream); - }); - }); - }); - }); -} - -} // namespace llm diff --git a/src/kernels/attention/device/sm120_fmha_launch.cuh b/src/kernels/attention/device/sm120_fmha_launch.cuh deleted file mode 100644 index 396be6a4..00000000 --- a/src/kernels/attention/device/sm120_fmha_launch.cuh +++ /dev/null @@ -1,144 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -#include "collective/sm120_collective_epilogue.cuh" -#include "collective/sm120_collective_fmha_mainloop_ws.cuh" -#include "common/fmha_block.h" -#include "common/tile_scheduler.cuh" -#include "kernel/sm120_kernel_fmha_ws.cuh" - -namespace llm { - -namespace detail { -/// Generic kernel template. -template -__global__ __launch_bounds__(Operator::kThreadsPerBlock) void device_kernel( - __grid_constant__ const typename Operator::Params params) { - extern __shared__ char smem[]; - Operator op; - op(params, smem); -} -} // namespace detail - -template -void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream) { - // TODO: tune tile shape M/N based on the head dim and smem size - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability - // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| - // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | - // valid dynamic shared memory sizes for different compute capabilities: - // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 - // * 7.5 : 0, 32, 64 - // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 - // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 - // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 - // * 12.0 : 0, 8, 16, 32, 64, 100 - constexpr int BLK_M = 64; - constexpr int BLK_N = 64; - - // TMA is used for K/V loading - constexpr bool KV_USE_TMA = false; - - assert(params.n_heads % params.n_kv_heads == 0 && - "n_heads must be divisible by n_kv_heads"); - const int group_size = params.n_heads / params.n_kv_heads; - - using TileShape = Shape, Int, Int>; - - using Block = FmhaBlock; - - using CollectiveMainloop = Sm120CollectiveFMhaWs; - using CollectiveEpilogue = Sm120CollectiveEpilogue; - - // TODO: support persistent kernels - using TileScheduler = SingleTileScheduler; - - // TODO: pass in max_q_len and max_kv_len for variable length - // MNKL: (Q K D ((KH G), B)) - using ProblemShape = - cute::tuple, int>>; - ProblemShape problem_shape = make_tuple( - params.q_len, - params.kv_len, - params.head_dim, - make_tuple(make_tuple(params.n_kv_heads, group_size), params.batch_size)); - - using AttnKernel = Sm120KernelFmhaWs; - - // TODO: convert params to Kernel Args - auto q_stride = make_stride( - params.q_batch_stride, params.q_seq_stride, params.q_head_stride, _1{}); - auto k_stride = make_stride( - params.k_batch_stride, params.k_seq_stride, params.k_head_stride, _1{}); - auto v_stride = make_stride( - params.v_batch_stride, params.v_seq_stride, params.v_head_stride, _1{}); - auto o_stride = make_stride( - params.o_batch_stride, params.o_seq_stride, params.o_head_stride, _1{}); - - typename AttnKernel::Arguments attn_args{ - .problem_shape = problem_shape, - // Block arguments - .block = - { - .q_ptr = params.q_ptr, - .k_ptr = params.k_ptr, - .v_ptr = params.v_ptr, - .o_ptr = params.o_ptr, - .q_stride = q_stride, - .k_stride = k_stride, - .v_stride = v_stride, - .o_stride = o_stride, - .sliding_window = params.sliding_window, - }, - // mainloop arguments - .mainloop = - { - .sliding_window = params.sliding_window, - .logits_soft_cap = params.logits_soft_cap, - .sm_scale = params.sm_scale, - .alibi_slopes_ptr = params.alibi_slopes_ptr, - }, - // epilogue arguments - .epilogue = {}, - }; - - auto attn_params = - AttnKernel::to_underlying_arguments(attn_args, /*workspace*/ nullptr); - - auto mha_kernel = detail::device_kernel; - - const auto smem_size = AttnKernel::kSharedStorageSize; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - - const dim3 grid = AttnKernel::get_grid_shape(attn_params); - const dim3 block = AttnKernel::get_block_shape(); - - mha_kernel<<>>(attn_params); - // TODO: check launch status -} - -} // namespace llm diff --git a/src/kernels/attention/fmha_runner.h b/src/kernels/attention/fmha_runner.h new file mode 100644 index 00000000..9236bd80 --- /dev/null +++ b/src/kernels/attention/fmha_runner.h @@ -0,0 +1,139 @@ +#pragma once +#include +#include + +#include +#include + +#include "collective/sm120_collective_epilogue.cuh" +#include "collective/sm120_collective_fmha_mainloop_ws.cuh" +#include "common/fmha_block.h" +#include "common/tile_scheduler.cuh" +#include "device/fmha.cuh" +#include "fmha_params.h" +#include "kernel/sm120_kernel_fmha_ws.cuh" + +namespace llm { +// ? Should include ArchTag? +// * select right kernel based on ArchTag? +// ? how to support fast compliling? +// * only compile the kernel for the target compute capability +template +class FmhaRunner { + public: + static bool run(const FmhaParams& params, cudaStream_t stream = nullptr) { + assert(params.head_dim <= kHeadDim); + // dispatch to proper kernel instantiation based on params + DISPATCH_BOOL(params.head_dim == kHeadDim, EVEN_K, [&] { + DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] { + DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] { + DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] { + return run_kernel(params, stream); + }); + }); + }); + }); + return false; // should never reach here + } + + template + static bool run_kernel(const FmhaParams& params, + cudaStream_t stream = nullptr) { + // TODO: tune tile shape M/N based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x + // | 12.0| Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 + // | 100 | valid dynamic shared memory sizes for different compute + // capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 + static constexpr int BLK_M = 64; + static constexpr int BLK_N = 64; + + // TMA is used for K/V loading + constexpr bool KV_USE_TMA = false; + + // TODO: pass in max_q_len and max_kv_len for variable length + // MNKL: (Q K D ((KH G), B)) + using ProblemShape = + cute::tuple, int>>; + + using TileShape = Shape, Int, Int>; + + using Block = FmhaBlock; + + using CollectiveMainloop = Sm120CollectiveFMhaWs; + using CollectiveEpilogue = + Sm120CollectiveEpilogue; + + // TODO: support persistent kernels + using TileScheduler = SingleTileScheduler; + + using AttnKernel = Sm120KernelFmhaWs; + + assert(params.n_heads % params.n_kv_heads == 0 && + "n_heads must be divisible by n_kv_heads"); + const int group_size = params.n_heads / params.n_kv_heads; + + ProblemShape problem_shape = + make_tuple(params.q_len, + params.kv_len, + params.head_dim, + make_tuple(make_tuple(params.n_kv_heads, group_size), + params.batch_size)); + + auto q_stride = make_stride( + params.q_batch_stride, params.q_seq_stride, params.q_head_stride, _1{}); + auto k_stride = make_stride( + params.k_batch_stride, params.k_seq_stride, params.k_head_stride, _1{}); + auto v_stride = make_stride( + params.v_batch_stride, params.v_seq_stride, params.v_head_stride, _1{}); + auto o_stride = make_stride( + params.o_batch_stride, params.o_seq_stride, params.o_head_stride, _1{}); + + typename AttnKernel::Arguments attn_args{ + .problem_shape = problem_shape, + .block = + { + .q_ptr = params.q_ptr, + .k_ptr = params.k_ptr, + .v_ptr = params.v_ptr, + .o_ptr = params.o_ptr, + .q_stride = q_stride, + .k_stride = k_stride, + .v_stride = v_stride, + .o_stride = o_stride, + .sliding_window = params.sliding_window, + }, + .mainloop = + { + .sliding_window = params.sliding_window, + .logits_soft_cap = params.logits_soft_cap, + .sm_scale = params.sm_scale, + .alibi_slopes_ptr = params.alibi_slopes_ptr, + }, + .epilogue = {}, + }; + + Fmha fmha; + if (!fmha.initialize(attn_args, /*workspace=*/nullptr)) { + return false; + } + return fmha.run(stream); + } +}; +} // namespace llm diff --git a/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh index af8e9674..854190da 100644 --- a/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh +++ b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -72,13 +73,12 @@ template class Sm120KernelFmhaWs { public: + using ArchTag = cutlass::arch::Sm120; + using TileShape = typename CollectiveMainloop::TileShape; using Element = typename CollectiveMainloop::Element; using ClusterShape = typename CollectiveMainloop::ClusterShape; - static const int kThreadsPerBlock = - WarpScheduler::kNumWarps * cutlass::NumThreadsPerWarp; - using PipelineQ = typename CollectiveMainloop::PipelineQ; using PipelineKV = typename CollectiveMainloop::PipelineKV; @@ -100,6 +100,11 @@ class Sm120KernelFmhaWs { static constexpr int kSharedStorageSize = sizeof(SharedStorage); + // needed for cutlass::device_kernel + static constexpr uint32_t MaxThreadsPerBlock = + WarpScheduler::kNumWarps * cutlass::NumThreadsPerWarp; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + struct Arguments { ProblemShape problem_shape; // (Q, K, D, ((KH, G), B)) typename Block::Arguments block; @@ -132,7 +137,7 @@ class Sm120KernelFmhaWs { static dim3 get_grid_shape(const Params& params) { return TileScheduler::get_grid_shape(params.scheduler); } - static dim3 get_block_shape() { return kThreadsPerBlock; } + static dim3 get_block_shape() { return MaxThreadsPerBlock; } CUTE_DEVICE void load_loop(const Params& params, PipelineQ& q_pipeline, diff --git a/src/kernels/attention/tests/sm120_fmha_test.cu b/src/kernels/attention/tests/sm120_fmha_test.cu index ddbd2ce4..8dfbb72a 100644 --- a/src/kernels/attention/tests/sm120_fmha_test.cu +++ b/src/kernels/attention/tests/sm120_fmha_test.cu @@ -4,9 +4,8 @@ #include #include -#include "common/static_dispatch.h" -#include "device/sm120_fmha_launch.cuh" #include "fmha_params.h" +#include "fmha_runner.h" #include "tests/mha_ref.h" namespace llm { @@ -93,15 +92,7 @@ torch::Tensor sm120_fmha( DISPATCH_TORCH_DTYPE_(query.dtype(), Dtype, [&] { DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] { - sm120_launch_mha_kernel(params, nullptr); - }); + FmhaRunner::run(params, /*stream=*/nullptr); }); }); return out; @@ -182,9 +173,9 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(6), // n_heads ::testing::Values(6 /*mha*/, 3 /*gqa*/, 1 /*mqa*/), // n_kv_heads ::testing::Values(32, 64), // head_dim - ::testing::Values(0.0), // logits_soft_cap - ::testing::Values(false), // alibi slope - ::testing::Values(-1) // sliding window + ::testing::Values(0.0, 50.0), // logits_soft_cap + ::testing::Values(false, true), // alibi slope + ::testing::Values(-1, 0, 10) // sliding window )); } // namespace llm