From 68af4767e55988a163f2c6adf25e20fe2000817b Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 24 Jun 2025 22:10:58 -0700 Subject: [PATCH 1/2] feat: use global residue_mnk for oob handling --- .../attention/sm80_collective_epilogue.cuh | 31 ++-- src/kernels/attention/sm80_collective_mha.cuh | 138 +++++++++++------- src/kernels/attention/sm80_kernel_mha.cuh | 26 ++-- 3 files changed, 111 insertions(+), 84 deletions(-) diff --git a/src/kernels/attention/sm80_collective_epilogue.cuh b/src/kernels/attention/sm80_collective_epilogue.cuh index 6a00ffc7..24ce2beb 100644 --- a/src/kernels/attention/sm80_collective_epilogue.cuh +++ b/src/kernels/attention/sm80_collective_epilogue.cuh @@ -71,21 +71,19 @@ struct Sm80CollectiveEpilogue { template - CUTE_DEVICE void operator()(const Params& /*params*/, - const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N) - TiledMma tiled_mma, - TensorO& gO, // (BLK_M, HEAD_DIM) - int tidx, - const BlockCoordMNK& block_coord_mnk, - const ProblemShapeMNK& problem_shape_mnk, - char* smem) { + class TensorCO, + class ResidueMNK> + CUTE_DEVICE void operator()( + const Params& /*params*/, + const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N) + TiledMma tiled_mma, + TensorO& gO, // (BLK_M, HEAD_DIM) + const TensorCO& cO, // (BLK_M, HEAD_DIM) => (M, K) + int tidx, + const ResidueMNK& residue_mnk, + char* smem) { static constexpr int kBlockM = get<0>(TileShape{}); - const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; - const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; - // Smem auto& ss = *reinterpret_cast(smem); // (BLK_M, HEAD_DIM) @@ -106,9 +104,6 @@ struct Sm80CollectiveEpilogue { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) - auto cO = make_identity_tensor(Shape{}); - auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) // (CPY,CPY_M,CPY_K) -> (blk_m, head_dim) @@ -117,9 +112,9 @@ struct Sm80CollectiveEpilogue { // wait for smem copy done before gmem copy __syncthreads(); - auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim); + const auto residue_mk = select<0, 2>(residue_mnk); safe_copy( - gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord); + gmem_tiled_copy_O, tOsO, tOgO, tOcO, residue_mk); } }; } // namespace llm diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/sm80_collective_mha.cuh index d4413569..c6ef6ff2 100644 --- a/src/kernels/attention/sm80_collective_mha.cuh +++ b/src/kernels/attention/sm80_collective_mha.cuh @@ -144,22 +144,27 @@ struct Sm80CollectiveMha { // returns false if the block has been skipped template - CUTE_DEVICE void operator()(const Params& params, - const TensorQ& gQ, // (BLK_M, HEAD_DIM) - const TensorK& gK, // (BLK_N, HEAD_DIM, n) - const TensorV& gV, // (BLK_N, HEAD_DIM, n) - FrgTensor& tOrO, // (MMA, MMA_M, MMA_N) - Softmax& softmax, - int tidx, - const BlockCoordMNK& block_coord_mnk, - const ProblemShapeMNK& problem_shape_mnk, - char* smem) { + class ResidueMNK> + CUTE_DEVICE void operator()( + const Params& params, + const TensorQ& gQ, // (BLK_M, HEAD_DIM) + const TensorCQ& cQ, // (BLK_M, HEAD_DIM) => (M, K) + const TensorK& gK, // (BLK_N, HEAD_DIM, n) + const TensorV& gV, // (BLK_N, HEAD_DIM, n) + const TensorCKV& cKV, // (BLK_N, HEAD_DIM, n) => (N, K) + FrgTensor& tOrO, // (MMA, MMA_M, MMA_N) + Softmax& softmax, + int tidx, + const BlockCoordMNK& block_coord_mnk, + const ResidueMNK& residue_mnk, // (M, N, K) + char* smem) { static_assert(is_rmem::value, "Accum tensor must be rmem resident."); static_assert(is_gmem::value, "Q tensor must be gmem resident."); @@ -170,7 +175,8 @@ struct Sm80CollectiveMha { static constexpr int kBlockN = get<1>(TileShape{}); const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; - const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; + const int q_packed_len = get<0>(residue_mnk); + const int kv_len = get<1>(residue_mnk); const int sliding_window = LOCAL ? params.sliding_window : kv_len; const float logits_soft_cap = params.logits_soft_cap; @@ -198,16 +204,66 @@ struct Sm80CollectiveMha { auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); - // coordinate tensor for oob handling - // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) - Tensor cQ = make_identity_tensor(Shape{}); - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); - // (BLK_N, HEAD_DIM) -> (blk_n, head_dim) - Tensor cKV = make_identity_tensor(Shape{}); - Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); + // (CPY, CPY_N, CPY_K, n) => (N, K) + Tensor tGcKV = gmem_thr_copy_KV.partition_S(cKV); + // (CPY, CPY_N, CPY_K, n) + Tensor tGgK = gmem_thr_copy_KV.partition_S(gK); + Tensor tGgV = gmem_thr_copy_KV.partition_S(gV); + + // (CPY, CPY_N, CPY_K) + Tensor tGsK = gmem_thr_copy_KV.partition_D(sK); + Tensor tGsV = gmem_thr_copy_KV.partition_D(sV); + + const auto residue_mk = select<0, 2>(residue_mnk); + const auto residue_nk = select<1, 2>(residue_mnk); + + auto produce_query = [&]() { + auto tGcQ = gmem_thr_copy_Q.partition_S(cQ); + auto tGgQ = gmem_thr_copy_Q.partition_S(gQ); + auto tGsQ = gmem_thr_copy_Q.partition_D(sQ); + safe_copy( + gmem_tiled_copy_Q, tGgQ, tGsQ, tGcQ, residue_mk); + }; + + auto produce_key = [&](int ni) { + // skip ZFILL_MN for key since Mask will mask out oob with -inf + safe_copy( + gmem_tiled_copy_KV, + tGgK(_, _, _, ni), + tGsK, + tGcKV(_, _, _, ni), + residue_nk); + }; + + // produce key without oob handling + auto produce_key_no_oob = [&](int ni) { + safe_copy( + gmem_tiled_copy_KV, + tGgK(_, _, _, ni), + tGsK, + tGcKV(_, _, _, ni), + residue_nk); + }; - Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); - Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); + auto produce_value = [&](int ni) { + // skipping ZFILL_MN for v may cause nan issue + safe_copy( + gmem_tiled_copy_KV, + tGgV(_, _, _, ni), + tGsV, + tGcKV(_, _, _, ni), + residue_nk); + }; + + // produce value without oob handling + auto produce_value_no_oob = [&](int ni) { + safe_copy( + gmem_tiled_copy_KV, + tGgV(_, _, _, ni), + tGsV, + tGcKV(_, _, _, ni), + residue_nk); + }; TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(tidx); @@ -309,11 +365,7 @@ struct Sm80CollectiveMha { // ############### Prologue ############### // produce query: [] => [q] - auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); - auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); - auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim); - safe_copy( - gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord); + produce_query(); cp_async_fence(); // wait g2s copy done for query @@ -326,15 +378,7 @@ struct Sm80CollectiveMha { __syncthreads(); // produce key: [q] => [q, k] - { - const int ni = n_block_max - 1; - auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - // skip ZFILL_MN for key since Mask will mask out oob with -inf - safe_copy( - gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); - } - + produce_key(n_block_max - 1); cp_async_fence(); // ############### Mainloop ############### @@ -373,18 +417,10 @@ struct Sm80CollectiveMha { __syncthreads(); // produce value, [] => [v] - auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, n_block_idx)); - auto max_coord = make_coord(kv_len - n_block_idx * kBlockN, head_dim); if (i == 0) { - safe_copy( - gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); - - } else { // without oob handling - safe_copy( - gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + produce_value(n_block_idx); + } else { + produce_value_no_oob(n_block_idx); } cp_async_fence(); @@ -410,15 +446,7 @@ struct Sm80CollectiveMha { // produce next key: [] => [k] if (n_block_idx > n_block_min) { - // without oob handling - const int ni = n_block_idx - 1; - auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( - gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); + produce_key_no_oob(n_block_idx - 1); } cp_async_fence(); diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh index 954bf40b..f85a47da 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -84,21 +84,29 @@ class Sm80KernelMha { const int q_packed_len = size<0>(Q); const int kv_len = size<0>(K); const int head_dim = params.head_dim; - auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); - if (m_block_idx * kBlockM >= q_packed_len) { // m out of bound, skip this block continue; } + const auto residue_mnk = make_tuple(q_packed_len, kv_len, head_dim); // (BLK_M, HEAD_DIM) Tensor gQ = local_tile( Q, Shape{}, make_coord(m_block_idx, _0{})); Tensor gO = local_tile( O, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_M, HEAD_DIM) => (M, K) + Tensor cQ = local_tile(make_identity_tensor(Q.shape()), + Shape{}, + make_coord(m_block_idx, _0{})); + // (BLK_N, HEAD_DIM, n) Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + // (BLK_N, HEAD_DIM, n) => (N, K) + Tensor cKV = local_tile(make_identity_tensor(K.shape()), + Shape{}, + make_coord(_, _0{})); TiledMma tiled_mma; // accumulator: MMA,MMA_M,MMA_K) @@ -111,24 +119,20 @@ class Sm80KernelMha { // mainloop mha(mainloop_params, gQ, + cQ, gK, gV, + cKV, tOrAccO, softmax, tidx, block_coord, - problem_shape_mnk, + residue_mnk, smem); // epilogue - epilogue(epilogue_params, - tOrAccO, - tiled_mma, - gO, - tidx, - block_coord, - problem_shape_mnk, - smem); + epilogue( + epilogue_params, tOrAccO, tiled_mma, gO, cQ, tidx, residue_mnk, smem); } } }; From 82c5ca58fcf5c74e96eaf16cc0316d11c990cfc3 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 24 Jun 2025 22:20:15 -0700 Subject: [PATCH 2/2] move tile handling into kernel --- src/kernels/attention/mha_tile.h | 157 ---------------------- src/kernels/attention/mla_tile.h | 147 -------------------- src/kernels/attention/sm80_kernel_mha.cuh | 153 ++++++++++++++++++++- src/kernels/attention/sm80_kernel_mla.cuh | 144 +++++++++++++++++++- 4 files changed, 293 insertions(+), 308 deletions(-) delete mode 100644 src/kernels/attention/mha_tile.h delete mode 100644 src/kernels/attention/mla_tile.h diff --git a/src/kernels/attention/mha_tile.h b/src/kernels/attention/mha_tile.h deleted file mode 100644 index dc535f93..00000000 --- a/src/kernels/attention/mha_tile.h +++ /dev/null @@ -1,157 +0,0 @@ -#pragma once -#include -#include - -#include "fast_math.h" -#include "gather_tensor.hpp" -#include "mha_params.h" - -namespace llm { -using namespace cute; - -template -struct MHATile { - static_assert(cute::dependent_false, "not implemented"); -}; - -// AttentionTile specialization for AttentionParams -template <> -struct MHATile { - const MHAParams& params_; - const int batch_idx_; - const int kv_head_idx_; - - CUTE_HOST_DEVICE MHATile(const MHAParams& params, - int batch_idx, - int kv_head_idx) - : params_(params), batch_idx_(batch_idx), kv_head_idx_(kv_head_idx) {} - - // return the query/output tile: (q_len, head_dim) - template - CUTE_HOST_DEVICE auto get_qo_tile() const { - // (batch, seq, head, dim) - - // packed all q/o in the same kv head group together - const auto head_base = kv_head_idx_ * params_.group_size; - auto packed_idx_to_coord = [this, head_base](int packed_idx) { - int idx, offset; - params_.group_size.divmod(packed_idx, idx, offset); - return make_coord(idx, head_base + offset); - }; - - const auto packed_len = params_.q_len * params_.group_size; - const auto q_offset = batch_idx_ * get<0>(params_.q_stride); - auto q = make_gather_tensor( - make_gmem_ptr((const Element*)params_.q_ptr + q_offset), - make_shape(packed_len, params_.head_dim), - make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)), - packed_idx_to_coord); - - const auto o_offset = batch_idx_ * get<0>(params_.o_stride); - auto o = make_gather_tensor( - make_gmem_ptr((Element*)params_.o_ptr + o_offset), - make_shape(packed_len, params_.head_dim), - make_stride(select<1, 2>(params_.o_stride), get<3>(params_.o_stride)), - packed_idx_to_coord); - return make_tuple(q, o); - } - - // return the key/value tile: (kv_len, head_dim) - template - CUTE_HOST_DEVICE auto get_kv_tile() const { - // (batch, seq, kv_head, dim) - const auto k_offset = batch_idx_ * get<0>(params_.k_stride) + - kv_head_idx_ * get<2>(params_.k_stride); - const auto v_offset = batch_idx_ * get<0>(params_.v_stride) + - kv_head_idx_ * get<2>(params_.v_stride); - // k[batch_idx, :, kv_head_idx, :] - auto k = - make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset), - make_shape(params_.kv_len, params_.head_dim), - select<1, 3>(params_.k_stride)); - // v[batch_idx, :, kv_head_idx, :] - auto v = - make_tensor(make_gmem_ptr((const Element*)params_.v_ptr + v_offset), - make_shape(params_.kv_len, params_.head_dim), - select<1, 3>(params_.v_stride)); - return make_tuple(k, v); - } -}; - -// paged KV cache + variable length sequence -template <> -struct MHATile { - // NOLINTNEXTLINE - const MHAPagedKVParams& params_; - const int batch_idx_; - const int kv_head_idx_; - - CUTE_HOST_DEVICE MHATile(const MHAPagedKVParams& params, - int batch_idx, - int kv_head_idx) - : params_(params), batch_idx_(batch_idx), kv_head_idx_(kv_head_idx) {} - - // return the query/output tile: (q_len, head_dim) - template - CUTE_HOST_DEVICE auto get_qo_tile() const { - const auto begin = params_.q_cu_lens[batch_idx_]; - const auto qo_len = params_.q_cu_lens[batch_idx_ + 1] - begin; - const auto head_base = kv_head_idx_ * params_.group_size; - auto packed_idx_to_coord = [this, head_base](int packed_idx) { - int idx, offset; - params_.group_size.divmod(packed_idx, idx, offset); - return make_coord(idx, head_base + offset); - }; - - const auto packed_len = qo_len * params_.group_size; - const auto q_offset = begin * get<0>(params_.q_stride); - auto q = make_gather_tensor( - make_gmem_ptr((const Element*)params_.q_ptr + q_offset), - make_shape(packed_len, params_.head_dim), - make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)), - packed_idx_to_coord); - - const auto o_offset = begin * get<0>(params_.o_stride); - auto o = make_gather_tensor( - make_gmem_ptr((Element*)params_.o_ptr + o_offset), - make_shape(packed_len, params_.head_dim), - make_stride(select<0, 1>(params_.o_stride), get<2>(params_.o_stride)), - packed_idx_to_coord); - return make_tuple(q, o); - } - - // return the key/value tile: (kv_len, head_dim) - template - CUTE_HOST_DEVICE auto get_kv_tile() const { - const auto kv_len = - params_.kv_cu_lens[batch_idx_ + 1] - params_.kv_cu_lens[batch_idx_]; - - // map seq_idx to slot_idx - const int* block_table = - params_.block_table + params_.block_cu_lens[batch_idx_]; - auto idx_to_slot = [block_table, - shr = params_.block_shift_right, - mask = params_.block_mask](int idx) { - return block_table[idx >> shr] + (idx & mask); - }; - - // v[:, kv_head_idx, :] - const auto k_offset = kv_head_idx_ * get<1>(params_.k_stride); - auto k = make_gather_tensor( - make_gmem_ptr((const Element*)params_.k_ptr + k_offset), - make_shape(kv_len, params_.head_dim), - select<0, 2>(params_.k_stride), - idx_to_slot); - - // v[:, kv_head_idx, :] - const auto v_offset = kv_head_idx_ * get<1>(params_.v_stride); - auto v = make_gather_tensor( - make_gmem_ptr((const Element*)params_.v_ptr + v_offset), - make_shape(kv_len, params_.head_dim), - select<0, 2>(params_.v_stride), - idx_to_slot); - return make_tuple(k, v); - } -}; - -} // namespace llm diff --git a/src/kernels/attention/mla_tile.h b/src/kernels/attention/mla_tile.h deleted file mode 100644 index 9b139f62..00000000 --- a/src/kernels/attention/mla_tile.h +++ /dev/null @@ -1,147 +0,0 @@ -#pragma once -#include -#include - -#include "gather_tensor.hpp" -#include "mla_params.h" - -namespace llm { -using namespace cute; - -template -struct MLATile { - static_assert(cute::dependent_false, "not implemented"); -}; - -// AttentionTile specialization for AttentionParams -template <> -struct MLATile { - const MLAParams& params_; - const int batch_idx_; - - CUTE_HOST_DEVICE MLATile(const MLAParams& params, int batch_idx) - : params_(params), batch_idx_(batch_idx) {} - - // return the query/output tile: (q_packed_len, head_dim) - // return q_rope tile: (q_packed_len, rope_head_dim) - template - CUTE_HOST_DEVICE auto get_qo_tile() const { - // (batch, seq, head, dim) - const auto q_packed_len = params_.q_len * params_.group_size; - const auto q_offset = batch_idx_ * get<0>(params_.q_stride); - auto q = - make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), - make_shape(q_packed_len, params_.head_dim), - select<2, 3>(params_.q_stride)); - - // (batch, seq, head, rope_head_dim) - const auto q_rope_offset = batch_idx_ * get<0>(params_.q_rope_stride); - auto q_rope = make_tensor( - make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), - make_shape(q_packed_len, params_.rope_head_dim), - select<2, 3>(params_.q_rope_stride)); - - // (batch, seq, head, dim) - const auto o_offset = batch_idx_ * get<0>(params_.o_stride); - auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), - make_shape(q_packed_len, params_.head_dim), - select<2, 3>(params_.o_stride)); - return make_tuple(q, q_rope, o); - } - - // return the kv: (kv_len, head_dim) - // return k_rope: (kv_len, rope_head_dim) - template - CUTE_HOST_DEVICE auto get_kv_tile() const { - // (batch, seq, dim) - const auto kv_offset = batch_idx_ * get<0>(params_.kv_stride); - // k[batch_idx, :, :] - auto kv = - make_tensor(make_gmem_ptr((const Element*)params_.kv_ptr + kv_offset), - make_shape(params_.kv_len, params_.head_dim), - select<1, 2>(params_.kv_stride)); - - // (batch, seq, rope_head_dim) - const auto k_rope_offset = batch_idx_ * get<0>(params_.k_rope_stride); - auto k_rope = make_tensor( - make_gmem_ptr((const Element*)params_.k_rope_ptr + k_rope_offset), - make_shape(params_.kv_len, params_.rope_head_dim), - select<1, 2>(params_.k_rope_stride)); - return make_tuple(kv, k_rope); - } -}; - -// paged KV cache + variable length sequence -template <> -struct MLATile { - // NOLINTNEXTLINE - const MLAPagedKVParams& params_; - const int batch_idx_; - - CUTE_HOST_DEVICE MLATile(const MLAPagedKVParams& params, int batch_idx) - : params_(params), batch_idx_(batch_idx) {} - - // return the query/output tile: (q_packed_len, head_dim) - // return q_rope tile: (q_packed_len, rope_head_dim) - template - CUTE_HOST_DEVICE auto get_qo_tile() const { - const auto begin = params_.q_cu_lens[batch_idx_]; - const auto qo_len = params_.q_cu_lens[batch_idx_ + 1] - begin; - - // (seq, head, dim) - const auto q_packed_len = qo_len * params_.group_size; - const auto q_offset = begin * get<0>(params_.q_stride); - - auto q = - make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), - make_shape(q_packed_len, params_.head_dim), - select<1, 2>(params_.q_stride)); - - // (seq, head, rope_head_dim) - const auto q_rope_offset = begin * get<0>(params_.q_rope_stride); - auto q_rope = make_tensor( - make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), - make_shape(q_packed_len, params_.rope_head_dim), - select<1, 2>(params_.q_rope_stride)); - - // (seq, head, dim) - const auto o_offset = begin * get<0>(params_.o_stride); - auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), - make_shape(q_packed_len, params_.head_dim), - select<1, 2>(params_.o_stride)); - return make_tuple(q, q_rope, o); - } - - // return the kv: (kv_len, head_dim) - // return k_rope: (kv_len, rope_head_dim) - template - CUTE_HOST_DEVICE auto get_kv_tile() const { - const auto kv_len = - params_.kv_cu_lens[batch_idx_ + 1] - params_.kv_cu_lens[batch_idx_]; - - // map seq_idx to slot_idx - const int* block_table = - params_.block_table + params_.block_cu_lens[batch_idx_]; - auto idx_to_slot = [block_table, - shr = params_.block_shift_right, - mask = params_.block_mask](int idx) { - return block_table[idx >> shr] + (idx & mask); - }; - - // kv: (seq, dim) - auto kv = make_gather_tensor(make_gmem_ptr((const Element*)params_.kv_ptr), - make_shape(kv_len, params_.head_dim), - params_.kv_stride, - idx_to_slot); - - // k_rope: (seq, rope_head_dim) - auto k_rope = - make_gather_tensor(make_gmem_ptr((const Element*)params_.k_rope_ptr), - make_shape(kv_len, params_.rope_head_dim), - params_.k_rope_stride, - idx_to_slot); - return make_tuple(kv, k_rope); - } -}; - -} // namespace llm diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh index f85a47da..5d131f80 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -6,13 +6,162 @@ #include #include -#include "mha_tile.h" +#include "gather_tensor.hpp" +#include "mha_params.h" #include "online_softmax.cuh" namespace llm { using namespace cute; +namespace detail { + +template +struct MHATile { + static_assert(cute::dependent_false, "not implemented"); +}; + +// AttentionTile specialization for AttentionParams +template <> +struct MHATile { + const MHAParams& params_; + const int batch_idx_; + const int kv_head_idx_; + + CUTE_HOST_DEVICE MHATile(const MHAParams& params, + int batch_idx, + int kv_head_idx) + : params_(params), batch_idx_(batch_idx), kv_head_idx_(kv_head_idx) {} + + // return the query/output tile: (q_len, head_dim) + template + CUTE_HOST_DEVICE auto get_qo_tile() const { + // (batch, seq, head, dim) + + // packed all q/o in the same kv head group together + const auto head_base = kv_head_idx_ * params_.group_size; + auto packed_idx_to_coord = [this, head_base](int packed_idx) { + int idx, offset; + params_.group_size.divmod(packed_idx, idx, offset); + return make_coord(idx, head_base + offset); + }; + + const auto packed_len = params_.q_len * params_.group_size; + const auto q_offset = batch_idx_ * get<0>(params_.q_stride); + auto q = make_gather_tensor( + make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(packed_len, params_.head_dim), + make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)), + packed_idx_to_coord); + + const auto o_offset = batch_idx_ * get<0>(params_.o_stride); + auto o = make_gather_tensor( + make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(packed_len, params_.head_dim), + make_stride(select<1, 2>(params_.o_stride), get<3>(params_.o_stride)), + packed_idx_to_coord); + return make_tuple(q, o); + } + + // return the key/value tile: (kv_len, head_dim) + template + CUTE_HOST_DEVICE auto get_kv_tile() const { + // (batch, seq, kv_head, dim) + const auto k_offset = batch_idx_ * get<0>(params_.k_stride) + + kv_head_idx_ * get<2>(params_.k_stride); + const auto v_offset = batch_idx_ * get<0>(params_.v_stride) + + kv_head_idx_ * get<2>(params_.v_stride); + // k[batch_idx, :, kv_head_idx, :] + auto k = + make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset), + make_shape(params_.kv_len, params_.head_dim), + select<1, 3>(params_.k_stride)); + // v[batch_idx, :, kv_head_idx, :] + auto v = + make_tensor(make_gmem_ptr((const Element*)params_.v_ptr + v_offset), + make_shape(params_.kv_len, params_.head_dim), + select<1, 3>(params_.v_stride)); + return make_tuple(k, v); + } +}; + +// paged KV cache + variable length sequence +template <> +struct MHATile { + // NOLINTNEXTLINE + const MHAPagedKVParams& params_; + const int batch_idx_; + const int kv_head_idx_; + + CUTE_HOST_DEVICE MHATile(const MHAPagedKVParams& params, + int batch_idx, + int kv_head_idx) + : params_(params), batch_idx_(batch_idx), kv_head_idx_(kv_head_idx) {} + + // return the query/output tile: (q_len, head_dim) + template + CUTE_HOST_DEVICE auto get_qo_tile() const { + const auto begin = params_.q_cu_lens[batch_idx_]; + const auto qo_len = params_.q_cu_lens[batch_idx_ + 1] - begin; + const auto head_base = kv_head_idx_ * params_.group_size; + auto packed_idx_to_coord = [this, head_base](int packed_idx) { + int idx, offset; + params_.group_size.divmod(packed_idx, idx, offset); + return make_coord(idx, head_base + offset); + }; + + const auto packed_len = qo_len * params_.group_size; + const auto q_offset = begin * get<0>(params_.q_stride); + auto q = make_gather_tensor( + make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(packed_len, params_.head_dim), + make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)), + packed_idx_to_coord); + + const auto o_offset = begin * get<0>(params_.o_stride); + auto o = make_gather_tensor( + make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(packed_len, params_.head_dim), + make_stride(select<0, 1>(params_.o_stride), get<2>(params_.o_stride)), + packed_idx_to_coord); + return make_tuple(q, o); + } + + // return the key/value tile: (kv_len, head_dim) + template + CUTE_HOST_DEVICE auto get_kv_tile() const { + const auto kv_len = + params_.kv_cu_lens[batch_idx_ + 1] - params_.kv_cu_lens[batch_idx_]; + + // map seq_idx to slot_idx + const int* block_table = + params_.block_table + params_.block_cu_lens[batch_idx_]; + auto idx_to_slot = [block_table, + shr = params_.block_shift_right, + mask = params_.block_mask](int idx) { + return block_table[idx >> shr] + (idx & mask); + }; + + // v[:, kv_head_idx, :] + const auto k_offset = kv_head_idx_ * get<1>(params_.k_stride); + auto k = make_gather_tensor( + make_gmem_ptr((const Element*)params_.k_ptr + k_offset), + make_shape(kv_len, params_.head_dim), + select<0, 2>(params_.k_stride), + idx_to_slot); + + // v[:, kv_head_idx, :] + const auto v_offset = kv_head_idx_ * get<1>(params_.v_stride); + auto v = make_gather_tensor( + make_gmem_ptr((const Element*)params_.v_ptr + v_offset), + make_shape(kv_len, params_.head_dim), + select<0, 2>(params_.v_stride), + idx_to_slot); + return make_tuple(k, v); + } +}; +} // namespace detail + template @@ -75,7 +224,7 @@ class Sm80KernelMha { const auto tidx = threadIdx.x; // (q_packed_len, HEAD_DIM) - MHATile tile(params, batch_idx, kv_head_idx); + detail::MHATile tile(params, batch_idx, kv_head_idx); auto [Q, O] = tile.template get_qo_tile(); // (kv_len, HEAD_DIM) auto [K, V] = tile.template get_kv_tile(); diff --git a/src/kernels/attention/sm80_kernel_mla.cuh b/src/kernels/attention/sm80_kernel_mla.cuh index c5b127df..d026a946 100644 --- a/src/kernels/attention/sm80_kernel_mla.cuh +++ b/src/kernels/attention/sm80_kernel_mla.cuh @@ -6,13 +6,153 @@ #include #include -#include "mla_tile.h" +#include "gather_tensor.hpp" +#include "mla_params.h" #include "online_softmax.cuh" namespace llm { using namespace cute; +namespace detail { + +template +struct MLATile { + static_assert(cute::dependent_false, "not implemented"); +}; + +// AttentionTile specialization for AttentionParams +template <> +struct MLATile { + const MLAParams& params_; + const int batch_idx_; + + CUTE_HOST_DEVICE MLATile(const MLAParams& params, int batch_idx) + : params_(params), batch_idx_(batch_idx) {} + + // return the query/output tile: (q_packed_len, head_dim) + // return q_rope tile: (q_packed_len, rope_head_dim) + template + CUTE_HOST_DEVICE auto get_qo_tile() const { + // (batch, seq, head, dim) + const auto q_packed_len = params_.q_len * params_.group_size; + const auto q_offset = batch_idx_ * get<0>(params_.q_stride); + auto q = + make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(q_packed_len, params_.head_dim), + select<2, 3>(params_.q_stride)); + + // (batch, seq, head, rope_head_dim) + const auto q_rope_offset = batch_idx_ * get<0>(params_.q_rope_stride); + auto q_rope = make_tensor( + make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), + make_shape(q_packed_len, params_.rope_head_dim), + select<2, 3>(params_.q_rope_stride)); + + // (batch, seq, head, dim) + const auto o_offset = batch_idx_ * get<0>(params_.o_stride); + auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(q_packed_len, params_.head_dim), + select<2, 3>(params_.o_stride)); + return make_tuple(q, q_rope, o); + } + + // return the kv: (kv_len, head_dim) + // return k_rope: (kv_len, rope_head_dim) + template + CUTE_HOST_DEVICE auto get_kv_tile() const { + // (batch, seq, dim) + const auto kv_offset = batch_idx_ * get<0>(params_.kv_stride); + // k[batch_idx, :, :] + auto kv = + make_tensor(make_gmem_ptr((const Element*)params_.kv_ptr + kv_offset), + make_shape(params_.kv_len, params_.head_dim), + select<1, 2>(params_.kv_stride)); + + // (batch, seq, rope_head_dim) + const auto k_rope_offset = batch_idx_ * get<0>(params_.k_rope_stride); + auto k_rope = make_tensor( + make_gmem_ptr((const Element*)params_.k_rope_ptr + k_rope_offset), + make_shape(params_.kv_len, params_.rope_head_dim), + select<1, 2>(params_.k_rope_stride)); + return make_tuple(kv, k_rope); + } +}; + +// paged KV cache + variable length sequence +template <> +struct MLATile { + // NOLINTNEXTLINE + const MLAPagedKVParams& params_; + const int batch_idx_; + + CUTE_HOST_DEVICE MLATile(const MLAPagedKVParams& params, int batch_idx) + : params_(params), batch_idx_(batch_idx) {} + + // return the query/output tile: (q_packed_len, head_dim) + // return q_rope tile: (q_packed_len, rope_head_dim) + template + CUTE_HOST_DEVICE auto get_qo_tile() const { + const auto begin = params_.q_cu_lens[batch_idx_]; + const auto qo_len = params_.q_cu_lens[batch_idx_ + 1] - begin; + + // (seq, head, dim) + const auto q_packed_len = qo_len * params_.group_size; + const auto q_offset = begin * get<0>(params_.q_stride); + + auto q = + make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(q_packed_len, params_.head_dim), + select<1, 2>(params_.q_stride)); + + // (seq, head, rope_head_dim) + const auto q_rope_offset = begin * get<0>(params_.q_rope_stride); + auto q_rope = make_tensor( + make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), + make_shape(q_packed_len, params_.rope_head_dim), + select<1, 2>(params_.q_rope_stride)); + + // (seq, head, dim) + const auto o_offset = begin * get<0>(params_.o_stride); + auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(q_packed_len, params_.head_dim), + select<1, 2>(params_.o_stride)); + return make_tuple(q, q_rope, o); + } + + // return the kv: (kv_len, head_dim) + // return k_rope: (kv_len, rope_head_dim) + template + CUTE_HOST_DEVICE auto get_kv_tile() const { + const auto kv_len = + params_.kv_cu_lens[batch_idx_ + 1] - params_.kv_cu_lens[batch_idx_]; + + // map seq_idx to slot_idx + const int* block_table = + params_.block_table + params_.block_cu_lens[batch_idx_]; + auto idx_to_slot = [block_table, + shr = params_.block_shift_right, + mask = params_.block_mask](int idx) { + return block_table[idx >> shr] + (idx & mask); + }; + + // kv: (seq, dim) + auto kv = make_gather_tensor(make_gmem_ptr((const Element*)params_.kv_ptr), + make_shape(kv_len, params_.head_dim), + params_.kv_stride, + idx_to_slot); + + // k_rope: (seq, rope_head_dim) + auto k_rope = + make_gather_tensor(make_gmem_ptr((const Element*)params_.k_rope_ptr), + make_shape(kv_len, params_.rope_head_dim), + params_.k_rope_stride, + idx_to_slot); + return make_tuple(kv, k_rope); + } +}; +} // namespace detail + template @@ -76,7 +216,7 @@ class Sm80KernelMla { // Q/O: (q_packed_len, HEAD_DIM) // Q_ROPE: (q_packed_len, ROPE_HEAD_DIM) - MLATile tile(params, batch_idx); + detail::MLATile tile(params, batch_idx); auto [Q, Q_ROPE, O] = tile.template get_qo_tile(); // KV: (kv_len, HEAD_DIM) // K_ROPE: (kv_len, ROPE_HEAD_DIM)