diff --git a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh index 640fd34f..287ebf6a 100644 --- a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh @@ -238,14 +238,16 @@ struct Sm120CollectiveFMhaWs { } // (m_block_idx, ((kv_head_idx, _0), batch_idx)) - const auto& block_coord = block.get_block_coord(); - const int m_block_idx = get<0>(block_coord); - const int kv_head_idx = get<1, 0, 0>(block_coord); + const auto& blk_coord = block.get_coord(); + const int m_block_idx = get<0>(blk_coord); + const int kv_head_idx = get<1, 0, 0>(blk_coord); + + const auto& problem_shape = block.get_problem_shape(); + const int q_len = get<0>(problem_shape); + const int kv_len = get<1>(problem_shape); - const auto q_packed_len = block.get_packed_len(); - const auto q_len = block.get_q_len(); - const auto kv_len = block.get_kv_len(); const auto& group_size = block.get_group_size(); + const int q_packed_len = block.get_packed_len(); // Construct smem tensors // (BLK_M, BLK_K), k-major diff --git a/src/kernels/attention/common/fmha_block.h b/src/kernels/attention/common/fmha_block.h index 5c2aaec7..512f0d98 100644 --- a/src/kernels/attention/common/fmha_block.h +++ b/src/kernels/attention/common/fmha_block.h @@ -12,7 +12,8 @@ namespace llm { using namespace cute; // AttentionTile specialization for AttentionParams -template static Params to_underlying_arguments(const ProblemShape& problem_shape, const Arguments& args, - void* workspace = nullptr) { + void* /*workspace*/) { // ProblemShape: (Q, K, D, ((KH, G), B)) - const int q_len = get<0>(problem_shape); - const int kv_len = get<1>(problem_shape); - const int head_dim = get<2>(problem_shape); - const int n_kv_heads = get<3, 0, 0>(problem_shape); const int group_size = get<3, 0, 1>(problem_shape); - const int batch_size = get<3, 1>(problem_shape); // TODO: construct tma_load for k/v tensors return { @@ -78,11 +68,7 @@ struct FmhaBlock { .k_stride = args.k_stride, .v_stride = args.v_stride, .o_stride = args.o_stride, - .batch_size = batch_size, - .q_len = q_len, - .kv_len = kv_len, - .head_dim = head_dim, - .n_kv_heads = n_kv_heads, + .problem_shape = problem_shape, .group_size = FastDivmod(group_size), }; } @@ -97,18 +83,18 @@ struct FmhaBlock { // hold a reference to the parameters and block coordination const Params& params_; - const BlocKCoord& blk_coord_; + const BlocKCoord& coord_; // derived parameters to avoid recomputation int m_block_base_; int packed_len_; // Constructor - CUTE_HOST_DEVICE FmhaBlock(const Params& params, const BlocKCoord& blk_coord) - : params_(params), blk_coord_(blk_coord) { - // derive parameters - m_block_base_ = get<0>(blk_coord) * get<0>(TileShape{}); - packed_len_ = params_.q_len * params_.group_size; + CUTE_HOST_DEVICE FmhaBlock(const Params& params, const BlocKCoord& coord) + : params_(params), coord_(coord) { + // derived parameters + m_block_base_ = get<0>(coord) * get<0>(TileShape{}); + packed_len_ = get<0>(params_.problem_shape) * params_.group_size; } // check if the m_block is valid @@ -120,36 +106,33 @@ struct FmhaBlock { // returns packed_len CUTE_HOST_DEVICE int get_packed_len() const { return packed_len_; } - // returns actual query length - CUTE_HOST_DEVICE int get_q_len() const { return params_.q_len; } - - // returns actual kv length - CUTE_HOST_DEVICE int get_kv_len() const { return params_.kv_len; } + // returns problem shape: (Q, K, D, ((KH, G), B)) + CUTE_HOST_DEVICE const auto& get_problem_shape() const { + return params_.problem_shape; + } - // returns head_dim - CUTE_HOST_DEVICE int get_head_dim() const { return params_.head_dim; } + // returns (m_block_idx, ((kv_head_idx, _0), batch_idx)) + CUTE_HOST_DEVICE const auto& get_coord() const { return coord_; } - // returns group size + // returns group size fast divmod CUTE_HOST_DEVICE const FastDivmod& get_group_size() const { return params_.group_size; } // returns redidue mnk CUTE_HOST_DEVICE auto get_residue_mnk() const { - return make_tuple(packed_len_, params_.kv_len, params_.head_dim); + auto residue_mnk = select<0, 1, 2>(params_.problem_shape); + get<0>(residue_mnk) = packed_len_; + return residue_mnk; } - // returns (m_block_idx, ((kv_head_idx, _0), batch_idx)) - CUTE_HOST_DEVICE const auto& get_block_coord() const { return blk_coord_; } - // returns kv block range: (n_block_min, n_block_max] template CUTE_HOST_DEVICE auto get_kv_blocks(int sliding_window) const { static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); - const int q_len = params_.q_len; - const int kv_len = params_.kv_len; + const auto [q_len, kv_len] = select<0, 1>(params_.problem_shape); const int q_idx = m_block_base_ / params_.group_size; // take care of causal mask const int diagonal = q_idx + kv_len - q_len; @@ -168,15 +151,11 @@ struct FmhaBlock { // return the query tile: (BLK_M, BLK_K) => (M, K) CUTE_HOST_DEVICE auto get_q_tile() const { // (Q, D, ((KH, G), B)) - auto q_shape = make_shape( - params_.q_len, - params_.head_dim, - make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size), - params_.batch_size)); + auto q_shape = select<0, 2, 3>(params_.problem_shape); auto mQ = make_tensor(make_gmem_ptr(params_.q_ptr), q_shape, params_.q_stride); // (Q, D, G*) - auto Q = mQ(_, _, get<1>(blk_coord_)); + auto Q = mQ(_, _, get<1>(coord_)); // packing all q in the same kv head group together auto packed_idx_to_coord = [this](int packed_idx) { @@ -192,12 +171,13 @@ struct FmhaBlock { get<1>(params_.q_stride)); // packed tensor: (pQ, D) => ((Q, G), D) + const int head_dim = get<2>(params_.problem_shape); auto pQ = make_gather_tensor(Q.data(), - make_shape(packed_len_, params_.head_dim), + make_shape(packed_len_, head_dim), q_stride, packed_idx_to_coord); - const auto m_block_idx = get<0>(blk_coord_); + const auto m_block_idx = get<0>(coord_); // (BLK_M, BLK_K) Tensor gQ = local_tile(pQ, Shape{}, make_coord(m_block_idx, _0{})); @@ -211,15 +191,11 @@ struct FmhaBlock { // return the output tile: (BLK_M, BLK_K) => (M, K) CUTE_HOST_DEVICE auto get_o_tile() const { // (Q, D, ((KH, G), B)) - auto o_shape = make_shape( - params_.q_len, - params_.head_dim, - make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size), - params_.batch_size)); + auto o_shape = select<0, 2, 3>(params_.problem_shape); auto mO = make_tensor(make_gmem_ptr(params_.o_ptr), o_shape, params_.o_stride); // (Q, D, G*) - auto O = mO(_, _, get<1>(blk_coord_)); + auto O = mO(_, _, get<1>(coord_)); // packing all q in the same kv head group together auto packed_idx_to_coord = [this](int packed_idx) { @@ -231,39 +207,36 @@ struct FmhaBlock { auto o_stride = make_stride( make_stride(get<0>(params_.o_stride), get<2, 0, 1>(params_.o_stride)), get<1>(params_.o_stride)); + const int head_dim = get<2>(params_.problem_shape); // packed tensor: (pO, D) => ((O, G), D) auto pO = make_gather_tensor(O.data(), - make_shape(packed_len_, params_.head_dim), + make_shape(packed_len_, head_dim), o_stride, packed_idx_to_coord); - const auto m_block_idx = get<0>(blk_coord_); + const auto m_block_idx = get<0>(coord_); // (BLK_M, BLK_K) Tensor gO = local_tile(pO, Shape{}, make_coord(m_block_idx, _0{})); // (BLK_M, BLK_K) => (M, K) - Tensor cQ = local_tile(make_identity_tensor(shape(pO)), + Tensor cO = local_tile(make_identity_tensor(shape(pO)), Shape{}, make_coord(m_block_idx, _0{})); - return make_tuple(gO, cQ); + return make_tuple(gO, cO); } // return the key/value tile: (BLK_N, BLK_K, n) => (N, K) CUTE_HOST_DEVICE auto get_kv_tile() const { // (KV, D, ((KH, G), B)) - auto kv_shape = make_shape( - params_.kv_len, - params_.head_dim, - make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size), - params_.batch_size)); + auto kv_shape = select<1, 2, 3>(params_.problem_shape); auto mK = make_tensor(make_gmem_ptr(params_.k_ptr), kv_shape, params_.k_stride); auto mV = make_tensor(make_gmem_ptr(params_.v_ptr), kv_shape, params_.v_stride); // (K/V, D) - auto K = mK(_, _, get<1>(blk_coord_)); - auto V = mV(_, _, get<1>(blk_coord_)); + auto K = mK(_, _, get<1>(coord_)); + auto V = mV(_, _, get<1>(coord_)); // (BLK_N, BLK_K, n) Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); diff --git a/src/kernels/attention/kernel/builders/sm120_kernel_builder.inl b/src/kernels/attention/kernel/builders/sm120_kernel_builder.inl index 2563f7dc..06abd859 100644 --- a/src/kernels/attention/kernel/builders/sm120_kernel_builder.inl +++ b/src/kernels/attention/kernel/builders/sm120_kernel_builder.inl @@ -42,7 +42,8 @@ struct KernelBuilder