Skip to content

Commit 34887b4

Browse files
committed
refactor: keep problem shape in fmha block for q/k/v shape
1 parent 5d0132c commit 34887b4

File tree

4 files changed

+58
-79
lines changed

4 files changed

+58
-79
lines changed

src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,16 @@ struct Sm120CollectiveFMhaWs {
238238
}
239239

240240
// (m_block_idx, ((kv_head_idx, _0), batch_idx))
241-
const auto& block_coord = block.get_block_coord();
242-
const int m_block_idx = get<0>(block_coord);
243-
const int kv_head_idx = get<1, 0, 0>(block_coord);
241+
const auto& blk_coord = block.get_coord();
242+
const int m_block_idx = get<0>(blk_coord);
243+
const int kv_head_idx = get<1, 0, 0>(blk_coord);
244+
245+
const auto& problem_shape = block.get_problem_shape();
246+
const int q_len = get<0>(problem_shape);
247+
const int kv_len = get<1>(problem_shape);
244248

245-
const auto q_packed_len = block.get_packed_len();
246-
const auto q_len = block.get_q_len();
247-
const auto kv_len = block.get_kv_len();
248249
const auto& group_size = block.get_group_size();
250+
const int q_packed_len = block.get_packed_len();
249251

250252
// Construct smem tensors
251253
// (BLK_M, BLK_K), k-major

