diff --git a/src/kernels/attention/layout_convertor.h b/src/kernels/attention/layout_convertor.h index 5b08591f..30cb2944 100644 --- a/src/kernels/attention/layout_convertor.h +++ b/src/kernels/attention/layout_convertor.h @@ -9,21 +9,23 @@ using namespace cute; // Only works for TiledMMA (64x16x16) with SM80_16x8x16_F32F16F16F32_TN struct LayoutConvertor { // Convert fragment layout to rowcol layout for iterating - // (MMA=4, MMA_M, MMA_N) => ((2, MMA_M), (2, MMA_N)) + // (MMA=4, MMA_M, MMA_N, ...) => ((2, MMA_M), (2, MMA_N), ...) template CUTE_HOST_DEVICE static constexpr auto to_mn(const LayoutC& layout) { - auto l = logical_divide(layout, Shape<_2>{}); - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), - make_layout(get<0, 0>(l), get<2>(l))); - } + constexpr int R = LayoutC::rank; + static_assert(R >= 3, "Expected at least 3 modes in LayoutC."); - // (MMA=4, MMA_M, MMA_N, STEPS) => ((2, MMA_M), (2, MMA_N), STEPS) - template - CUTE_HOST_DEVICE static constexpr auto to_mns(const LayoutC& layout) { + // ((2, 2), MMA_M, MMA_N, ...) auto l = logical_divide(layout, Shape<_2>{}); - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), - make_layout(get<0, 0>(l), get<2>(l)), - get<3>(l)); + // ((2, MMA_M), (2, MMA_N), ...) + if constexpr (R > 3) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<2>(l)), + take<3, R>(l)); + } else { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<2>(l))); + } } // Convert fragment layout from gemm-I C to gemm-II A @@ -36,4 +38,4 @@ struct LayoutConvertor { } }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mask.h b/src/kernels/attention/mask.h index 754c5a33..f869d889 100644 --- a/src/kernels/attention/mask.h +++ b/src/kernels/attention/mask.h @@ -12,7 +12,6 @@ struct Mask { // Fragment type for alibi slopes using FragmentT = decltype(make_tensor(Int{})); - const int q_len_; const int kv_len_; const FastDivmod& group_size_; const int sliding_window_; @@ -24,46 +23,43 @@ struct Mask { int kv_len, const FastDivmod& group_size, int sliding_window) - : q_len_(q_len), - kv_len_(kv_len), + : kv_len_(kv_len), group_size_(group_size), sliding_window_(sliding_window), diagonal_offset_(kv_len - q_len) {} // cS_mn: ((2, MMA_M), (2, MMA_N)) template - CUTE_HOST_DEVICE void init_alibi(IdentityS& cS_mn, - int m_base_idx, - int kv_head_idx, - float sm_scale, - const float* alibi_slops_ptr) { + CUTE_HOST_DEVICE void init_alibi( + IdentityS& cS_mn, // ((2, MMA_M), (2, MMA_N)) => (M, N) + int kv_head_idx, + float sm_scale, + const float* alibi_slops_ptr) { // copy alibi slopes to registers CUTE_UNROLL for (int i = 0; i < size<0>(cS_mn); ++i) { - const auto [m, n] = cS_mn(i, _0{}); - const int q_packed_idx = m_base_idx + m; + const int q_packed_idx = get<0>(cS_mn(i, _0{})); const int offset = q_packed_idx % group_size_; - const int head_idx = kv_head_idx * group_size_ + offset; + const int head_idx = (kv_head_idx * group_size_) + offset; alibi_slopes_(i) = alibi_slops_ptr[head_idx] / sm_scale; } } // rS_mn/cS_mn: ((2, MMA_M), (2, MMA_N)) - template - CUTE_HOST_DEVICE void apply(FragmentS& rS_mn, - IdentityS& cS_mn, - int m_base_idx, - int n_base_idx) const { + template + CUTE_HOST_DEVICE void apply( + FragmentS& rS_mn, // ((2, MMA_M), (2, MMA_N)) + IdentityS& cS_mn // ((2, MMA_M), (2, MMA_N)) => (M, N) + ) const { CUTE_UNROLL for (int i = 0; i < size<0>(rS_mn); ++i) { const auto alibi_slope = ALIBI ? alibi_slopes_(i) : 0.0f; CUTE_UNROLL for (int j = 0; j < size<1>(rS_mn); ++j) { - auto [m, n] = cS_mn(i, j); - const int q_packed_idx = m_base_idx + m; - const int kv_idx = n_base_idx + n; - - const int q_idx = q_packed_idx / group_size_ + diagonal_offset_; + const auto [m, n] = cS_mn(i, j); + const int q_packed_idx = m; + const int kv_idx = n; + const int q_idx = (q_packed_idx / group_size_) + diagonal_offset_; const bool out_of_boundary = [&]() { if constexpr (OOB_MASK && LOCAL) { @@ -93,4 +89,4 @@ struct Mask { } }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/sm80_collective_mha.cuh index 1a5c2dc1..62756e7d 100644 --- a/src/kernels/attention/sm80_collective_mha.cuh +++ b/src/kernels/attention/sm80_collective_mha.cuh @@ -148,9 +148,10 @@ struct Sm80CollectiveMha { class TensorK, class TensorV, class TensorCKV, + class TensorCMN, class FrgTensor, class Softmax, - class BlockCoordMNK, + class BlockCoord, class ResidueMNK> CUTE_DEVICE void operator()( const Params& params, @@ -159,10 +160,11 @@ struct Sm80CollectiveMha { const TensorK& gK, // (BLK_N, HEAD_DIM, n) const TensorV& gV, // (BLK_N, HEAD_DIM, n) const TensorCKV& cKV, // (BLK_N, HEAD_DIM, n) => (N, K) + const TensorCMN& cMN, // (BLK_M, BLK_N, n) => (M, N) FrgTensor& tOrO, // (MMA, MMA_M, MMA_N) Softmax& softmax, int tidx, - const BlockCoordMNK& block_coord_mnk, + const BlockCoord& blk_coord, const ResidueMNK& residue_mnk, // (M, N, K) char* smem) { static_assert(is_rmem::value, @@ -174,7 +176,8 @@ struct Sm80CollectiveMha { static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); - const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk; + const int q_idx = get<0>(blk_coord); + const int kv_head_idx = get<1>(blk_coord); const int q_packed_len = get<0>(residue_mnk); const int kv_len = get<1>(residue_mnk); @@ -342,7 +345,7 @@ struct Sm80CollectiveMha { auto tOrO_mn = make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout())); - const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len; + const int diagonal = q_idx + kv_len - q_len; // process kv in range: [kv_idx_min, kv_idx_max) const int kv_idx_min = std::max(0, diagonal - sliding_window); const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); @@ -385,31 +388,29 @@ struct Sm80CollectiveMha { constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; const int n_blocks = n_block_max - n_block_min; - // attention score accumulator, (MMA,MMA_M,MMA_N) + // attention score accumulator, (MMA, MMA_M, MMA_N) auto tSrS = partition_fragment_C(tiled_mma, Shape{}); + // ((2, MMA_M), (2, MMA_N)) auto tSrS_mn = make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout())); - // identity tensor for score accumulator - auto tScS = - thr_mma.partition_C(make_identity_tensor(Shape{})); - auto tScS_mn = - make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout())); + // (MMA, MMA_M, MMA_N, n) + auto tScMN = thr_mma.partition_C(cMN); + // ((2, MMA_M), (2, MMA_N), n) => (M, N) + auto tScMN_mn = + make_tensor(tScMN.data(), LayoutConvertor::to_mn(tScMN.layout())); constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); using Mask = Mask; Mask mask(q_len, kv_len, group_size, sliding_window); if constexpr (ALIBI) { - mask.init_alibi(tScS_mn, - m_block_idx * kBlockM, - kv_head_idx, - sm_scale, - params.alibi_slopes_ptr); + const auto tScS_mn = tScMN_mn(_, _, _0{}); + mask.init_alibi(tScS_mn, kv_head_idx, sm_scale, params.alibi_slopes_ptr); } CUTE_NO_UNROLL for (int i = 0; i < n_blocks; ++i) { - const int n_block_idx = n_block_max - 1 - i; + const int ni = n_block_max - 1 - i; clear(tSrS); // wait key, queue: [q, k] => [] @@ -418,9 +419,9 @@ struct Sm80CollectiveMha { // produce value, [] => [v] if (i == 0) { - produce_value(n_block_idx); + produce_value(ni); } else { - produce_value_no_oob(n_block_idx); + produce_value_no_oob(ni); } cp_async_fence(); @@ -435,18 +436,19 @@ struct Sm80CollectiveMha { apply_logits_soft_cap(tSrS); } + // apply mask + // ((2, MMA_M), (2, MMA_N)) => (M, N) + const auto tScS_mn = tScMN_mn(_, _, ni); if (i < n_oob_mask) { - mask.template apply( - tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); + mask.template apply(tSrS_mn, tScS_mn); } else { - mask.template apply( - tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); + mask.template apply(tSrS_mn, tScS_mn); } softmax.rescale(tSrS_mn, tOrO_mn); // produce next key: [] => [k] - if (n_block_idx > n_block_min) { - produce_key_no_oob(n_block_idx - 1); + if (ni > n_block_min) { + produce_key_no_oob(ni - 1); } cp_async_fence(); diff --git a/src/kernels/attention/sm80_collective_mla.cuh b/src/kernels/attention/sm80_collective_mla.cuh index 670cdfd2..55af9d33 100644 --- a/src/kernels/attention/sm80_collective_mla.cuh +++ b/src/kernels/attention/sm80_collective_mla.cuh @@ -244,9 +244,10 @@ struct Sm80CollectiveMla { class TensorCQR, class TensorKR, class TensorCKR, + class TensorCMN, class FrgTensor, class Softmax, - class BlockCoordMNK, + class BlockCoord, class ResidueMNK, class RopeResidueMNK> CUTE_DEVICE void operator()( @@ -259,10 +260,11 @@ struct Sm80CollectiveMla { const TensorCQR& cQ_rope, // (BLK_M, ROPE_HEAD_DIM) =>(M, K) const TensorKR& gK_rope, // (BLK_N, HEAD_DIM, n) const TensorCKR& cK_rope, // (BLK_N, HEAD_DIM, n) => (N, K) + const TensorCMN& cMN, // (BLK_M, BLK_N, n) => (M, N) FrgTensor& tOrO, // (BLK_N, ROPE_HEAD_DIM, n) Softmax& softmax, int tidx, - const BlockCoordMNK& block_coord_mnk, + const BlockCoord& blk_coord, const ResidueMNK& residue_mnk, const RopeResidueMNK& rope_residue_mnk, char* smem) { @@ -279,12 +281,11 @@ struct Sm80CollectiveMla { static constexpr int kBlockN = get<1>(TileShape{}); static constexpr int kBlockK = get<2>(TileShape{}); - const int m_block_idx = get<1>(block_coord_mnk); + const int q_idx = get<0>(blk_coord); const int q_packed_len = get<0>(residue_mnk); const int kv_len = get<1>(residue_mnk); const auto& group_size = params.group_size; - const int q_len = q_packed_len / group_size; // Construct shared memory tiles @@ -586,16 +587,13 @@ struct Sm80CollectiveMla { cute::copy(smem_tiled_copy_S, tCrS, tCsS); }; - // output accumulator: (MMA,MMA_M,MMA_K,k) - // auto tOrO = - // partition_fragment_C(tiled_mma_pv, Shape{}); - // clear(tOrO); + // (MMA,MMA_M,MMA_K,k) => ((2, MMA_M), (2, MMA_K), k) auto tOrO_mn = - make_tensor(tOrO.data(), LayoutConvertor::to_mns(tOrO.layout())); + make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout())); const int n_block_min = 0; // process kv in range: [0, kv_idx_max) - const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len; + const int diagonal = q_idx + kv_len - q_len; const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); @@ -649,14 +647,17 @@ struct Sm80CollectiveMla { } // ############### Mainloop ############### - // attention score accumulator, (MMA,MMA_M,MMA_N) + // attention score accumulator, (MMA, MMA_M, MMA_N) auto tSrS = partition_fragment_C(tiled_mma_qk, Shape{}); - auto tScS = - thr_mma_qk.partition_C(make_identity_tensor(Shape{})); + // ((2, MMA_M), (2, MMA_N)) auto tSrS_mn = make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout())); - auto tScS_mn = - make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout())); + + // (MMA, MMA_M, MMA_N, n) => (M, N) + auto tScMN = thr_mma_qk.partition_C(cMN); + // ((2, MMA_M), (2, MMA_N), n) => (M, N) + auto tScMN_mn = + make_tensor(tScMN.data(), LayoutConvertor::to_mn(tScMN.layout())); constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); using Mask = Mask; @@ -688,11 +689,12 @@ struct Sm80CollectiveMla { } // apply mask + // ((2, MMA_M), (2, MMA_N)) => (M, N) + const auto tScS_mn = tScMN_mn(_, _, ni); if (i < n_oob_mask) { - mask.apply(tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); + mask.apply(tSrS_mn, tScS_mn); } else { - mask.apply( - tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); + mask.apply(tSrS_mn, tScS_mn); } softmax.rescale(tSrS_mn, tOrO_mn, reduce_rowmax); diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh index 5d131f80..1ffd8c79 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -218,6 +218,7 @@ class Sm80KernelMha { EpilogueParams epilogue_params; // process each block + const auto& group_size = params.group_size; for (const auto block_coord : scheduler) { // block coord: (batch_idx, m_block_idx, kv_head_idx) const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord; @@ -232,12 +233,15 @@ class Sm80KernelMha { // problem shape const int q_packed_len = size<0>(Q); const int kv_len = size<0>(K); - const int head_dim = params.head_dim; - if (m_block_idx * kBlockM >= q_packed_len) { + const int m_block_base = m_block_idx * kBlockM; + if (m_block_base >= q_packed_len) { // m out of bound, skip this block continue; } - const auto residue_mnk = make_tuple(q_packed_len, kv_len, head_dim); + + const int q_idx = m_block_base / group_size; + const auto residue_mnk = + make_tuple(q_packed_len, kv_len, params.head_dim); // (BLK_M, HEAD_DIM) Tensor gQ = local_tile( @@ -257,6 +261,12 @@ class Sm80KernelMha { Shape{}, make_coord(_, _0{})); + // (BLK_M, BLK_N, n) => (M, N) + Tensor cMN = + local_tile(make_identity_tensor(make_shape(q_packed_len, kv_len)), + Shape{}, + make_coord(m_block_idx, _)); + TiledMma tiled_mma; // accumulator: MMA,MMA_M,MMA_K) auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); @@ -266,16 +276,18 @@ class Sm80KernelMha { OnlineSoftmax softmax(params.sm_scale_log2); // mainloop + const auto blk_coord = make_coord(q_idx, kv_head_idx); mha(mainloop_params, gQ, cQ, gK, gV, cKV, + cMN, tOrAccO, softmax, tidx, - block_coord, + blk_coord, residue_mnk, smem); diff --git a/src/kernels/attention/sm80_kernel_mla.cuh b/src/kernels/attention/sm80_kernel_mla.cuh index d026a946..8f65e58d 100644 --- a/src/kernels/attention/sm80_kernel_mla.cuh +++ b/src/kernels/attention/sm80_kernel_mla.cuh @@ -208,7 +208,6 @@ class Sm80KernelMla { // process each block const auto& group_size = params.group_size; - for (const auto block_coord : scheduler) { // block coord: (batch_idx, m_block_idx, kv_head_idx) const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord; @@ -224,17 +223,16 @@ class Sm80KernelMla { // problem shape const int q_packed_len = size<0>(Q); - const int q_len = q_packed_len / group_size; const int kv_len = size<0>(KV); + const int m_block_base = m_block_idx * kBlockM; - if (m_block_idx * kBlockM >= size<0>(Q)) { + if (m_block_base >= q_packed_len) { // m out of bound, return return; } - - const auto head_dim = params.head_dim; - auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); - const auto residue_mnk = make_tuple(q_packed_len, kv_len, head_dim); + const int q_idx = m_block_base / group_size; + const auto residue_mnk = + make_tuple(q_packed_len, kv_len, params.head_dim); const auto rope_residue_mnk = make_tuple(q_packed_len, kv_len, ROPE_HEAD_DIM{}); @@ -269,6 +267,12 @@ class Sm80KernelMla { Shape{}, make_coord(_, _0{})); + // (BLK_M, BLK_N, n) => (M, N) + Tensor cMN = + local_tile(make_identity_tensor(make_shape(q_packed_len, kv_len)), + Shape{}, + make_coord(m_block_idx, _)); + TiledMma_PV tiled_mma_pv; // accumulator: MMA,MMA_M,MMA_K, k) auto tOrAccO = @@ -279,6 +283,7 @@ class Sm80KernelMla { OnlineSoftmax softmax(params.sm_scale_log2); // mainloop + const auto blk_coord = make_coord(q_idx, _0{}); mha(mainloop_params, gQ, cQ, @@ -288,10 +293,11 @@ class Sm80KernelMla { cQ_rope, gK_rope, cK_rope, + cMN, tOrAccO, softmax, tidx, - block_coord, + blk_coord, residue_mnk, rope_residue_mnk, smem);