Skip to content

Commit 5d0132c

Browse files
authored
refactor: change stride for Q/K/V to MNKL (#494)
1 parent 8d41e66 commit 5d0132c

12 files changed

+248
-180
lines changed

src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,35 @@ struct Sm120CollectiveFMhaWs {
179179

180180
// load Q/K/V from gmem to smem
181181
template <class Block>
182-
CUTE_DEVICE void load(const Block& block,
182+
CUTE_DEVICE void load(const Params& params,
183+
const Block& block,
183184
int tidx,
184185
PipelineQ& q_pipeline,
185186
typename PipelineQ::PipelineState& q_state,
186187
PipelineKV& kv_pipeline,
187188
typename PipelineKV::PipelineState& kv_state,
188189
TensorStorage& ss) {
190+
if (!block.is_valid()) {
191+
// skip invalid block
192+
return;
193+
}
194+
const auto [n_block_min, n_block_max] =
195+
block.template get_kv_blocks<LOCAL>(params.sliding_window);
196+
if (n_block_min >= n_block_max) {
197+
return; // no kv blocks to process
198+
}
199+
189200
// forward to the load implementation
190201
Load load;
191-
load(block, tidx, q_pipeline, q_state, kv_pipeline, kv_state, ss);
202+
load(block,
203+
tidx,
204+
n_block_min,
205+
n_block_max,
206+
q_pipeline,
207+
q_state,
208+
kv_pipeline,
209+
kv_state,
210+
ss);
192211
}
193212

194213
template <class Block, class FrgTensor, class PipelineQ, class PipelineKV>
@@ -212,12 +231,16 @@ struct Sm120CollectiveFMhaWs {
212231
return;
213232
}
214233

215-
const auto [n_block_min, n_block_max] = block.get_kv_blocks();
234+
const auto [n_block_min, n_block_max] =
235+
block.template get_kv_blocks<LOCAL>(params.sliding_window);
216236
if (n_block_min >= n_block_max) {
217237
return; // no kv blocks to process
218238
}
219239

220-
const auto [batch_idx, m_block_idx, kv_head_idx] = block.get_block_coord();
240+
// (m_block_idx, ((kv_head_idx, _0), batch_idx))
241+
const auto& block_coord = block.get_block_coord();
242+
const int m_block_idx = get<0>(block_coord);
243+
const int kv_head_idx = get<1, 0, 0>(block_coord);
221244

222245
const auto q_packed_len = block.get_packed_len();
223246
const auto q_len = block.get_q_len();

src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,15 @@ struct Sm120CollectiveLoadCpAsyncWs {
3939
template <class Block>
4040
CUTE_DEVICE void operator()(const Block& block,
4141
int tidx,
42+
int n_block_min,
43+
int n_block_max,
4244
PipelineQ& q_pipeline,
4345
typename PipelineQ::PipelineState& q_state,
4446
PipelineKV& kv_pipeline,
4547
typename PipelineKV::PipelineState& kv_state,
4648
TensorStorage& ss) {
4749
static constexpr int kStages = size<2>(SmemLayoutK{});
4850

49-
if (!block.is_valid()) {
50-
// skip invalid block
51-
return;
52-
}
53-
54-
const auto [n_block_min, n_block_max] = block.get_kv_blocks();
55-
if (n_block_min >= n_block_max) {
56-
return; // no kv blocks to process
57-
}
58-
5951
// (M, N, K)
6052
const auto residue_mnk = block.get_residue_mnk();
6153

src/kernels/attention/collective/sm120_collective_load_tma_ws.cuh

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,6 @@ struct Sm120CollectiveLoadTmaWs {
3535
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,
3636
Element>{}));
3737

38-
// using StrideK = ...;
39-
40-
// using TMA_K = decltype(make_tma_copy(
41-
// GmemTiledCopy{}, // TMA_COPY
42-
// make_tensor(static_cast<InternalElementA const*>(nullptr),
43-
// repeat_like(StrideK{}, int32_t(0)), StrideK{}),
44-
// SmemLayoutK{}(_,_,_0{})));
45-
46-
// Tensor tensor_k = make_tensor(ptr_k, make_layout(make_shape(M,K,L),
47-
// args.stride_k)); auto tma_load_k = make_tma_copy(SM90_TMA_LOAD{},
48-
// gtensor_k, SmemLayoutK{}(_,_,_0{}));
49-
5038
// load Q using cp_async and K/V using tma
5139
template <class Block>
5240
CUTE_DEVICE void operator()(const Block& block,

0 commit comments

Comments
 (0)