src/kernels/attention/common/fmha_block.h

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace llm {
1212
using namespace cute;
1313

1414
// AttentionTile specialization for AttentionParams
15-
template <typename TileShape, // (BLK_M, BLK_N, BLK_K)
15+
template <typename ProblemShape, // (Q, K, D, ((KH, G), B))
16+
typename TileShape, // (BLK_M, BLK_N, BLK_K)
1617
typename BlocKCoord, // (m_block_idx, ((kv_head_idx, _0), batch_idx))
1718
typename Element, // Element type
1819
typename StrideQ, // (Q, D, ((KH, G), B))
@@ -46,27 +47,15 @@ struct FmhaBlock {
4647
StrideV v_stride;
4748
StrideO o_stride;
4849

49-
// Parameters from problem shape
50-
int batch_size;
51-
int q_len;
52-
int kv_len;
53-
int head_dim;
54-
// int n_heads;
55-
int n_kv_heads; // number of kv heads
50+
// for fast divmod
5651
FastDivmod group_size;
5752
};
5853

59-
template <class ProblemShape>
6054
static Params to_underlying_arguments(const ProblemShape& problem_shape,
6155
const Arguments& args,
62-
void* workspace = nullptr) {
56+
void* /*workspace*/) {
6357
// ProblemShape: (Q, K, D, ((KH, G), B))
64-
const int q_len = get<0>(problem_shape);
65-
const int kv_len = get<1>(problem_shape);
66-
const int head_dim = get<2>(problem_shape);
67-
const int n_kv_heads = get<3, 0, 0>(problem_shape);
6858
const int group_size = get<3, 0, 1>(problem_shape);
69-
const int batch_size = get<3, 1>(problem_shape);
7059

7160
// TODO: construct tma_load for k/v tensors
7261
return {
@@ -78,11 +67,6 @@ struct FmhaBlock {
7867
.k_stride = args.k_stride,
7968
.v_stride = args.v_stride,
8069
.o_stride = args.o_stride,
81-
.batch_size = batch_size,
82-
.q_len = q_len,
83-
.kv_len = kv_len,
84-
.head_dim = head_dim,
85-
.n_kv_heads = n_kv_heads,
8670
.group_size = FastDivmod(group_size),
8771
};
8872
}
@@ -96,19 +80,22 @@ struct FmhaBlock {
9680
using BLK_K = Int<kBlockK>;
9781

9882
// hold a reference to the parameters and block coordination
83+
const ProblemShape& problem_shape_;
9984
const Params& params_;
100-
const BlocKCoord& blk_coord_;
85+
const BlocKCoord& coord_;
10186

10287
// derived parameters to avoid recomputation
10388
int m_block_base_;
10489
int packed_len_;
10590

10691
// Constructor
107-
CUTE_HOST_DEVICE FmhaBlock(const Params& params, const BlocKCoord& blk_coord)
108-
: params_(params), blk_coord_(blk_coord) {
109-
// derive parameters
110-
m_block_base_ = get<0>(blk_coord) * get<0>(TileShape{});
111-
packed_len_ = params_.q_len * params_.group_size;
92+
CUTE_HOST_DEVICE FmhaBlock(const ProblemShape& problem_shape,
93+
const Params& params,
94+
const BlocKCoord& coord)
95+
: problem_shape_(problem_shape), params_(params), coord_(coord) {
96+
// derived parameters
97+
m_block_base_ = get<0>(coord) * get<0>(TileShape{});
98+
packed_len_ = get<0>(problem_shape_) * params_.group_size;
11299
}
113100

114101
// check if the m_block is valid
@@ -120,36 +107,33 @@ struct FmhaBlock {
120107
// returns packed_len
121108
CUTE_HOST_DEVICE int get_packed_len() const { return packed_len_; }
122109

123-
// returns actual query length
124-
CUTE_HOST_DEVICE int get_q_len() const { return params_.q_len; }
125-
126-
// returns actual kv length
127-
CUTE_HOST_DEVICE int get_kv_len() const { return params_.kv_len; }
110+
// returns problem shape: (Q, K, D, ((KH, G), B))
111+
CUTE_HOST_DEVICE const auto& get_problem_shape() const {
112+
return problem_shape_;
113+
}
128114

129-
// returns head_dim
130-
CUTE_HOST_DEVICE int get_head_dim() const { return params_.head_dim; }
115+
// returns (m_block_idx, ((kv_head_idx, _0), batch_idx))
116+
CUTE_HOST_DEVICE const auto& get_coord() const { return coord_; }
131117

132-
// returns group size
118+
// returns group size fast divmod
133119
CUTE_HOST_DEVICE const FastDivmod& get_group_size() const {
134120
return params_.group_size;
135121
}
136122

137123
// returns redidue mnk
138124
CUTE_HOST_DEVICE auto get_residue_mnk() const {
139-
return make_tuple(packed_len_, params_.kv_len, params_.head_dim);
125+
auto residue_mnk = select<0, 1, 2>(problem_shape_);
126+
get<0>(residue_mnk) = packed_len_;
127+
return residue_mnk;
140128
}
141129

142-
// returns (m_block_idx, ((kv_head_idx, _0), batch_idx))
143-
CUTE_HOST_DEVICE const auto& get_block_coord() const { return blk_coord_; }
144-
145130
// returns kv block range: (n_block_min, n_block_max]
146131
template <bool kLocal>
147132
CUTE_HOST_DEVICE auto get_kv_blocks(int sliding_window) const {
148133
static constexpr int kBlockM = get<0>(TileShape{});
149134
static constexpr int kBlockN = get<1>(TileShape{});
150135

151-
const int q_len = params_.q_len;
152-
const int kv_len = params_.kv_len;
136+
const auto [q_len, kv_len] = select<0, 1>(problem_shape_);
153137
const int q_idx = m_block_base_ / params_.group_size;
154138
// take care of causal mask
155139
const int diagonal = q_idx + kv_len - q_len;
@@ -168,15 +152,11 @@ struct FmhaBlock {
168152
// return the query tile: (BLK_M, BLK_K) => (M, K)
169153
CUTE_HOST_DEVICE auto get_q_tile() const {
170154
// (Q, D, ((KH, G), B))
171-
auto q_shape = make_shape(
172-
params_.q_len,
173-
params_.head_dim,
174-
make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size),
175-
params_.batch_size));
155+
auto q_shape = select<0, 2, 3>(problem_shape_);
176156
auto mQ =
177157
make_tensor(make_gmem_ptr(params_.q_ptr), q_shape, params_.q_stride);
178158
// (Q, D, G*)
179-
auto Q = mQ(_, _, get<1>(blk_coord_));
159+
auto Q = mQ(_, _, get<1>(coord_));
180160

181161
// packing all q in the same kv head group together
182162
auto packed_idx_to_coord = [this](int packed_idx) {
@@ -192,12 +172,13 @@ struct FmhaBlock {
192172
get<1>(params_.q_stride));
193173

194174
// packed tensor: (pQ, D) => ((Q, G), D)
175+
const int head_dim = get<2>(problem_shape_);
195176
auto pQ = make_gather_tensor(Q.data(),
196-
make_shape(packed_len_, params_.head_dim),
177+
make_shape(packed_len_, head_dim),
197178
q_stride,
198179
packed_idx_to_coord);
199180

200-
const auto m_block_idx = get<0>(blk_coord_);
181+
const auto m_block_idx = get<0>(coord_);
201182
// (BLK_M, BLK_K)
202183
Tensor gQ =
203184
local_tile(pQ, Shape<BLK_M, BLK_K>{}, make_coord(m_block_idx, _0{}));
@@ -211,15 +192,11 @@ struct FmhaBlock {
211192
// return the output tile: (BLK_M, BLK_K) => (M, K)
212193
CUTE_HOST_DEVICE auto get_o_tile() const {
213194
// (Q, D, ((KH, G), B))
214-
auto o_shape = make_shape(
215-
params_.q_len,
216-
params_.head_dim,
217-
make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size),
218-
params_.batch_size));
195+
auto o_shape = select<0, 2, 3>(problem_shape_);
219196
auto mO =
220197
make_tensor(make_gmem_ptr(params_.o_ptr), o_shape, params_.o_stride);
221198
// (Q, D, G*)
222-
auto O = mO(_, _, get<1>(blk_coord_));
199+
auto O = mO(_, _, get<1>(coord_));
223200

224201
// packing all q in the same kv head group together
225202
auto packed_idx_to_coord = [this](int packed_idx) {
@@ -231,39 +208,36 @@ struct FmhaBlock {
231208
auto o_stride = make_stride(
232209
make_stride(get<0>(params_.o_stride), get<2, 0, 1>(params_.o_stride)),
233210
get<1>(params_.o_stride));
211+
const int head_dim = get<2>(problem_shape_);
234212
// packed tensor: (pO, D) => ((O, G), D)
235213
auto pO = make_gather_tensor(O.data(),
236-
make_shape(packed_len_, params_.head_dim),
214+
make_shape(packed_len_, head_dim),
237215
o_stride,
238216
packed_idx_to_coord);
239217

240-
const auto m_block_idx = get<0>(blk_coord_);
218+
const auto m_block_idx = get<0>(coord_);
241219
// (BLK_M, BLK_K)
242220
Tensor gO =
243221
local_tile(pO, Shape<BLK_M, BLK_K>{}, make_coord(m_block_idx, _0{}));
244222
// (BLK_M, BLK_K) => (M, K)
245-
Tensor cQ = local_tile(make_identity_tensor(shape(pO)),
223+
Tensor cO = local_tile(make_identity_tensor(shape(pO)),
246224
Shape<BLK_M, BLK_K>{},
247225
make_coord(m_block_idx, _0{}));
248-
return make_tuple(gO, cQ);
226+
return make_tuple(gO, cO);
249227
}
250228

251229
// return the key/value tile: (BLK_N, BLK_K, n) => (N, K)
252230
CUTE_HOST_DEVICE auto get_kv_tile() const {
253231
// (KV, D, ((KH, G), B))
254-
auto kv_shape = make_shape(
255-
params_.kv_len,
256-
params_.head_dim,
257-
make_shape(make_shape(params_.n_kv_heads, (int)params_.group_size),
258-
params_.batch_size));
232+
auto kv_shape = select<1, 2, 3>(problem_shape_);
259233
auto mK =
260234
make_tensor(make_gmem_ptr(params_.k_ptr), kv_shape, params_.k_stride);
261235
auto mV =
262236
make_tensor(make_gmem_ptr(params_.v_ptr), kv_shape, params_.v_stride);
263237

264238
// (K/V, D)
265-
auto K = mK(_, _, get<1>(blk_coord_));
266-
auto V = mV(_, _, get<1>(blk_coord_));
239+
auto K = mK(_, _, get<1>(coord_));
240+
auto V = mV(_, _, get<1>(coord_));
267241

268242
// (BLK_N, BLK_K, n)
269243
Tensor gK = local_tile(K, Shape<BLK_N, BLK_K>{}, make_coord(_, _0{}));

src/kernels/attention/kernel/builders/sm120_kernel_builder.inl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ struct KernelBuilder<cutlass::arch::Sm120,
4242
// TODO: support persistent kernels
4343
using TileScheduler = SingleTileScheduler;
4444
using BlocKCoord = TileScheduler::BlocKCoord;
45-
using Block = FmhaBlock<TileShape,
45+
using Block = FmhaBlock<ProblemShape,
46+
TileShape,
4647
BlocKCoord,
4748
Element,
4849
StrideQ,

src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class Sm120KernelFmhaWs {
114114
};
115115

116116
struct Params {
117+
ProblemShape problem_shape; // (Q, K, D, ((KH, G), B))
117118
typename Block::Params input;
118119
typename CollectiveMainloop::Params mainloop;
119120
typename CollectiveEpilogue::Params epilogue;
@@ -123,14 +124,15 @@ class Sm120KernelFmhaWs {
123124
// convert arguments to params
124125
static Params to_underlying_arguments(Arguments const& args,
125126
void* workspace) {
126-
return Params{Block::to_underlying_arguments(
127+
return Params{.problem_shape = args.problem_shape,
128+
.input = Block::to_underlying_arguments(
127129
args.problem_shape, args.input, workspace),
128-
CollectiveMainloop::to_underlying_arguments(
130+
.mainloop = CollectiveMainloop::to_underlying_arguments(
129131
args.problem_shape, args.mainloop, workspace),
130-
CollectiveEpilogue::to_underlying_arguments(
132+
.epilogue = CollectiveEpilogue::to_underlying_arguments(
131133
args.problem_shape, args.epilogue, workspace),
132-
TileScheduler::to_underlying_arguments(args.problem_shape,
133-
TileShape{})};
134+
.scheduler = TileScheduler::to_underlying_arguments(
135+
args.problem_shape, TileShape{})};
134136
}
135137

136138
// returns grid and block shape for kernel launch
@@ -154,7 +156,7 @@ class Sm120KernelFmhaWs {
154156

155157
// process each block
156158
for (const auto blk_coord : scheduler) {
157-
const Block block(params.input, blk_coord);
159+
const Block block(params.problem_shape, params.input, blk_coord);
158160
mainloop.load(params.mainloop,
159161
block,
160162
tidx,
@@ -191,7 +193,7 @@ class Sm120KernelFmhaWs {
191193

192194
// process each block
193195
for (const auto blk_coord : scheduler) {
194-
const Block block(params.input, blk_coord);
196+
const Block block(params.problem_shape, params.input, blk_coord);
195197

196198
TiledMma tiled_mma;
197199
// accumulator: (MMA,MMA_M,MMA_K)

0 commit comments

Comments
 (0)