diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 10d97757..7c671dd1 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -11,16 +11,13 @@ cc_library( layout_convertor.h fast_cast.cuh online_softmax.cuh + safe_copy.h mask.h static_dispatch.h mha_params.h mha_tile.h - mha_kernel_sm80.cuh - sm80_mha_dispatch.cuh mla_params.h mla_tile.h - mla_traits_sm80.h - mla_kernel_sm80.cuh attn_combine_kernel.cuh DEPS cutlass @@ -74,9 +71,8 @@ cc_test( NAME mla_kernel_test SRCS - mla_traits_test.cpp - mla_kernel_sm80_test.cu - mla_kernel_sm80_pagedkv_test.cu + sm80_mla_test.cu + sm80_mla_pagedkv_test.cu DEPS :attention.template absl::random_random @@ -117,7 +113,7 @@ nvbench_binary( nvbench_binary( NAME - mla_sm80_bench + sm80_mla_bench SRCS mla_sm80_bench.cu DEPS diff --git a/src/kernels/attention/attn_combine_kernel.cuh b/src/kernels/attention/attn_combine_kernel.cuh index 89a35185..2ad749ed 100644 --- a/src/kernels/attention/attn_combine_kernel.cuh +++ b/src/kernels/attention/attn_combine_kernel.cuh @@ -6,8 +6,8 @@ #include #include -#include "cute_extensions.cuh" #include "fast_cast.cuh" +#include "safe_copy.h" namespace llm { @@ -239,4 +239,4 @@ void launch_attn_combine_kernel(const Params& params, cudaStream_t stream) { combine_kernel<<>>(params); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/generate_instantiation_cu.py b/src/kernels/attention/generate_instantiation_cu.py index fb1fb339..4bb2e7a2 100755 --- a/src/kernels/attention/generate_instantiation_cu.py +++ b/src/kernels/attention/generate_instantiation_cu.py @@ -39,17 +39,18 @@ """ MLA_KERNEL_TEMPLATE = """ -#include "mla_kernel_sm80.cuh" // IWYU pragma: export +#include "sm80_mla_launch.cuh" // IWYU pragma: export #include "mla_params.h" // IWYU pragma: export -#include "mla_traits_sm80.h" // IWYU pragma: export namespace llm {{ -using Traits = MLATraitsSM80<{DTYPE}, {HEAD_DIM}, {ROPE_HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}, {STAGES}>; using Params = MLAPagedKVParams; -template void launch_mla_kernel_sm80(const Params& params, - cudaStream_t stream); +template void sm80_launch_mla_kernel(const Params& params, + cudaStream_t stream); }} // namespace llm """ @@ -87,28 +88,18 @@ class MLAKernel: dtype: str head_dim: int rope_head_dim: int - blk_m: int - blk_n: int - blk_k: int - stages: int @property def template(self) -> str: - assert self.head_dim % self.blk_k == 0 - return MLA_KERNEL_TEMPLATE.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, ROPE_HEAD_DIM=self.rope_head_dim, - BLK_M=self.blk_m, - BLK_N=self.blk_n, - BLK_K=self.blk_k, - STAGES=self.stages, ) @property def filename(self) -> str: - return f"mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_s{self.stages}_sm80.cu" + return f"sm80_mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}.cu" def gen_mha_kernels() -> Iterator[MHAKernel]: @@ -141,25 +132,15 @@ def gen_mha_kernels() -> Iterator[MHAKernel]: def gen_mla_kernels() -> Iterator[MLAKernel]: # TODO: choose BLK_M, BLK_N, BLK_K, STAGES based on compute capability # mla kernel instantiations - for dtype, head_dim, rope_head_dim, ( - blk_m, - blk_n, - blk_k, - stages, - ) in itertools.product( + for dtype, head_dim, rope_head_dim in itertools.product( ["fp16", "bf16"], # dtype [512], # head_dim [64], # rope_head_dim - [(64, 16, 128, 1)], # blk_m, blk_n, blk_k, stages ): yield MLAKernel( dtype=dtype, head_dim=head_dim, rope_head_dim=rope_head_dim, - blk_m=blk_m, - blk_n=blk_n, - blk_k=blk_k, - stages=stages, ) diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh deleted file mode 100644 index ca1d47e2..00000000 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ /dev/null @@ -1,665 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include - -#include "cute/config.hpp" -#include "cute_extensions.cuh" -#include "fast_cast.cuh" -#include "layout_convertor.h" -#include "mask.h" -#include "mla_tile.h" -#include "online_softmax.cuh" - -namespace llm { - -template -struct MLASharedStorage { - using DType = typename Traits::DType; - using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutKV = typename Traits::SmemLayoutKV; - using SmemLayoutP = typename Traits::SmemLayoutP; - using SmemLayoutQRope = typename Traits::SmemLayoutQRope; - using SmemLayoutKRope = typename Traits::SmemLayoutKRope; - using SmemLayoutVt = typename Traits::SmemLayoutVt; - using SmemLayoutO = typename Traits::SmemLayoutO; - using SmemLayoutRowmax = typename Traits::SmemLayoutRowmax; - using SmemLayoutRowsum = typename Traits::SmemLayoutRowsum; - - union { - struct { - cute::array_aligned> q_smem; - union { - cute::array_aligned> kv_smem; - cute::array_aligned> vt_smem; - }; - cute::array_aligned> p_smem; - cute::array_aligned> q_rope_smem; - cute::array_aligned> k_rope_smem; - union { - cute::array_aligned> - row_max_smem; - cute::array_aligned> - row_sum_smem; - }; - }; - - cute::array_aligned> o_smem; - }; -}; - -template -__global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( - __grid_constant__ const Params params) { - using namespace cute; - - constexpr int kBlockM = Traits::kBlockM; - constexpr int kBlockN = Traits::kBlockN; - constexpr int kBlockK = Traits::kBlockK; - constexpr int kHeadDim = Traits::kHeadDim; - constexpr int kSteps = Traits::kSteps; - constexpr int kStages = Traits::kStages; - constexpr int kRopeHeadDim = Traits::kRopeHeadDim; - constexpr int kRowsPerMMA = Traits::kRowsPerMMA; - - using _BLK_M = Int; - using _BLK_N = Int; - using _BLK_K = Int; - using _STEPS = Int; - using _HEAD_DIM = Int; - using _ROPE_HEAD_DIM = Int; - - // type alias - using DType = typename Traits::DType; - - using TiledMma_QK = typename Traits::TiledMma_QK; - using TiledMma_PV = typename Traits::TiledMma_PV; - - using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutKV = typename Traits::SmemLayoutKV; - using SmemLayoutP = typename Traits::SmemLayoutP; - using SmemLayoutQRope = typename Traits::SmemLayoutQRope; - using SmemLayoutKRope = typename Traits::SmemLayoutKRope; - using SmemLayoutVt = typename Traits::SmemLayoutVt; - using SmemLayoutO = typename Traits::SmemLayoutO; - using SmemLayoutRowmax = typename Traits::SmemLayoutRowmax; - using SmemLayoutRowsum = typename Traits::SmemLayoutRowsum; - using SharedStorage = MLASharedStorage; - - using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; - using GmemTiledCopyQRope = typename Traits::GmemTiledCopyQRope; - using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; - using GmemTiledCopyKRope = typename Traits::GmemTiledCopyKRope; - using GmemTiledCopyO = typename Traits::GmemTiledCopyO; - - using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; - using SmemTiledCopyK = typename Traits::SmemTiledCopyK; - using SmemTiledCopyS = typename Traits::SmemTiledCopyS; - using SmemTiledCopyP = typename Traits::SmemTiledCopyP; - using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; - using SmemTiledCopyO = typename Traits::SmemTiledCopyO; - - const int m_block_idx = blockIdx.x; - const int batch_idx = blockIdx.y; - const int tidx = threadIdx.x; - - const auto& group_size = params.group_size; - - // ProblemShape - // Q/O: (q_packed_len, HEAD_DIM) - // Q_ROPE: (q_packed_len, ROPE_HEAD_DIM) - MLATile tile(params, batch_idx); - auto [Q, Q_ROPE, O] = tile.template get_qo_tile(); - // KV: (kv_len, HEAD_DIM) - // K_ROPE: (kv_len, ROPE_HEAD_DIM) - auto [KV, K_ROPE] = tile.template get_kv_tile(); - - const int q_packed_len = size<0>(Q); - const int q_len = q_packed_len / group_size; - const int kv_len = size<0>(KV); - - if (m_block_idx * kBlockM >= size<0>(Q)) { - // m out of bound, return - return; - } - - // Gmem - // (BLK_M, BLK_K, STEPS) - Tensor gQ = - local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); - Tensor gO = - local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); - // (BLK_N, BLK_K, n, STEPS) - Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _)); - - // (BLK_M, ROPE_HEAD_DIM) - Tensor gQ_rope = local_tile( - Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(m_block_idx, _0{})); - // (BLK_N, ROPE_HEAD_DIM, n) - Tensor gK_rope = - local_tile(K_ROPE, Shape<_BLK_N, _ROPE_HEAD_DIM>{}, make_coord(_, _0{})); - - // Smem - extern __shared__ char smem[]; - auto& ss = *reinterpret_cast(smem); - - // (BLK_M, BLK_K, STEPS), k-major - Tensor sQ = make_tensor(make_smem_ptr(ss.q_smem.data()), SmemLayoutQ{}); - // (BLK_N, BLK_K, STEPS, STAGES), k-major - Tensor sKV = make_tensor(make_smem_ptr(ss.kv_smem.data()), SmemLayoutKV{}); - - // (BLK_M, BLK_N), k-major - Tensor sP = make_tensor(make_smem_ptr(ss.p_smem.data()), SmemLayoutP{}); - - // (BLK_M, ROPE_HEAD_DIM), k-major - Tensor sQ_rope = - make_tensor(make_smem_ptr(ss.q_rope_smem.data()), SmemLayoutQRope{}); - // (BLK_N, ROPE_HEAD_DIM, STAGES), k-major - Tensor sK_rope = - make_tensor(make_smem_ptr(ss.k_rope_smem.data()), SmemLayoutKRope{}); - - // Tensor for V^t; used in GEMM-II. - // (BLK_K, BLK_N, STEPS, STAGES) - Tensor sVt = make_tensor(make_smem_ptr(ss.vt_smem.data()), SmemLayoutVt{}); - - // (BLK_M, BLK_K, STEPS), reuse smem - Tensor sO = make_tensor(make_smem_ptr(ss.o_smem.data()), SmemLayoutO{}); - - // (BLK_M, 2) - Tensor sRowmax = - make_tensor(make_smem_ptr(ss.row_max_smem.data()), SmemLayoutRowmax{}); - Tensor sRowsum = - make_tensor(make_smem_ptr(ss.row_max_smem.data()), SmemLayoutRowsum{}); - - // reduce rowmax/rowsum accross 2 warps via shared memory - // thread layout: (32, (4, 2)), each thread process 2 rows - // (store_idx, load_idx) = (0, 64) or (1, 65), ... - const int row_store_idx = tidx / 4 * 2; - const int row_load_idx = row_store_idx ^ kBlockM; - auto reduce_rowmax = [&](auto& row_max) { - CUTE_UNROLL - for (int i = 0; i < size(row_max); ++i) { - sRowmax(row_store_idx + i) = row_max(i); - } - __syncthreads(); - CUTE_UNROLL - for (int i = 0; i < size(row_max); ++i) { - row_max(i) = max(row_max(i), sRowmax(row_load_idx + i)); - } - }; - auto reduce_rowsum = [&](auto& row_sum) { - CUTE_UNROLL - for (int i = 0; i < size(row_sum); ++i) { - sRowsum(row_store_idx + i) = row_sum(i); - } - __syncthreads(); - CUTE_UNROLL - for (int i = 0; i < size(row_sum); ++i) { - row_sum(i) += sRowsum(row_load_idx + i); - } - }; - - // g2s tiled copy for q - GmemTiledCopyQ gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx); - - // coordinate tensor for oob handling - // (BLK_M, BLK_K) -> (blk_m, blk_k) - Tensor cQ = make_identity_tensor(Shape<_BLK_M, _BLK_K>{}); - Tensor tCcQ = gmem_thr_copy_Q.partition_S(cQ(_, _)); - auto max_coord_Q = make_coord(q_packed_len - m_block_idx * kBlockM, kBlockK); - - auto produce_q = [&](int step) { - // gQ/sQ: (BLK_M, BLK_K, STEPS) - auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step)); - auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step)); - safe_copy( - gmem_tiled_copy_Q, tCgQ, tCsQ, tCcQ, max_coord_Q); - }; - - // g2s tiled copy for q_rope - GmemTiledCopyQRope gmem_tiled_copy_Q_rope; - auto gmem_thr_copy_Q_rope = gmem_tiled_copy_Q_rope.get_slice(tidx); - - // (BLK_M, ROPE_HEAD_DIM) -> (blk_m, rope_head_dim) - Tensor cQ_rope = make_identity_tensor(Shape<_BLK_M, _ROPE_HEAD_DIM>{}); - Tensor tCcQ_rope = gmem_thr_copy_Q_rope.partition_S(cQ_rope); - - auto produce_q_rope = [&]() { - auto tCgQ_rope = gmem_thr_copy_Q_rope.partition_S(gQ_rope); - auto tCsQ_rope = gmem_thr_copy_Q_rope.partition_D(sQ_rope); - auto max_coord = - make_coord(q_packed_len - m_block_idx * kBlockM, kRopeHeadDim); - safe_copy( - gmem_tiled_copy_Q_rope, tCgQ_rope, tCsQ_rope, tCcQ_rope, max_coord); - }; - - // g2s tiled copy for kv - GmemTiledCopyKV gmem_tiled_copy_KV; - auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx); - - // (BLK_N, BLK_K, STEPS) -> (blk_n, head_dim) - Tensor cKV = make_identity_tensor(Shape<_BLK_N, _BLK_K>{}); - Tensor tCcKV = gmem_thr_copy_KV.partition_S(cKV); - - auto produce_kv = [&](int ni, int step, int stage) { - // gKV: (BLK_N, BLK_K, n, STEPS) - // sKV: (BLK_N, BLK_K, STEPS, STAGES) - auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step)); - auto tCsKV = gmem_thr_copy_KV.partition_D(sKV(_, _, step, stage)); - auto max_coord = make_coord(kv_len - ni * kBlockN, kBlockK); - safe_copy( - gmem_tiled_copy_KV, tCgKV, tCsKV, tCcKV, max_coord); - }; - - auto produce_kv_no_oob = [&](int ni, int step, int stage) { - // gKV: (BLK_N, BLK_K, n, STEPS) - // sKV: (BLK_N, BLK_K, STEPS, STAGES) - auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step)); - auto tCsKV = gmem_thr_copy_KV.partition_D(sKV(_, _, step, stage)); - cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV); - }; - - // g2s tiled copy for k_rope - GmemTiledCopyKRope gmem_tiled_copy_K_rope; - auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx); - - // (BLK_N, ROPE_HEAD_DIM) -> (blk_n, rope_head_dim) - Tensor cK_rope = make_identity_tensor(Shape<_BLK_N, _ROPE_HEAD_DIM>{}); - Tensor tKcK_rope = gmem_thr_copy_K_rope.partition_S(cK_rope); - - auto produce_k_rope = [&](int ni, int stage) { - // gK_rope: (BLK_N, ROPE_HEAD_DIM, n) - // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) - auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni)); - auto tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage)); - auto max_coord = make_coord(kv_len - ni * kBlockN, kRopeHeadDim); - safe_copy( - gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope, tKcK_rope, max_coord); - }; - - auto produce_k_rope_no_oob = [&](int ni, int stage) { - // gK_rope: (BLK_N, ROPE_HEAD_DIM, n) - // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) - auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni)); - auto tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage)); - cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope); - }; - - // GEMM-I: S = Q@K.T - TiledMma_QK tiled_mma_qk; - auto thr_mma_qk = tiled_mma_qk.get_slice(tidx); - // sQ: (BLK_M, BLK_K, STEPS) - auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{})); - // sKV: (BLK_N, BLK_K, STEPS, STAGES) - auto tSrK = thr_mma_qk.partition_fragment_B(sKV(_, _, _0{}, _0{})); - - // s2r tiled copy for q/q_rope - SmemTiledCopyQ smem_tiled_copy_Q; - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx); - // (CPY, CPY_M, CPY_K, STEPS) - auto tCsQ = smem_thr_copy_Q.partition_S(sQ); - // (CPY, CPY_M, CPY_K) - auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); - - // s2r tiled copy for k/k_rope - SmemTiledCopyK smem_tiled_copy_K; - auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx); - // (CPY, CPY_N, CPY_K, STEPS, STAGES) - auto tCsK = smem_thr_copy_K.partition_S(sKV); - // (CPY, CPY_N, CPY_K) - auto tCrK = smem_thr_copy_K.retile_D(tSrK); - - // S = Q@K.T - // tSrS: (MMA,MMA_M,MMA_N) - auto compute_qk = [&](auto& tSrS, int step, int stage) { - // tCsQ: (CPY, CPY_M, CPY_K, STEPS) - auto tCsQ_s = tCsQ(_, _, _, step); - // TCsK: (CPY, CPY_N, CPY_K, STEPS, STAGES) - auto tCsK_s = tCsK(_, _, _, step, stage); - // prefetch kv - cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{})); - cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{})); - - CUTE_UNROLL - for (int k = 0; k < size<2>(tCsQ_s); ++k) { - // prefetch next kv - if (k != size<2>(tCsQ_s) - 1) { - const auto next_k = k + 1; - cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, next_k), tCrQ(_, _, next_k)); - cute::copy(smem_tiled_copy_K, tCsK_s(_, _, next_k), tCrK(_, _, next_k)); - } - cute::gemm(tiled_mma_qk, tSrQ(_, _, k), tSrK(_, _, k), tSrS); - } - }; - - // sQ_rope: (BLK_N, ROPE_HEAD_DIM) - auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope); - // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) - auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope(_, _, _0{})); - // (CPY, CPY_M, CPY_K) - auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); - // (CPY, CPY_M, CPY_K) - auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); - // (CPY, CPY_N, CPY_K, STAGES) - auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); - // (CPY, CPY_N, CPY_K) - auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); - auto compute_qk_rope = [&](auto& tSrS, int stage) { - auto tCsK_rope_s = tCsK_rope(_, _, _, stage); - cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); - cute::copy( - smem_tiled_copy_K, tCsK_rope_s(_, _, _0{}), tCrK_rope(_, _, _0{})); - - CUTE_UNROLL - for (int k = 0; k < size<2>(tCsQ_rope); ++k) { - if (k != size<2>(tCsQ_rope) - 1) { - const auto next_k = k + 1; - cute::copy(smem_tiled_copy_Q, - tCsQ_rope(_, _, next_k), - tCrQ_rope(_, _, next_k)); - cute::copy(smem_tiled_copy_K, - tCsK_rope_s(_, _, next_k), - tCrK_rope(_, _, next_k)); - } - cute::gemm(tiled_mma_qk, tSrQ_rope(_, _, k), tSrK_rope(_, _, k), tSrS); - } - }; - - // GEMM-II: O = softmax(S)@V - TiledMma_PV tiled_mma_pv; - auto thr_mma_pv = tiled_mma_pv.get_slice(tidx); - // sP: (BLK_M, BLK_N) - auto tOrP = thr_mma_pv.partition_fragment_A(sP); - // sVt: (BLK_K, BLK_N, STEPS, STAGES) - auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}, _0{})); - - // s2r tiled copy for p - SmemTiledCopyP smem_tiled_copy_P; - auto smem_thr_copy_P = smem_tiled_copy_P.get_slice(tidx); - // (CPY, CPY_M, CPY_K) - auto tCsP = smem_thr_copy_P.partition_S(sP); - // (CPY, CPY_M, CPY_K) - auto tCrP = smem_thr_copy_P.retile_D(tOrP); - - // s2r tiled copy for vt - SmemTiledCopyVt smem_tiled_copy_Vt; - auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx); - // (CPY, CPY_N, CPY_K, STEPS, STAGES) - auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); - // (CPY, CPY_N, CPY_K) - auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); - - // O = P*V = softmax(S)*V - // tOrO: (MMA,MMA_M,MMA_K,STEPS) - auto compute_pv = [&](auto& tOrO, int step, int stage) { - auto tOrO_s = tOrO(_, _, _, step); - auto tCsVt_s = tCsVt(_, _, _, step, stage); - - cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{})); - cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{})); - - CUTE_UNROLL - for (int k = 0; k < size<2>(tCsVt_s); ++k) { - if (k != size<2>(tCsVt_s) - 1) { - const auto next_k = k + 1; - cute::copy(smem_tiled_copy_P, tCsP(_, _, next_k), tCrP(_, _, next_k)); - cute::copy( - smem_tiled_copy_Vt, tCsVt_s(_, _, next_k), tCrVt(_, _, next_k)); - } - cute::gemm(tiled_mma_pv, tCrP(_, _, k), tOrVt(_, _, k), tOrO_s); - } - }; - - // r2s tiled copy for S/P - SmemTiledCopyS smem_tiled_copy_S; - auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(tidx); - auto store_s_to_smem = [&](const auto& tSrS) { - // cast Accumulator to Element type - auto tSrS_ = make_tensor_like(tSrS); - fast_cast(tSrS, tSrS_); - // copy scores from rmem to smem - auto tCrS = smem_thr_copy_S.retile_S(tSrS_); - auto tCsS = smem_thr_copy_S.partition_D(sP); - cute::copy(smem_tiled_copy_S, tCrS, tCsS); - }; - - // tOrO: (MMA,MMA_M,MMA_K,STEPS) - auto epilogue = [&](const auto& tOrO) { - // write output to gmem - // 1. copy output from reg to smem (reuse sQ) - SmemTiledCopyO smem_tiled_copy_O; - auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - auto tOrO_s = tOrO(_, _, _, step); - auto sO_s = sO(_, _, step); - - // cast Accumulator to Element type - auto tOrO_ = make_tensor_like(tOrO_s); - fast_cast(tOrO_s, tOrO_); - - auto tCrO = smem_thr_copy_O.retile_S(tOrO_); - auto tCsO = smem_thr_copy_O.partition_D(sO_s); - cute::copy(smem_tiled_copy_O, tCrO, tCsO); - } - - __syncthreads(); - - // 2. copy output from smem to gmem - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx); - - // (BLK_M, BLK_K) -> (blk_m, blk_k) - auto cO = make_identity_tensor(Shape<_BLK_M, _BLK_K>{}); - auto tCcO = gmem_thr_copy_Q.partition_S(cO); - auto max_coord_O = - make_coord(q_packed_len - m_block_idx * kBlockM, kBlockK); - - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - auto tCsO = gmem_thr_copy_O.partition_S(sO(_, _, step)); - auto tCgO = gmem_thr_copy_O.partition_D(gO(_, _, step)); - - safe_copy( - gmem_tiled_copy_O, tCsO, tCgO, tCcO, max_coord_O); - } - }; - - // output accumulator: (MMA,MMA_M,MMA_K,STEPS) - auto tOrO = - partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STEPS>{}); - auto tOrO_mn = - make_tensor(tOrO.data(), LayoutConvertor::to_mns(tOrO.layout())); - clear(tOrO); - - const int n_block_min = 0; - // process kv in range: [0, kv_idx_max) - const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len; - const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); - const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); - - if (n_block_min >= n_block_max) { - // write output to gmem - epilogue(tOrO); - return; - } - - // ############### Prologue ############### - // g2s async data copy pipelines - // | stage | queue | - // | 1 | [k_r, kv0, kv1] | - // | 2 | [k_r, kv0, kv1, (nop, nop, nop, k_r, kv0, kv1)] | - // ^ kWait = (kSteps + 1) * (2*kStages - 1) - 1 - constexpr int kWait = (kSteps + 1) * (2 * kStages - 1) - 1; - // produce q_rope/q - produce_q_rope(); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - produce_q(step); - } - // produce k_rope/kv - CUTE_UNROLL - for (int stage = 0; stage < kStages; ++stage) { - const int ni = n_block_max - 1 - stage; - // insert nops between stages for a perfect pipeline - if (stage != 0) { - cp_async_fence(); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - cp_async_fence(); - } - } - // handle oob kv - if (ni >= n_block_min) { - stage == 0 ? produce_k_rope(ni, stage) : produce_k_rope_no_oob(ni, stage); - cp_async_fence(); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - stage == 0 ? produce_kv(ni, step, stage) - : produce_kv_no_oob(ni, step, stage); - cp_async_fence(); - } - } else { - cp_async_fence(); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - cp_async_fence(); - } - } - } - - // ############### Mainloop ############### - // attention score accumulator, (MMA,MMA_M,MMA_N) - auto tSrS = partition_fragment_C(tiled_mma_qk, Shape<_BLK_M, _BLK_N>{}); - auto tScS = - thr_mma_qk.partition_C(make_identity_tensor(Shape<_BLK_M, _BLK_N>{})); - auto tSrS_mn = - make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout())); - auto tScS_mn = - make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout())); - - constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); - using Softmax = OnlineSoftmax; - using Mask = Mask; - - Softmax softmax(params.sm_scale_log2); - Mask mask(q_len, kv_len, group_size, /*sliding_window=*/kv_len); - - constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; - const int n_blocks = n_block_max - n_block_min; - int stage = 0; - CUTE_NO_UNROLL - for (int i = 0; i < n_blocks; ++i) { - const int ni = n_block_max - 1 - i; - clear(tSrS); - - cp_async_wait(); - __syncthreads(); - - // 1> S = Q_rope@K_rope.T - compute_qk_rope(tSrS, stage); - cp_async_fence(); - - // 2> S += Q@K.T - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - cp_async_wait(); - __syncthreads(); - - compute_qk(tSrS, step, stage); - cp_async_fence(); - } - - // apply mask - if (i < n_oob_mask) { - mask.apply(tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); - } else { - mask.apply( - tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); - } - - softmax.rescale(tSrS_mn, tOrO_mn, reduce_rowmax); - - // save tSrS from rmem to smem - store_s_to_smem(tSrS); - __syncthreads(); - - // 3> O = softmax(S)*V - const auto next_ni = ni - kStages; - if (next_ni >= n_block_min) { - produce_k_rope_no_oob(next_ni, stage); - cp_async_fence(); - - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - compute_pv(tOrO, step, stage); - __syncthreads(); - produce_kv_no_oob(next_ni, step, stage); - cp_async_fence(); - } - } else { - cp_async_fence(); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - compute_pv(tOrO, step, stage); - cp_async_fence(); - } - } - - // move to next stage - if constexpr (kStages == 1) { - // do nothing - } else if constexpr (kStages == 2) { - stage = stage ^ 1; - } else { - stage = (stage + 1) % kStages; - } - } - - // ############### Epilogue ############### - - // normalize output: o /= rowsum - softmax.finalize(tOrO_mn, reduce_rowsum); - - // write output to gmem - epilogue(tOrO); -} - -template -void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { - const auto batch_size = params.batch_size; - const auto max_q_packed_len = params.max_q_len * params.n_heads; - - const auto smem_size = sizeof(MLASharedStorage); - // print("smem_size: %d\n", smem_size); - - auto mla_kernel = mla_kernel_sm80; - C10_CUDA_CHECK(cudaFuncSetAttribute( - mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // TODO: support persistent kernels - dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, 1); - dim3 block = Traits::kThreadNum; - mla_kernel<<>>(params); -} - -} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h deleted file mode 100644 index c221d8c6..00000000 --- a/src/kernels/attention/mla_traits_sm80.h +++ /dev/null @@ -1,220 +0,0 @@ -#pragma once -#include -#include - -#include "cute_extensions.cuh" - -namespace llm { -using namespace cute; - -template -struct MLATraitsSM80 { - static constexpr int kHeadDim = HEAD_DIM; - static constexpr int kRopeHeadDim = ROPE_HEAD_DIM; - static constexpr int kBlockM = BLK_M; - static constexpr int kBlockN = BLK_N; - static constexpr int kBlockK = BLK_K; - static constexpr int kStages = STAGES; - static constexpr int kRowsPerMMA = 2; - - static_assert(kHeadDim % 64 == 0); - static_assert(kRopeHeadDim % 64 == 0); - - static_assert(kBlockM % 64 == 0); - static_assert(kBlockN % 16 == 0); - static_assert(kBlockK % 64 == 0); - static_assert(kStages == 1 || kStages == 2); - - static_assert(kHeadDim % kBlockK == 0); - // number of steps per stage - static constexpr int kSteps = kHeadDim / kBlockK; - - // helpful aliases - using DType = DTYPE; - using _BLK_M = Int; - using _BLK_N = Int; - using _BLK_K = Int; - using _STEPS = Int; - using _STAGES = Int; - using _HEAD_DIM = Int; - using _ROPE_HEAD_DIM = Int; - - // ******* Mainloop ******* - // TiledMMA (64x16x16) for gemm-I and gemm-II - // choose MMA_Atom based on Element type - using MMA_Atom_ = - std::conditional_t, - MMA_Atom, - MMA_Atom>; - - using TiledMma_64x32x16_ = - TiledMMA>, // warp layout 4x2x1 - Tile<_64, _32, _16>>; // Shape 64x32x16 - using TiledMma_64x16x16_ = - TiledMMA>, // warp layout 4x2x1 - Tile<_64, _16, _16>>; // Shape 64x16x16 - - // TiledMma for P = Softmax(Q*K^T), warp layout 4x2x1 - using TiledMma_QK = std::conditional_t; - - // TiledMma for O = P*V^T, warp layout 4x2x1 - using TiledMma_PV = TiledMma_64x32x16_; - - // Shared memory LayoutAtom for differnt block sizes - using SmemLayoutAtom_8x64 = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride<_64, _1>>{})); - using SmemLayoutAtom_8x32 = - decltype(composition(Swizzle<2, 3, 3>{}, - Layout, Stride<_32, _1>>{})); - using SmemLayoutAtom_8x16 = - decltype(composition(Swizzle<1, 3, 3>{}, - Layout, Stride<_16, _1>>{})); - - using SmemLayoutAtomK = std::conditional_t; - using SmemLayoutAtomN = - std::conditional_t>; - using SmemLayoutAtomR = std::conditional_t; - - // SMEM layout for Q/K/V/P - // Q smem: (BLK_M, BLK_K, STEPS) - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomK{}, - Shape<_BLK_M, _BLK_K, _STEPS>{})); - - // KV smem: (BLK_N, BLK_K, STEPS, STAGES) - using SmemLayoutKV = - decltype(tile_to_shape(SmemLayoutAtomK{}, - Shape<_BLK_N, _BLK_K, _STEPS, _STAGES>{})); - - // P smem: (BLK_M, BLK_N) - using SmemLayoutP = - decltype(tile_to_shape(SmemLayoutAtomN{}, Shape<_BLK_M, _BLK_N>{})); - - // V^T smem: (BLK_K, BLK_N, STEPS, STAGES) - using SmemLayoutVt = decltype(select<1, 0, 2, 3>(SmemLayoutKV{})); - - // QRope smem: (BLK_M, ROPE_HEAD_DIM) - using SmemLayoutQRope = - decltype(tile_to_shape(SmemLayoutAtomR{}, - Shape<_BLK_M, _ROPE_HEAD_DIM>{})); - - // KRoep smem: (BLK_N, ROPE_HEAD_DIM, STAGES) - using SmemLayoutKRope = - decltype(tile_to_shape(SmemLayoutAtomR{}, - Shape<_BLK_N, _ROPE_HEAD_DIM, _STAGES>{})); - - // Shared memory for reduce between warps - // rowmax/rowsum smem: (_BLK_M, _2) - using SmemLayoutRowmax = Layout>>; - using SmemLayoutRowsum = Layout>>; - - // Tiled copy for differnt block sizes - using GmemTiledCopy_32x64_ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) - Layout>{} // Val layout: 8 vals per read - )); - using GmemTiledCopy_16x128_ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - Layout, Stride<_16, _1>>{}, // Thr layout: (_16, _16) - Layout>{} // Val layout: 8 vals per read - )); - using GmemTiledCopy_16x64_ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - Layout, Stride<_16, _1>>{}, // Thr layout: (_16, _16) - Layout>{} // Val layout: 4 vals per read - )); - - // g2s tiled copy for q - using GmemTiledCopyQ = GmemTiledCopy_32x64_; - // g2s tiled copy for q_rope - using GmemTiledCopyQRope = GmemTiledCopy_32x64_; - - // g2s tiled copy for kv: (32x64), (16x64) or (16x128), - using GmemTiledCopyKV = - std::conditional_t>; - - // g2s tiled copy for k_rope: (32x64) or (16x64) - using GmemTiledCopyKRope = std::conditional_t; - - // s2r tiled copy for gemm-I S = Q*K^T - // warp layout: 4x2x1, tiledmma mxnxk: 64x32x16 or 64x16x16 - // Smem tiled copy for Q, 4 warps mxk: 64x16 - using SmemTiledCopyQ = - decltype(make_tiled_copy_A(Copy_Atom{}, - TiledMma_QK{})); - - using Copy_Atom_K_ = std::conditional_t, - Copy_Atom>; - // Smem tiled copy for KV, 2 warps nxk: 32x16 or 16x16 - using SmemTiledCopyK = - decltype(make_tiled_copy_B(Copy_Atom_K_{}, TiledMma_QK{})); - - // r2s tiled copy for gemm-I S - // use 128-bit vectorizing copy - using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; - - using SmemTiledCopyS = - decltype(make_tiled_copy_C(Copy_Atom{}, - TiledMma_QK{})); - - // s2r tiled copy for gemm-II: O = P*V^T - // warp layout: 4x2x1, TiledMma mxnxk: 64x32x16 - // Smem tiled copy for P, 4 warps mxk: 64x16 - using SmemTiledCopyP = - decltype(make_tiled_copy_A(Copy_Atom{}, - TiledMma_PV{})); - - // Smem tiled copy for V^T, 2 warps nxk: 32x16 - using SmemTiledCopyVt = - decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma_PV{})); - - // r2s tiled copy for gemm-II O - using SmemTiledCopyO = - decltype(make_tiled_copy_C(Copy_Atom{}, - TiledMma_PV{})); - - // ******* Epilogue ******* - - // O smem: (BLK_M, BLK_K, STEPS) - using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomK{}, - Shape<_BLK_M, _BLK_K, _STEPS>{})); - - // s2g tiled copy for O (32x64) - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) - Layout>{} // Val layout: 8 vals per read - )); - - // constexpr values for kernel launch - static constexpr size_t kThreadNum = size(TiledMma_PV{}); -}; - -} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp deleted file mode 100644 index 68cd79c9..00000000 --- a/src/kernels/attention/mla_traits_test.cpp +++ /dev/null @@ -1,65 +0,0 @@ -#include - -#include - -#include "cute_extensions.cuh" -#include "gather_tensor.hpp" -#include "mla_traits_sm80.h" - -namespace llm { - -using namespace cute; - -template -void test_mla_traits() { - // type alias - using TiledMma_QK = typename Traits::TiledMma_QK; - - using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutKV = typename Traits::SmemLayoutKV; - using SmemLayoutQRope = typename Traits::SmemLayoutQRope; - using SmemLayoutKRope = typename Traits::SmemLayoutKRope; - using SmemLayoutVt = typename Traits::SmemLayoutVt; - using SmemLayoutO = typename Traits::SmemLayoutO; - - using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; - using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; - using GmemTiledCopyO = typename Traits::GmemTiledCopyO; - - using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; - using SmemTiledCopyK = typename Traits::SmemTiledCopyK; - using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; - using SmemTiledCopyO = typename Traits::SmemTiledCopyO; - - // test layout conversation - Tensor sQ = make_tensor(counting_iterator(0), SmemLayoutQ{}); - Tensor sKV = make_tensor(counting_iterator(0), SmemLayoutKV{}); - Tensor sVt = make_tensor(sKV.data(), SmemLayoutVt{}); - - Tensor sQ_rope = make_tensor(counting_iterator(0), SmemLayoutQRope{}); - Tensor sKV_rope = make_tensor(counting_iterator(0), SmemLayoutKRope{}); - - // print("sQ:"); print(sQ);print("\n"); - // print("sKV:"); print(sKV);print("\n"); - // print("sVt:"); print(sVt);print("\n"); - - // print("sQ_rope:"); print(sQ_rope);print("\n"); - // print("sKV_rope:"); print(sKV_rope);print("\n"); - - TiledMma_QK tiled_mma_qk; - auto thr_mma_qk = tiled_mma_qk.get_slice(0); - // auto tOrVt = thr_mma_qk.partition_fragment_B(sVt); - // TODO: add tests for layout conformance -} - -TEST(MLATraitsTest, TraitsSM80) { - test_mla_traits>(); -} - -} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/safe_copy.h similarity index 69% rename from src/kernels/attention/cute_extensions.cuh rename to src/kernels/attention/safe_copy.h index 14273bad..8312fe1c 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/safe_copy.h @@ -73,15 +73,15 @@ template CUTE_HOST_DEVICE void safe_copy( const TiledCopy& tiled_copy, const TensorS& src, // (CPY, CPY_M/N, CPY_K) TensorD& dst, // (CPY, CPY_M/N, CPY_K) - const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k) - const Coord& max_coord // max_coord(blk_m/n, blk_k) + const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (m/n, k) + const Residue& residue // (m/n, k) ) { CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(dst)); // CPY == CPY CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(identity)); // CPY == CPY @@ -96,10 +96,10 @@ CUTE_HOST_DEVICE void safe_copy( // handle both m/n and k oob CUTE_UNROLL for (int mi = 0; mi < size<1>(src); ++mi) { - if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { + if (elem_less<0>(identity(_0{}, mi, _0{}), residue)) { CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { - if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + if (elem_less<1>(identity(_0{}, _0{}, ki), residue)) { copy(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } else if constexpr (ZFILL_K) { zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); @@ -111,7 +111,7 @@ CUTE_HOST_DEVICE void safe_copy( // still need to handle k oob even if m/n is not zfilled CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { - if (!elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + if (!elem_less<1>(identity(_0{}, _0{}, ki), residue)) { zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } } @@ -121,7 +121,7 @@ CUTE_HOST_DEVICE void safe_copy( // only handle m/n oob CUTE_UNROLL for (int mi = 0; mi < size<1>(src); ++mi) { - if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { + if (elem_less<0>(identity(_0{}, mi, _0{}), residue)) { copy(copy_atom, src(_, mi, _), dst(_, mi, _)); } else if constexpr (ZFILL_MN) { zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); @@ -131,7 +131,7 @@ CUTE_HOST_DEVICE void safe_copy( // only handle k oob CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { - if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + if (elem_less<1>(identity(_0{}, _0{}, ki), residue)) { copy(copy_atom, src(_, _, ki), dst(_, _, ki)); } else if constexpr (ZFILL_K) { zfill(copy_atom, src(_, _, ki), dst(_, _, ki)); @@ -143,6 +143,60 @@ CUTE_HOST_DEVICE void safe_copy( } } +template +CUTE_HOST_DEVICE void safe_copy( + const TiledCopy& tiled_copy, + const TensorS& src, // (CPY, CPY_M/N, CPY_K) + TensorD&& dst, // (CPY, CPY_M/N, CPY_K) + const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (m/n, k) + const Residue& residue // (m/n, k) +) { + safe_copy( + tiled_copy, src, dst, identity, residue); +} + +template +CUTE_HOST_DEVICE void safe_copy( + const TiledCopy& tiled_copy, + const TensorS& src, // (CPY, CPY_M/N, CPY_K, k) + TensorD& dst, // (CPY, CPY_M/N, CPY_K, k) + const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (m/n, k) + const Residue& residue // (m/n, k) +) { + CUTE_UNROLL + for (int k = 0; k < size<3>(src); ++k) { + safe_copy(tiled_copy, + src(_, _, _, k), + dst(_, _, _, k), + identity(_, _, _, k), + residue); + } +} + template CUTE_HOST_DEVICE void safe_copy( const TiledCopy& tiled_copy, const TensorS& src, // (CPY, CPY_K) TensorD& dst, // (CPY, CPY_K) - const TensorC& identity, // (CPY, CPY_K) -> (blk_k) - const Coord& max_coord // max_coord(blk_k) + const TensorC& identity, // (CPY, CPY_K) -> (k) + const Residue& residue // (k) ) { CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(dst)); // CPY == CPY CUTE_STATIC_ASSERT_V(size<0>(src) == size<0>(identity)); // CPY == CPY @@ -172,7 +226,7 @@ CUTE_HOST_DEVICE void safe_copy( // handle k oob CUTE_UNROLL for (int ki = 0; ki < size<1>(src); ++ki) { - if (elem_less<0>(identity(_0{}, ki), max_coord)) { + if (elem_less<0>(identity(_0{}, ki), residue)) { copy(copy_atom, src(_, ki), dst(_, ki)); } else if constexpr (ZFILL_K) { zfill(copy_atom, src(_, ki), dst(_, ki)); @@ -184,4 +238,4 @@ CUTE_HOST_DEVICE void safe_copy( } } -} // namespace cute \ No newline at end of file +} // namespace cute diff --git a/src/kernels/attention/sm80_collective_epilogue.cuh b/src/kernels/attention/sm80_collective_epilogue.cuh index 6a00ffc7..7093e703 100644 --- a/src/kernels/attention/sm80_collective_epilogue.cuh +++ b/src/kernels/attention/sm80_collective_epilogue.cuh @@ -7,8 +7,8 @@ #include #include -#include "cute_extensions.cuh" #include "fast_cast.cuh" +#include "safe_copy.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/sm80_collective_mha.cuh index d4413569..81b03cda 100644 --- a/src/kernels/attention/sm80_collective_mha.cuh +++ b/src/kernels/attention/sm80_collective_mha.cuh @@ -8,10 +8,10 @@ #include #include -#include "cute_extensions.cuh" #include "fast_cast.cuh" #include "layout_convertor.h" #include "mask.h" +#include "safe_copy.h" namespace llm { diff --git a/src/kernels/attention/sm80_collective_mla.cuh b/src/kernels/attention/sm80_collective_mla.cuh new file mode 100644 index 00000000..670cdfd2 --- /dev/null +++ b/src/kernels/attention/sm80_collective_mla.cuh @@ -0,0 +1,740 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "fast_cast.cuh" +#include "layout_convertor.h" +#include "mask.h" +#include "safe_copy.h" + +namespace llm { + +using namespace cute; + +template +struct Sm80CollectiveMla { + using TileShape = TileShape_; + using Element = Element_; + using ElementAccum = float; + + static constexpr int kHeadDim = HeadDim_; + static constexpr int kRopeHeadDim = RopeHeadDim_; + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockN = get<1>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + static constexpr int kStages = Stages; + + static_assert(kHeadDim % 64 == 0); + static_assert(kRopeHeadDim % 64 == 0); + + static_assert(kBlockM % 64 == 0); + static_assert(kBlockN % 16 == 0); + static_assert(kBlockK % 64 == 0); + static_assert(kStages == 1 || kStages == 2); + + // number of steps per stage + static_assert(kHeadDim % kBlockK == 0); + static constexpr int kSteps = kHeadDim / kBlockK; + + using BLK_M = Int; + using BLK_N = Int; + using BLK_K = Int; + using HEAD_DIM = Int; + using ROPE_HEAD_DIM = Int; + using STEPS = Int; + using STAGES = Int; + + // TiledMMA (64x16x16) for gemm-I and gemm-II + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + + using TiledMma_64x32x16_ = + TiledMMA>, // warp layout 4x2x1 + Tile<_64, _32, _16>>; // Shape 64x32x16 + using TiledMma_64x16x16_ = + TiledMMA>, // warp layout 4x2x1 + Tile<_64, _16, _16>>; // Shape 64x16x16 + + // TiledMma for P = Softmax(Q*K^T), warp layout 4x2x1 + using TiledMma_QK = std::conditional_t; + + // TiledMma for O = P*V^T, warp layout 4x2x1 + using TiledMma_PV = TiledMma_64x32x16_; + + static constexpr int kRowsPerMMA = 2; + static constexpr int kMmaThreads = + max(size(TiledMma_QK{}), size(TiledMma_PV{})); + + // Shared memory LayoutAtom for differnt block sizes + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + using SmemLayoutAtom_8x16 = + decltype(composition(Swizzle<1, 3, 3>{}, + Layout, Stride<_16, _1>>{})); + + using SmemLayoutAtomK = std::conditional_t; + using SmemLayoutAtomN = + std::conditional_t>; + using SmemLayoutAtomR = std::conditional_t; + + // SMEM layout for Q/K/V/P + // Q smem: (BLK_M, BLK_K, k) + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomK{}, Shape{})); + + // KV smem: (BLK_N, BLK_K, k, STAGES) + using SmemLayoutKV = + decltype(tile_to_shape(SmemLayoutAtomK{}, + Shape{})); + + // P smem: (BLK_M, BLK_N) + using SmemLayoutP = + decltype(tile_to_shape(SmemLayoutAtomN{}, Shape{})); + + // V^T smem: (BLK_K, BLK_N, k, STAGES) + using SmemLayoutVt = decltype(select<1, 0, 2, 3>(SmemLayoutKV{})); + + // QRope smem: (BLK_M, ROPE_HEAD_DIM) + using SmemLayoutQRope = + decltype(tile_to_shape(SmemLayoutAtomR{}, Shape{})); + + // KRoep smem: (BLK_N, ROPE_HEAD_DIM, STAGES) + using SmemLayoutKRope = + decltype(tile_to_shape(SmemLayoutAtomR{}, + Shape{})); + + // Shared memory for reduce between warps + // rowmax/rowsum smem: (_BLK_M, _2) + using SmemLayoutRowmax = Layout>>; + using SmemLayoutRowsum = Layout>>; + + // Tiled copy for differnt block sizes + using GmemTiledCopy_32x64_ = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) + Layout>{} // Val layout: 8 vals per read + )); + using GmemTiledCopy_16x128_ = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, // Thr layout: (_16, _16) + Layout>{} // Val layout: 8 vals per read + )); + using GmemTiledCopy_16x64_ = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, // Thr layout: (_16, _16) + Layout>{} // Val layout: 4 vals per read + )); + + // g2s tiled copy for q + using GmemTiledCopyQ = GmemTiledCopy_32x64_; + // g2s tiled copy for q_rope + using GmemTiledCopyQRope = GmemTiledCopy_32x64_; + + // g2s tiled copy for kv: (32x64), (16x64) or (16x128), + using GmemTiledCopyKV = + std::conditional_t>; + + // g2s tiled copy for k_rope: (32x64) or (16x64) + using GmemTiledCopyKRope = std::conditional_t; + + // s2r tiled copy for gemm-I S = Q*K^T + // warp layout: 4x2x1, tiledmma mxnxk: 64x32x16 or 64x16x16 + // Smem tiled copy for Q, 4 warps mxk: 64x16 + using SmemTiledCopyQ = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma_QK{})); + + using Copy_Atom_K_ = + std::conditional_t, + Copy_Atom>; + // Smem tiled copy for KV, 2 warps nxk: 32x16 or 16x16 + using SmemTiledCopyK = + decltype(make_tiled_copy_B(Copy_Atom_K_{}, TiledMma_QK{})); + + // r2s tiled copy for gemm-I S + // use 128-bit vectorizing copy + using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + + using SmemTiledCopyS = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma_QK{})); + + // s2r tiled copy for gemm-II: O = P*V^T + // warp layout: 4x2x1, TiledMma mxnxk: 64x32x16 + // Smem tiled copy for P, 4 warps mxk: 64x16 + using SmemTiledCopyP = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma_PV{})); + + // Smem tiled copy for V^T, 2 warps nxk: 32x16 + using SmemTiledCopyVt = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma_PV{})); + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> q_smem; + union { + cute::array_aligned> kv_smem; + cute::array_aligned> vt_smem; + }; + cute::array_aligned> p_smem; + cute::array_aligned> q_rope_smem; + cute::array_aligned> k_rope_smem; + union { + cute::array_aligned> row_max_smem; + cute::array_aligned> row_sum_smem; + }; + }; + + // Host side arguments + struct Arguments { + FastDivmod group_size; + }; + + // Device side params + using Params = Arguments; + + // Convert host side arguments to device side params + static Params to_underlying_arguments(Arguments const& args) { + // no convertion needed. + return args; + } + + // returns false if the block has been skipped + template + CUTE_DEVICE void operator()( + const Params& params, + const TensorQ& gQ, // (BLK_M, BLK_K, k) + const TensorCQ& cQ, // (BLK_M, BLK_K, k) => (M, K) + const TensorKV& gKV, // (BLK_N, BLK_K, n, k) + const TensorCKV& cKV, // (BLK_N, BLK_K, n, k) => (N, K) + const TensorQR& gQ_rope, // (BLK_M, ROPE_HEAD_DIM) + const TensorCQR& cQ_rope, // (BLK_M, ROPE_HEAD_DIM) =>(M, K) + const TensorKR& gK_rope, // (BLK_N, HEAD_DIM, n) + const TensorCKR& cK_rope, // (BLK_N, HEAD_DIM, n) => (N, K) + FrgTensor& tOrO, // (BLK_N, ROPE_HEAD_DIM, n) + Softmax& softmax, + int tidx, + const BlockCoordMNK& block_coord_mnk, + const ResidueMNK& residue_mnk, + const RopeResidueMNK& rope_residue_mnk, + char* smem) { + static_assert(is_rmem::value, + "Accum tensor must be rmem resident."); + static_assert(is_gmem::value, "Q tensor must be gmem resident."); + static_assert(is_gmem::value, "KV tensor must be gmem resident."); + static_assert(is_gmem::value, + "Q_Rope tensor must be gmem resident."); + static_assert(is_gmem::value, + "K_Rope tensor must be gmem resident."); + + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockN = get<1>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + + const int m_block_idx = get<1>(block_coord_mnk); + const int q_packed_len = get<0>(residue_mnk); + const int kv_len = get<1>(residue_mnk); + + const auto& group_size = params.group_size; + + const int q_len = q_packed_len / group_size; + + // Construct shared memory tiles + auto& ss = *reinterpret_cast(smem); + + // (BLK_M, BLK_K, k), k-major + Tensor sQ = make_tensor(make_smem_ptr(ss.q_smem.data()), SmemLayoutQ{}); + // (BLK_N, BLK_K, k, STAGES), k-major + Tensor sKV = make_tensor(make_smem_ptr(ss.kv_smem.data()), SmemLayoutKV{}); + + // (BLK_M, BLK_N), k-major + Tensor sP = make_tensor(make_smem_ptr(ss.p_smem.data()), SmemLayoutP{}); + + // (BLK_M, ROPE_HEAD_DIM), k-major + Tensor sQ_rope = + make_tensor(make_smem_ptr(ss.q_rope_smem.data()), SmemLayoutQRope{}); + // (BLK_N, ROPE_HEAD_DIM, STAGES), k-major + Tensor sK_rope = + make_tensor(make_smem_ptr(ss.k_rope_smem.data()), SmemLayoutKRope{}); + + // Tensor for V^t; used in GEMM-II. + // (BLK_K, BLK_N, k, STAGES) + Tensor sVt = make_tensor(make_smem_ptr(ss.vt_smem.data()), SmemLayoutVt{}); + + // (BLK_M, 2) + Tensor sRowmax = + make_tensor(make_smem_ptr(ss.row_max_smem.data()), SmemLayoutRowmax{}); + Tensor sRowsum = + make_tensor(make_smem_ptr(ss.row_max_smem.data()), SmemLayoutRowsum{}); + + // reduce rowmax/rowsum accross 2 warps via shared memory + // thread layout: (32, (4, 2)), each thread process 2 rows + // (store_idx, load_idx) = (0, 64) or (1, 65), ... + const int row_store_idx = tidx / 4 * 2; + const int row_load_idx = row_store_idx ^ kBlockM; + auto reduce_rowmax = [&](auto& row_max) { + CUTE_UNROLL + for (int i = 0; i < size(row_max); ++i) { + sRowmax(row_store_idx + i) = row_max(i); + } + __syncthreads(); + CUTE_UNROLL + for (int i = 0; i < size(row_max); ++i) { + row_max(i) = max(row_max(i), sRowmax(row_load_idx + i)); + } + }; + auto reduce_rowsum = [&](auto& row_sum) { + CUTE_UNROLL + for (int i = 0; i < size(row_sum); ++i) { + sRowsum(row_store_idx + i) = row_sum(i); + } + __syncthreads(); + CUTE_UNROLL + for (int i = 0; i < size(row_sum); ++i) { + row_sum(i) += sRowsum(row_load_idx + i); + } + }; + + // g2s tiled copy for q + GmemTiledCopyQ gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx); + + // coordinate tensor for oob handling + // (CPY, CPY_M, CPY_K, k) => (M, K) + Tensor tGcQ = gmem_thr_copy_Q.partition_S(cQ); + // (CPY, CPY_M, CPY_K, k) + Tensor tGgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tGsQ = gmem_thr_copy_Q.partition_D(sQ); + const auto residue_mk = select<0, 2>(residue_mnk); + + auto produce_q = [&](int ki) { + safe_copy(gmem_tiled_copy_Q, + tGgQ(_, _, _, ki), + tGsQ(_, _, _, ki), + tGcQ(_, _, _, ki), + residue_mk); + }; + + // g2s tiled copy for q_rope + GmemTiledCopyQRope gmem_tiled_copy_Q_rope; + auto gmem_thr_copy_Q_rope = gmem_tiled_copy_Q_rope.get_slice(tidx); + + // (CPY, CPY_M, CPY_K) => (blk_m, rope_head_dim) + Tensor tGcQ_rope = gmem_thr_copy_Q_rope.partition_S(cQ_rope); + Tensor tGgQ_rope = gmem_thr_copy_Q_rope.partition_S(gQ_rope); + Tensor tGsQ_rope = gmem_thr_copy_Q_rope.partition_D(sQ_rope); + const auto rope_residue_mk = select<0, 2>(rope_residue_mnk); + + auto produce_q_rope = [&]() { + safe_copy(gmem_tiled_copy_Q_rope, + tGgQ_rope, + tGsQ_rope, + tGcQ_rope, + rope_residue_mk); + }; + + // g2s tiled copy for kv + GmemTiledCopyKV gmem_tiled_copy_KV; + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx); + + // (CPY, CPY_M, CPY_K, n, k) => (N, K) + Tensor tGcKV = gmem_thr_copy_KV.partition_S(cKV); + // (CPY, CPY_M, CPY_K, n, k) + auto tGgKV = gmem_thr_copy_KV.partition_S(gKV); + // (CPY, CPY_M, CPY_K, k, STAGES) + auto tGsKV = gmem_thr_copy_KV.partition_D(sKV); + const auto residue_nk = select<1, 2>(residue_mnk); + + auto produce_kv = [&](int ni, int ki, int stage) { + safe_copy(gmem_tiled_copy_KV, + tGgKV(_, _, _, ni, ki), + tGsKV(_, _, _, ki, stage), + tGcKV(_, _, _, ni, ki), + residue_nk); + }; + + auto produce_kv_no_oob = [&](int ni, int ki, int stage) { + cute::copy(gmem_tiled_copy_KV, + tGgKV(_, _, _, ni, ki), + tGsKV(_, _, _, ki, stage)); + }; + + // g2s tiled copy for k_rope + GmemTiledCopyKRope gmem_tiled_copy_K_rope; + auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx); + + // (CPY, CPY_M, CPY_K, n) => (N, K) + Tensor tGcK_rope = gmem_thr_copy_K_rope.partition_S(cK_rope); + // (CPY, CPY_M, CPY_K, n) + Tensor tGgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope); + // (CPY, CPY_M, CPY_K, STAGES) + Tensor tGsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope); + const auto rope_residue_nk = select<1, 2>(rope_residue_mnk); + + auto produce_k_rope = [&](int ni, int stage) { + safe_copy(gmem_tiled_copy_K_rope, + tGgK_rope(_, _, _, ni), + tGsK_rope(_, _, _, stage), + tGcK_rope(_, _, _, ni), + rope_residue_nk); + }; + + auto produce_k_rope_no_oob = [&](int ni, int stage) { + cute::copy(gmem_tiled_copy_K_rope, + tGgK_rope(_, _, _, ni), + tGsK_rope(_, _, _, stage)); + }; + + // GEMM-I: S = Q@K.T + TiledMma_QK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_slice(tidx); + // sQ: (BLK_M, BLK_K, k) + auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{})); + // sKV: (BLK_N, BLK_K, k, STAGES) + auto tSrK = thr_mma_qk.partition_fragment_B(sKV(_, _, _0{}, _0{})); + + // s2r tiled copy for q/q_rope + SmemTiledCopyQ smem_tiled_copy_Q; + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx); + // (CPY, CPY_M, CPY_K, k) + auto tCsQ = smem_thr_copy_Q.partition_S(sQ); + // (CPY, CPY_M, CPY_K) + auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); + + // s2r tiled copy for k/k_rope + SmemTiledCopyK smem_tiled_copy_K; + auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx); + // (CPY, CPY_N, CPY_K, k, STAGES) + auto tCsK = smem_thr_copy_K.partition_S(sKV); + // (CPY, CPY_N, CPY_K) + auto tCrK = smem_thr_copy_K.retile_D(tSrK); + + // S = Q@K.T + // tSrS: (MMA,MMA_M,MMA_N) + auto compute_qk = [&](auto& tSrS, int step, int stage) { + // tCsQ: (CPY, CPY_M, CPY_K, k) + auto tCsQ_s = tCsQ(_, _, _, step); + // TCsK: (CPY, CPY_N, CPY_K, k, STAGES) + auto tCsK_s = tCsK(_, _, _, step, stage); + // prefetch kv + cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{})); + + CUTE_UNROLL + for (int k = 0; k < size<2>(tCsQ_s); ++k) { + // prefetch next kv + if (k != size<2>(tCsQ_s) - 1) { + const auto next_k = k + 1; + cute::copy( + smem_tiled_copy_Q, tCsQ_s(_, _, next_k), tCrQ(_, _, next_k)); + cute::copy( + smem_tiled_copy_K, tCsK_s(_, _, next_k), tCrK(_, _, next_k)); + } + cute::gemm(tiled_mma_qk, tSrQ(_, _, k), tSrK(_, _, k), tSrS); + } + }; + + // sQ_rope: (BLK_N, ROPE_HEAD_DIM) + auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope); + // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) + auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope(_, _, _0{})); + // (CPY, CPY_M, CPY_K) + auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); + // (CPY, CPY_M, CPY_K) + auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); + // (CPY, CPY_N, CPY_K, STAGES) + auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); + // (CPY, CPY_N, CPY_K) + auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); + auto compute_qk_rope = [&](auto& tSrS, int stage) { + auto tCsK_rope_s = tCsK_rope(_, _, _, stage); + cute::copy( + smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); + cute::copy( + smem_tiled_copy_K, tCsK_rope_s(_, _, _0{}), tCrK_rope(_, _, _0{})); + + CUTE_UNROLL + for (int k = 0; k < size<2>(tCsQ_rope); ++k) { + if (k != size<2>(tCsQ_rope) - 1) { + const auto next_k = k + 1; + cute::copy(smem_tiled_copy_Q, + tCsQ_rope(_, _, next_k), + tCrQ_rope(_, _, next_k)); + cute::copy(smem_tiled_copy_K, + tCsK_rope_s(_, _, next_k), + tCrK_rope(_, _, next_k)); + } + cute::gemm(tiled_mma_qk, tSrQ_rope(_, _, k), tSrK_rope(_, _, k), tSrS); + } + }; + + // GEMM-II: O = softmax(S)@V + TiledMma_PV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_slice(tidx); + // sP: (BLK_M, BLK_N) + auto tOrP = thr_mma_pv.partition_fragment_A(sP); + // sVt: (BLK_K, BLK_N, k, STAGES) + auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}, _0{})); + + // s2r tiled copy for p + SmemTiledCopyP smem_tiled_copy_P; + auto smem_thr_copy_P = smem_tiled_copy_P.get_slice(tidx); + // (CPY, CPY_M, CPY_K) + auto tCsP = smem_thr_copy_P.partition_S(sP); + // (CPY, CPY_M, CPY_K) + auto tCrP = smem_thr_copy_P.retile_D(tOrP); + + // s2r tiled copy for vt + SmemTiledCopyVt smem_tiled_copy_Vt; + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx); + // (CPY, CPY_N, CPY_K, k, STAGES) + auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); + // (CPY, CPY_N, CPY_K) + auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); + + // O = P*V = softmax(S)*V + // tOrO: (MMA,MMA_M,MMA_K,k) + auto compute_pv = [&](auto& tOrO, int step, int stage) { + auto tOrO_s = tOrO(_, _, _, step); + auto tCsVt_s = tCsVt(_, _, _, step, stage); + + cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{})); + cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{})); + + CUTE_UNROLL + for (int k = 0; k < size<2>(tCsVt_s); ++k) { + if (k != size<2>(tCsVt_s) - 1) { + const auto next_k = k + 1; + cute::copy(smem_tiled_copy_P, tCsP(_, _, next_k), tCrP(_, _, next_k)); + cute::copy( + smem_tiled_copy_Vt, tCsVt_s(_, _, next_k), tCrVt(_, _, next_k)); + } + cute::gemm(tiled_mma_pv, tCrP(_, _, k), tOrVt(_, _, k), tOrO_s); + } + }; + + // r2s tiled copy for S/P + SmemTiledCopyS smem_tiled_copy_S; + auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(tidx); + auto store_s_to_smem = [&](const auto& tSrS) { + // cast Accumulator to Element type + auto tSrS_ = make_tensor_like(tSrS); + fast_cast(tSrS, tSrS_); + // copy scores from rmem to smem + auto tCrS = smem_thr_copy_S.retile_S(tSrS_); + auto tCsS = smem_thr_copy_S.partition_D(sP); + cute::copy(smem_tiled_copy_S, tCrS, tCsS); + }; + + // output accumulator: (MMA,MMA_M,MMA_K,k) + // auto tOrO = + // partition_fragment_C(tiled_mma_pv, Shape{}); + // clear(tOrO); + auto tOrO_mn = + make_tensor(tOrO.data(), LayoutConvertor::to_mns(tOrO.layout())); + + const int n_block_min = 0; + // process kv in range: [0, kv_idx_max) + const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len; + const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); + const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); + + if (n_block_min >= n_block_max) { + return; + } + + // ############### Prologue ############### + // g2s async data copy pipelines + // | stage | queue | + // | 1 | [k_r, kv0, kv1] | + // | 2 | [k_r, kv0, kv1, (nop, nop, nop, k_r, kv0, kv1)] | + // ^ kWait = (kSteps + 1) * (2*kStages - 1) - 1 + constexpr int kWait = (kSteps + 1) * (2 * kStages - 1) - 1; + // produce q_rope/q + produce_q_rope(); + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + produce_q(ki); + } + // produce k_rope/kv + CUTE_UNROLL + for (int stage = 0; stage < kStages; ++stage) { + const int ni = n_block_max - 1 - stage; + // insert nops between stages for a perfect pipeline + if (stage != 0) { + cp_async_fence(); + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + cp_async_fence(); + } + } + // handle oob kv + if (ni >= n_block_min) { + stage == 0 ? produce_k_rope(ni, stage) + : produce_k_rope_no_oob(ni, stage); + cp_async_fence(); + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + stage == 0 ? produce_kv(ni, ki, stage) + : produce_kv_no_oob(ni, ki, stage); + cp_async_fence(); + } + } else { + cp_async_fence(); + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + cp_async_fence(); + } + } + } + + // ############### Mainloop ############### + // attention score accumulator, (MMA,MMA_M,MMA_N) + auto tSrS = partition_fragment_C(tiled_mma_qk, Shape{}); + auto tScS = + thr_mma_qk.partition_C(make_identity_tensor(Shape{})); + auto tSrS_mn = + make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout())); + auto tScS_mn = + make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout())); + + constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); + using Mask = Mask; + Mask mask(q_len, kv_len, group_size, /*sliding_window=*/kv_len); + + constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; + const int n_blocks = n_block_max - n_block_min; + int stage = 0; + CUTE_NO_UNROLL + for (int i = 0; i < n_blocks; ++i) { + const int ni = n_block_max - 1 - i; + clear(tSrS); + + cp_async_wait(); + __syncthreads(); + + // 1> S = Q_rope@K_rope.T + compute_qk_rope(tSrS, stage); + cp_async_fence(); + + // 2> S += Q@K.T + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + cp_async_wait(); + __syncthreads(); + + compute_qk(tSrS, ki, stage); + cp_async_fence(); + } + + // apply mask + if (i < n_oob_mask) { + mask.apply(tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); + } else { + mask.apply( + tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); + } + + softmax.rescale(tSrS_mn, tOrO_mn, reduce_rowmax); + + // save tSrS from rmem to smem + store_s_to_smem(tSrS); + __syncthreads(); + + // 3> O = softmax(S)*V + const auto next_ni = ni - kStages; + if (next_ni >= n_block_min) { + produce_k_rope_no_oob(next_ni, stage); + cp_async_fence(); + + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + compute_pv(tOrO, ki, stage); + __syncthreads(); + produce_kv_no_oob(next_ni, ki, stage); + cp_async_fence(); + } + } else { + cp_async_fence(); + CUTE_UNROLL + for (int ki = 0; ki < kSteps; ++ki) { + compute_pv(tOrO, ki, stage); + cp_async_fence(); + } + } + + // move to next stage + if constexpr (kStages == 1) { + // do nothing + } else if constexpr (kStages == 2) { + stage = stage ^ 1; + } else { + stage = (stage + 1) % kStages; + } + } + // normalize output: o /= rowsum + softmax.finalize(tOrO_mn, reduce_rowsum); + } +}; + +} // namespace llm diff --git a/src/kernels/attention/sm80_collective_mla_epilogue.cuh b/src/kernels/attention/sm80_collective_mla_epilogue.cuh new file mode 100644 index 00000000..c53cbba7 --- /dev/null +++ b/src/kernels/attention/sm80_collective_mla_epilogue.cuh @@ -0,0 +1,127 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "fast_cast.cuh" +#include "safe_copy.h" + +namespace llm { +using namespace cute; + +template +struct Sm80CollectiveMlaEpilogue { + using TileShape = TileShape_; + using Element = Element_; + + static constexpr int kHeadDim = HeadDim_; + + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + // number of steps per stage + static constexpr int kSteps = kHeadDim / kBlockK; + + using BLK_M = Int; + using BLK_K = Int; + using HEAD_DIM = Int; + using STEPS = Int; + + // Shared memory LayoutAtom for differnt block sizes + using SmemLayoutAtom_8x64 = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_64, _1>>{})); + using SmemLayoutAtom_8x32 = + decltype(composition(Swizzle<2, 3, 3>{}, + Layout, Stride<_32, _1>>{})); + using SmemLayoutAtom_ = std::conditional_t; + + // Q smem: (BLK_M, HEAD_DIM) + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + + // use 128-bit vectorizing copy + using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>; + + // r2s copy atom for O + using SmemCopyAtom_ = Copy_Atom; + + // s2g tiled copy for O + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Stride<_8, _1>>{}, // Thr layout: (_32, _8) + Layout>{} // Val layout: 8 vals per read + )); + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_o; + }; + + // Host side kernel arguments + struct Arguments {}; + + // Device side kernel params + using Params = Arguments; + + // Convert host side arguments to device side params + static Params to_underlying_arguments(Arguments const& args) { return args; } + + template + CUTE_DEVICE void operator()( + const Params& /*params*/, + const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N, k) + TiledMma tiled_mma, + TensorO& gO, // (BLK_M, BLK_K, k) + const TensorCO& cO, // (BLK_M, BLK_K, k) => (m, k) + int tidx, + const ResidueMNK& residue_mnk, + char* smem) { + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + + // Smem + auto& ss = *reinterpret_cast(smem); + // (BLK_M, BLK_K, k) + Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{}); + + // 1. cast output from ElementAccumulator to Element + auto tOrO = make_tensor_like(tOrAccO); + fast_cast(tOrAccO, tOrO); + + // 2. copy output from reg to smem + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtom_{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + auto tSrO = smem_thr_copy_O.retile_S(tOrO); + auto tSsO = smem_thr_copy_O.partition_D(sO); + cute::copy(smem_tiled_copy_O, tSrO, tSsO); + + // wait for smem copy done before gmem copy + __syncthreads(); + + // 3. copy output from smem to gmem + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + + auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K, k) + auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K, k) + + // (CPY,CPY_M,CPY_K, k) -> (m, k) + auto tOcO = gmem_thr_copy_O.partition_D(cO); + auto residue_mk = select<0, 2>(residue_mnk); + safe_copy( + gmem_tiled_copy_O, tOsO, tOgO, tOcO, residue_mk); + } +}; +} // namespace llm diff --git a/src/kernels/attention/sm80_kernel_mla.cuh b/src/kernels/attention/sm80_kernel_mla.cuh new file mode 100644 index 00000000..c5b127df --- /dev/null +++ b/src/kernels/attention/sm80_kernel_mla.cuh @@ -0,0 +1,172 @@ +#pragma once + +#include +#include + +#include +#include + +#include "mla_tile.h" +#include "online_softmax.cuh" + +namespace llm { + +using namespace cute; + +template +class Sm80KernelMla { + public: + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + using TileScheduler = TileScheduler_; + + using TiledMma_PV = typename CollectiveMainloop::TiledMma_PV; + + using Element = typename CollectiveMainloop::Element; + using BLK_M = typename CollectiveMainloop::BLK_M; + using BLK_N = typename CollectiveMainloop::BLK_N; + using BLK_K = typename CollectiveMainloop::BLK_K; + using HEAD_DIM = typename CollectiveMainloop::HEAD_DIM; + using ROPE_HEAD_DIM = typename CollectiveMainloop::ROPE_HEAD_DIM; + using STEPS = typename CollectiveMainloop::STEPS; + + static constexpr int kBlockM = CollectiveMainloop::kBlockM; + + static constexpr int kRowsPerMMA = CollectiveMainloop::kRowsPerMMA; + + static constexpr int kSharedStorageSize = + cute::max(sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage)); + + static constexpr int kMmaThreads = CollectiveMainloop::kMmaThreads; + + // Kernel params + using MainloopParams = typename CollectiveMainloop::Params; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileSchedulerParams = typename TileScheduler::Params; + + // returns grid and block shape for kernel launch + using TileSchedulerArgs = typename TileScheduler::Arguments; + static dim3 get_grid_shape(TileSchedulerArgs const& args) { + return TileScheduler::get_grid_shape(args); + } + static dim3 get_block_shape() { return kMmaThreads; } + + template + CUTE_DEVICE void operator()(const Params& params, + const TileSchedulerParams& scheduler_params, + char* smem) { + CollectiveMainloop mha; + CollectiveEpilogue epilogue; + TileScheduler scheduler(scheduler_params); + + // construct params + MainloopParams mainloop_params{params.group_size}; + EpilogueParams epilogue_params; + + // process each block + const auto& group_size = params.group_size; + + for (const auto block_coord : scheduler) { + // block coord: (batch_idx, m_block_idx, kv_head_idx) + const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord; + const auto tidx = threadIdx.x; + + // Q/O: (q_packed_len, HEAD_DIM) + // Q_ROPE: (q_packed_len, ROPE_HEAD_DIM) + MLATile tile(params, batch_idx); + auto [Q, Q_ROPE, O] = tile.template get_qo_tile(); + // KV: (kv_len, HEAD_DIM) + // K_ROPE: (kv_len, ROPE_HEAD_DIM) + auto [KV, K_ROPE] = tile.template get_kv_tile(); + + // problem shape + const int q_packed_len = size<0>(Q); + const int q_len = q_packed_len / group_size; + const int kv_len = size<0>(KV); + + if (m_block_idx * kBlockM >= size<0>(Q)) { + // m out of bound, return + return; + } + + const auto head_dim = params.head_dim; + auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); + const auto residue_mnk = make_tuple(q_packed_len, kv_len, head_dim); + const auto rope_residue_mnk = + make_tuple(q_packed_len, kv_len, ROPE_HEAD_DIM{}); + + // (BLK_M, BLK_K, k) + Tensor gQ = + local_tile(Q, Shape{}, make_coord(m_block_idx, _)); + Tensor gO = + local_tile(O, Shape{}, make_coord(m_block_idx, _)); + // (BLK_M, BLK_K, k) => (M, K) + Tensor cQ = local_tile(make_identity_tensor(Q.shape()), + Shape{}, + make_coord(m_block_idx, _)); + // (BLK_N, BLK_K, n, k) + Tensor gKV = local_tile(KV, Shape{}, make_coord(_, _)); + // (BLK_N, BLK_K, n, k) => (N, K) + Tensor cKV = local_tile(make_identity_tensor(KV.shape()), + Shape{}, + make_coord(_, _)); + + // (BLK_M, ROPE_HEAD_DIM) + Tensor gQ_rope = local_tile( + Q_ROPE, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_M, ROPE_HEAD_DIM) => (M, K) + Tensor cQ_rope = local_tile(make_identity_tensor(Q_ROPE.shape()), + Shape{}, + make_coord(m_block_idx, _0{})); + // (BLK_N, ROPE_HEAD_DIM, n) + Tensor gK_rope = local_tile( + K_ROPE, Shape{}, make_coord(_, _0{})); + // (BLK_N, ROPE_HEAD_DIM, n) => (N, K) + Tensor cK_rope = local_tile(make_identity_tensor(K_ROPE.shape()), + Shape{}, + make_coord(_, _0{})); + + TiledMma_PV tiled_mma_pv; + // accumulator: MMA,MMA_M,MMA_K, k) + auto tOrAccO = + partition_fragment_C(tiled_mma_pv, Shape{}); + clear(tOrAccO); + + constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO); + OnlineSoftmax softmax(params.sm_scale_log2); + + // mainloop + mha(mainloop_params, + gQ, + cQ, + gKV, + cKV, + gQ_rope, + cQ_rope, + gK_rope, + cK_rope, + tOrAccO, + softmax, + tidx, + block_coord, + residue_mnk, + rope_residue_mnk, + smem); + + // epilogue + epilogue(epilogue_params, + tOrAccO, + tiled_mma_pv, + gO, + cQ, + tidx, + residue_mnk, + smem); + } + } +}; + +} // namespace llm diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/sm80_mla_bench.cu similarity index 100% rename from src/kernels/attention/mla_sm80_bench.cu rename to src/kernels/attention/sm80_mla_bench.cu diff --git a/src/kernels/attention/sm80_mla_dispatch.cuh b/src/kernels/attention/sm80_mla_dispatch.cuh new file mode 100644 index 00000000..fb4ebe35 --- /dev/null +++ b/src/kernels/attention/sm80_mla_dispatch.cuh @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +#include "sm80_mla_launch.cuh" +#include "static_dispatch.h" + +namespace llm { + +#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ + [&] { \ + if (HEAD_DIM_V == 128) { \ + constexpr static int HEAD_DIM_NAME = 128; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V == 256) { \ + constexpr static int HEAD_DIM_NAME = 256; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V == 512) { \ + constexpr static int HEAD_DIM_NAME = 512; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } \ + }() + +#define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \ + [&] { \ + if (ROPE_HEAD_DIM_V == 64) { \ + constexpr static int ROPE_HEAD_DIM_NAME = 64; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } \ + }() + +// forward declaration +// template +// void sm80_launch_mla_kernel(const Params& params, cudaStream_t stream); + +// user-facing function to run the attention kernel +template +void sm80_run_mla(Params& params, cudaStream_t stream = nullptr) { + // normalize params that for performance optimization + params.normalize(); + + // dispatch to proper kernel instantiation based on params + DISPATCH_HEAD_DIM_(params.head_dim, HEAD_DIM, [&] { + DISPATCH_ROPE_HEAD_DIM_(params.rope_head_dim, ROPE_HEAD_DIM, [&] { + sm80_launch_mla_kernel(params, stream); + }); + }); +} + +} // namespace llm diff --git a/src/kernels/attention/sm80_mla_launch.cuh b/src/kernels/attention/sm80_mla_launch.cuh new file mode 100644 index 00000000..8dd8975c --- /dev/null +++ b/src/kernels/attention/sm80_mla_launch.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +#include +#include + +#include "sm80_collective_mla.cuh" +#include "sm80_collective_mla_epilogue.cuh" +#include "sm80_kernel_mla.cuh" +#include "tile_scheduler.cuh" + +namespace llm { + +namespace detail { +/// Generic kernel template. +template +__global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel( + __grid_constant__ const Params params, + __grid_constant__ const typename Operator::TileSchedulerParams + scheduler_params) { + extern __shared__ char smem[]; + Operator op; + op(params, scheduler_params, smem); +} + +template +CUTE_HOST_DEVICE constexpr auto tile_shape_selector() { + using namespace cute; + // TODO: tune block shape MNK based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| + // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | + // valid dynamic shared memory sizes for different compute capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 + // constexpr int stages = 1; + if constexpr (HEAD_DIM <= 128) { + return Shape<_64, _64, _128>{}; + } else if constexpr (HEAD_DIM <= 256) { + return Shape<_64, _32, _128>{}; + } else if constexpr (HEAD_DIM <= 512) { + return Shape<_64, _16, _128>{}; + } else { + return Shape<_64, _16, _128>{}; + } +} + +template +CUTE_HOST_DEVICE constexpr auto stages_selector() { + using namespace cute; + if constexpr (HEAD_DIM <= 128) { + return 2; + } else if constexpr (HEAD_DIM <= 256) { + return 2; + } else if constexpr (HEAD_DIM <= 512) { + return 1; + } else { + return 1; + } +} + +} // namespace detail + +template +void sm80_launch_mla_kernel(const Params& params, cudaStream_t stream) { + const auto batch_size = params.batch_size; + const auto max_q_packed_len = params.max_q_len * params.group_size; + + constexpr int stages = detail::stages_selector<80, HEAD_DIM, ROPE_HEAD_DIM>(); + using TileShape = + decltype(detail::tile_shape_selector<80, HEAD_DIM, ROPE_HEAD_DIM>()); + using CollectiveMainloop = + Sm80CollectiveMla; + using CollectiveEpilogue = + Sm80CollectiveMlaEpilogue; + + // TODO: support persistent kernels + using TileScheduler = SingleTileScheduler; + + constexpr int BLK_M = get<0>(TileShape{}); + + const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M); + typename TileScheduler::Arguments scheduler_args{ + batch_size, m_blocks, /*n_kv_heads=*/1}; + auto scheduler_params = + TileScheduler::to_underlying_arguments(scheduler_args); + + using AttnKernel = + Sm80KernelMla; + + auto mla_kernel = detail::device_kernel; + + const auto smem_size = AttnKernel::kSharedStorageSize; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + + const dim3 grid = AttnKernel::get_grid_shape(scheduler_args); + const dim3 block = AttnKernel::get_block_shape(); + + mla_kernel<<>>(params, scheduler_params); + // TODO: check launch status +} + +} // namespace llm diff --git a/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/sm80_mla_pagedkv_test.cu similarity index 75% rename from src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu rename to src/kernels/attention/sm80_mla_pagedkv_test.cu index c64e928e..1debe7a2 100644 --- a/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/sm80_mla_pagedkv_test.cu @@ -3,50 +3,12 @@ #include #include "cute/layout.hpp" -#include "mla_kernel_sm80.cuh" #include "mla_params.h" #include "mla_ref.h" -#include "mla_traits_sm80.h" +#include "sm80_mla_dispatch.cuh" namespace llm { - -#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ - [&] { \ - if (HEAD_DIM_V == 128) { \ - constexpr static int HEAD_DIM_NAME = 128; \ - constexpr static int BLK_M = 64; \ - constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 128; \ - constexpr static int STAGES = 2; \ - return __VA_ARGS__(); \ - } else if (HEAD_DIM_V == 256) { \ - constexpr static int HEAD_DIM_NAME = 256; \ - constexpr static int BLK_M = 64; \ - constexpr static int BLK_N = 32; \ - constexpr static int BLK_K = 128; \ - constexpr static int STAGES = 2; \ - return __VA_ARGS__(); \ - } else if (HEAD_DIM_V == 512) { \ - constexpr static int HEAD_DIM_NAME = 512; \ - constexpr static int BLK_M = 64; \ - constexpr static int BLK_N = 16; \ - constexpr static int BLK_K = 128; \ - constexpr static int STAGES = 1; \ - return __VA_ARGS__(); \ - } else { \ - assert(false); \ - } \ - }() - -#define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \ - [&] { \ - if (ROPE_HEAD_DIM_V == 64) { \ - constexpr static int ROPE_HEAD_DIM_NAME = 64; \ - return __VA_ARGS__(); \ - } else { \ - assert(false); \ - } \ - }() +using namespace cute; #define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \ [&] { \ @@ -109,23 +71,7 @@ torch::Tensor mla_pagedkv_sm80( params.block_cu_lens = block_cu_lens.const_data_ptr(); params.block_size = block_size; - params.normalize(); - - DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] { - DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - DISPATCH_ROPE_HEAD_DIM_(rope_head_dim, ROPE_HEAD_DIM, [&] { - using Traits = MLATraitsSM80; - - launch_mla_kernel_sm80(params, nullptr); - }); - }); - }); + DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] { sm80_run_mla(params); }); return out; } diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/sm80_mla_test.cu similarity index 63% rename from src/kernels/attention/mla_kernel_sm80_test.cu rename to src/kernels/attention/sm80_mla_test.cu index ad72abf5..b934731a 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/sm80_mla_test.cu @@ -3,54 +3,12 @@ #include #include -#include -#include "cute/numeric/numeric_types.hpp" -#include "mla_kernel_sm80.cuh" #include "mla_params.h" #include "mla_ref.h" -#include "mla_traits_sm80.h" +#include "sm80_mla_dispatch.cuh" namespace llm { - -#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ - [&] { \ - if (HEAD_DIM_V == 128) { \ - constexpr static int HEAD_DIM_NAME = 128; \ - constexpr static int BLK_M = 64; \ - constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 128; \ - constexpr static int STAGES = 2; \ - return __VA_ARGS__(); \ - } else if (HEAD_DIM_V == 256) { \ - constexpr static int HEAD_DIM_NAME = 256; \ - constexpr static int BLK_M = 64; \ - constexpr static int BLK_N = 32; \ - constexpr static int BLK_K = 128; \ - constexpr static int STAGES = 2; \ - return __VA_ARGS__(); \ - } else if (HEAD_DIM_V == 512) { \ - constexpr static int HEAD_DIM_NAME = 512; \ - constexpr static int BLK_M = 64; \ - constexpr static int BLK_N = 16; \ - constexpr static int BLK_K = 128; \ - constexpr static int STAGES = 1; \ - return __VA_ARGS__(); \ - } else { \ - assert(false); \ - } \ - }() - -#define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \ - [&] { \ - if (ROPE_HEAD_DIM_V == 64) { \ - constexpr static int ROPE_HEAD_DIM_NAME = 64; \ - return __VA_ARGS__(); \ - } else { \ - assert(false); \ - } \ - }() - #define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \ [&] { \ if (TORCH_DTYPE == torch::kHalf) { \ @@ -107,21 +65,7 @@ torch::Tensor mla_sm80( params.sm_scale = sm_scale; params.normalize(); - DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] { - DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - DISPATCH_ROPE_HEAD_DIM_(rope_head_dim, ROPE_HEAD_DIM, [&] { - using Traits = MLATraitsSM80; - - launch_mla_kernel_sm80(params, nullptr); - }); - }); - }); + DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] { sm80_run_mla(params); }); return out; }