Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,16 @@ struct Sm120CollectiveFMhaWs {
}

// (m_block_idx, ((kv_head_idx, _0), batch_idx))
const auto& block_coord = block.get_block_coord();
const int m_block_idx = get<0>(block_coord);
const int kv_head_idx = get<1, 0, 0>(block_coord);
const auto& blk_coord = block.get_coord();
const int m_block_idx = get<0>(blk_coord);
const int kv_head_idx = get<1, 0, 0>(blk_coord);

const auto& problem_shape = block.get_problem_shape();
const int q_len = get<0>(problem_shape);
const int kv_len = get<1>(problem_shape);

const auto q_packed_len = block.get_packed_len();
const auto q_len = block.get_q_len();
const auto kv_len = block.get_kv_len();
const auto& group_size = block.get_group_size();
const int q_packed_len = block.get_packed_len();

// Construct smem tensors
// (BLK_M, BLK_K), k-major
Expand Down
103 changes: 38 additions & 65 deletions src/kernels/attention/common/fmha_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace llm {
using namespace cute;

// AttentionTile specialization for AttentionParams
template <typename TileShape, // (BLK_M, BLK_N, BLK_K)
template <typename ProblemShape, // (Q, K, D, ((KH, G), B))
typename TileShape, // (BLK_M, BLK_N, BLK_K)
typename BlocKCoord, // (m_block_idx, ((kv_head_idx, _0), batch_idx))
typename Element, // Element type
typename StrideQ, // (Q, D, ((KH, G), B))
Expand Down Expand Up @@ -46,27 +47,16 @@ struct FmhaBlock {
StrideV v_stride;
StrideO o_stride;

// Parameters from problem shape
int batch_size;
int q_len;
int kv_len;
int head_dim;
// int n_heads;
int n_kv_heads; // number of kv heads
ProblemShape problem_shape;
// for fast divmod
FastDivmod group_size;
};

template <class ProblemShape>
static Params to_underlying_arguments(const ProblemShape& problem_shape,
const Arguments& args,
void* workspace = nullptr) {
void* /*workspace*/) {
// ProblemShape: (Q, K, D, ((KH, G), B))
const int q_len = get<0>(problem_shape);
const int kv_len = get<1>(problem_shape);
const int head_dim = get<2>(problem_shape);
const int n_kv_heads = get<3, 0, 0>(problem_shape);
const int group_size = get<3, 0, 1>(problem_shape);
const int batch_size = get<3, 1>(problem_shape);

// TODO: construct tma_load for k/v tensors
return {
Expand All @@ -78,11 +68,7 @@ struct FmhaBlock {
.k_stride = args.k_stride,
.v_stride = args.v_stride,
.o_stride = args.o_stride,
.batch_size = batch_size,
.q_len = q_len,
.kv_len = kv_len,
.head_dim = head_dim,
.n_kv_heads = n_kv_heads,
.problem_shape = problem_shape,
.group_size = FastDivmod(group_size),
};
}
Expand All @@ -97,18 +83,18 @@ struct FmhaBlock {

// hold a reference to the parameters and block coordination
const Params& params_;
const BlocKCoord& blk_coord_;
const BlocKCoord& coord_;

// derived parameters to avoid recomputation
int m_block_base_;
int packed_len_;

// Constructor
CUTE_HOST_DEVICE FmhaBlock(const Params& params, const BlocKCoord& blk_coord)
: params_(params), blk_coord_(blk_coord) {
// derive parameters
m_block_base_ = get<0>(blk_coord) * get<0>(TileShape{});
packed_len_ = params_.q_len * params_.group_size;
CUTE_HOST_DEVICE FmhaBlock(const Params& params, const BlocKCoord& coord)
: params_(params), coord_(coord) {
// derived parameters
m_block_base_ = get<0>(coord) * get<0>(TileShape{});
packed_len_ = get<0>(params_.problem_shape) * params_.group_size;
}

// check if the m_block is valid
Expand All @@ -120,36 +106,33 @@ struct FmhaBlock {
// returns packed_len
CUTE_HOST_DEVICE int get_packed_len() const { return packed_len_; }

// returns actual query length
CUTE_HOST_DEVICE int get_q_len() const { return params_.q_len; }

// returns actual kv length
CUTE_HOST_DEVICE int get_kv_len() const { return params_.kv_len; }
// returns problem shape: (Q, K, D, ((KH, G), B))
CUTE_HOST_DEVICE const auto& get_problem_shape() const {
return params_.problem_shape;
}

// returns head_dim
CUTE_HOST_DEVICE int get_head_dim() const { return params_.head_dim; }
// returns (m_block_idx, ((kv_head_idx, _0), batch_idx))
CUTE_HOST_DEVICE const auto& get_coord() const { return coord_; }

// returns group size
// returns group size fast divmod
CUTE_HOST_DEVICE const FastDivmod& get_group_size() const {
return params_.group_size;
}

// returns redidue mnk
CUTE_HOST_DEVICE auto get_residue_mnk() const {
return make_tuple(packed_len_, params_.kv_len, params_.head_dim);
auto residue_mnk = select<0, 1, 2>(params_.problem_shape);
get<0>(residue_mnk) = packed_len_;
return residue_mnk;
}

// returns (m_block_idx, ((kv_head_idx, _0), batch_idx))
CUTE_HOST_DEVICE const auto& get_block_coord() const { return blk_coord_; }

// returns kv block range: (n_block_min, n_block_max]
template <bool kLocal>
CUTE_HOST_DEVICE auto get_kv_blocks(int sliding_window) const {
static constexpr int kBlockM = get<0>(TileShape{});
static constexpr int kBlockN = get<1>(TileShape{});

const int q_len = params_.q_len;
const int kv_len = params_.kv_len;
const auto [q_len, kv_len] = select<0, 1>(params_.problem_shape);
const int q_idx = m_block_base_ / params_.group_size;
// take care of causal mask
const int diagonal = q_idx + kv_len - q_len;
Expand All @@ -168,15 +151,11 @@ struct FmhaBlock {
// return the query tile: (BLK_M, BLK_K) => (M, K)
CUTE_HOST_DEVICE auto get_q_tile() const {
// (Q, D, ((KH, G), B))
auto q_shape = make_shape(
params_.q_len,
params_.head_dim,
make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size),
params_.batch_size));
auto q_shape = select<0, 2, 3>(params_.problem_shape);
auto mQ =
make_tensor(make_gmem_ptr(params_.q_ptr), q_shape, params_.q_stride);
// (Q, D, G*)
auto Q = mQ(_, _, get<1>(blk_coord_));
auto Q = mQ(_, _, get<1>(coord_));

// packing all q in the same kv head group together
auto packed_idx_to_coord = [this](int packed_idx) {
Expand All @@ -192,12 +171,13 @@ struct FmhaBlock {
get<1>(params_.q_stride));

// packed tensor: (pQ, D) => ((Q, G), D)
const int head_dim = get<2>(params_.problem_shape);
auto pQ = make_gather_tensor(Q.data(),
make_shape(packed_len_, params_.head_dim),
make_shape(packed_len_, head_dim),
q_stride,
packed_idx_to_coord);

const auto m_block_idx = get<0>(blk_coord_);
const auto m_block_idx = get<0>(coord_);
// (BLK_M, BLK_K)
Tensor gQ =
local_tile(pQ, Shape<BLK_M, BLK_K>{}, make_coord(m_block_idx, _0{}));
Expand All @@ -211,15 +191,11 @@ struct FmhaBlock {
// return the output tile: (BLK_M, BLK_K) => (M, K)
CUTE_HOST_DEVICE auto get_o_tile() const {
// (Q, D, ((KH, G), B))
auto o_shape = make_shape(
params_.q_len,
params_.head_dim,
make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size),
params_.batch_size));
auto o_shape = select<0, 2, 3>(params_.problem_shape);
auto mO =
make_tensor(make_gmem_ptr(params_.o_ptr), o_shape, params_.o_stride);
// (Q, D, G*)
auto O = mO(_, _, get<1>(blk_coord_));
auto O = mO(_, _, get<1>(coord_));

// packing all q in the same kv head group together
auto packed_idx_to_coord = [this](int packed_idx) {
Expand All @@ -231,39 +207,36 @@ struct FmhaBlock {
auto o_stride = make_stride(
make_stride(get<0>(params_.o_stride), get<2, 0, 1>(params_.o_stride)),
get<1>(params_.o_stride));
const int head_dim = get<2>(params_.problem_shape);
// packed tensor: (pO, D) => ((O, G), D)
auto pO = make_gather_tensor(O.data(),
make_shape(packed_len_, params_.head_dim),
make_shape(packed_len_, head_dim),
o_stride,
packed_idx_to_coord);

const auto m_block_idx = get<0>(blk_coord_);
const auto m_block_idx = get<0>(coord_);
// (BLK_M, BLK_K)
Tensor gO =
local_tile(pO, Shape<BLK_M, BLK_K>{}, make_coord(m_block_idx, _0{}));
// (BLK_M, BLK_K) => (M, K)
Tensor cQ = local_tile(make_identity_tensor(shape(pO)),
Tensor cO = local_tile(make_identity_tensor(shape(pO)),
Shape<BLK_M, BLK_K>{},
make_coord(m_block_idx, _0{}));
return make_tuple(gO, cQ);
return make_tuple(gO, cO);
}

// return the key/value tile: (BLK_N, BLK_K, n) => (N, K)
CUTE_HOST_DEVICE auto get_kv_tile() const {
// (KV, D, ((KH, G), B))
auto kv_shape = make_shape(
params_.kv_len,
params_.head_dim,
make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size),
params_.batch_size));
auto kv_shape = select<1, 2, 3>(params_.problem_shape);
auto mK =
make_tensor(make_gmem_ptr(params_.k_ptr), kv_shape, params_.k_stride);
auto mV =
make_tensor(make_gmem_ptr(params_.v_ptr), kv_shape, params_.v_stride);

// (K/V, D)
auto K = mK(_, _, get<1>(blk_coord_));
auto V = mV(_, _, get<1>(blk_coord_));
auto K = mK(_, _, get<1>(coord_));
auto V = mV(_, _, get<1>(coord_));

// (BLK_N, BLK_K, n)
Tensor gK = local_tile(K, Shape<BLK_N, BLK_K>{}, make_coord(_, _0{}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ struct KernelBuilder<cutlass::arch::Sm120,
// TODO: support persistent kernels
using TileScheduler = SingleTileScheduler;
using BlocKCoord = TileScheduler::BlocKCoord;
using Block = FmhaBlock<TileShape,
using Block = FmhaBlock<ProblemShape,
TileShape,
BlocKCoord,
Element,
StrideQ,
Expand Down