Skip to content

Commit ce2c59a

Browse files
committed
update kv_head_base logic
1 parent 56e8bed commit ce2c59a

File tree

4 files changed

+52
-44
lines changed

4 files changed

+52
-44
lines changed

src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ struct Sm120CollectiveLoadCpAsyncWs {
5555
// (M, N, K)
5656
const auto residue_mnk = block.get_residue_mnk();
5757

58-
// (BLK_M, HEAD_DIM) => (M, K)
58+
// (BLK_M, BLK_K) => (M, K)
5959
auto [gQ, cQ] = block.get_q_tile();
60-
// (BLK_N, HEAD_DIM, n) => (N, K)
60+
// (BLK_N, BLK_K, n) => (N, K)
6161
auto [gK, gV, cKV] = block.get_kv_tile();
6262

6363
// Construct smem tensors
64-
// (BLK_M, HEAD_DIM), k-major
64+
// (BLK_M, BLK_K), k-major
6565
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
66-
// (BLK_N, HEAD_DIM, KVStages), k-major
66+
// (BLK_N, BLK_K, KVStages), k-major
6767
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
6868
Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{});
6969

src/kernels/attention/collective/sm120_collective_load_tma_ws.cuh

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cute/tensor.hpp>
1111

1212
#include "common/safe_copy.h"
13+
#include "common/selector.h"
1314

1415
namespace llm {
1516

@@ -26,19 +27,12 @@ template <class TileShape,
2627
class PipelineKV,
2728
bool EVEN_K>
2829
struct Sm120CollectiveLoadTmaWs {
30+
static constexpr int kThreads = 128;
2931
static constexpr int kBlockK = get<2>(TileShape{});
30-
// Thr layout for gmem copy
31-
using GmemCopyThrLayout_ =
32-
std::conditional_t<kBlockK == 32,
33-
Layout<Shape<_32, _4>, Stride<_4, _1>>,
34-
Layout<Shape<_16, _8>, Stride<_8, _1>>>;
35-
36-
// g2s tiled copy for q
37-
using GmemTiledCopyQ = decltype(make_tiled_copy(
38-
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>{},
39-
GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4)
40-
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
41-
));
32+
// g2s tiled copy for Q
33+
using GmemTiledCopyQ =
34+
decltype(gmem_tiled_copy_selector<Element, kThreads, kBlockK>(
35+
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>{}));
4236

4337
// using StrideK = ...;
4438

@@ -74,13 +68,13 @@ struct Sm120CollectiveLoadTmaWs {
7468
// (M, N, K)
7569
const auto residue_mnk = block.get_residue_mnk();
7670

77-
// (BLK_M, HEAD_DIM) => (M, K)
71+
// (BLK_M, BLK_K) => (M, K)
7872
auto [gQ, cQ] = block.get_q_tile();
7973

8074
// Construct smem tensors
81-
// (BLK_M, HEAD_DIM), k-major
75+
// (BLK_M, BLK_K), k-major
8276
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
83-
// (BLK_N, HEAD_DIM, KVStages), k-major
77+
// (BLK_N, BLK_K, KVStages), k-major
8478
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
8579
Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{});
8680

src/kernels/attention/common/fmha_block.h

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,20 @@ struct FmhaBlock {
9494
const auto& [batch_idx, m_block_idx, kv_head_idx] = blk_coord_;
9595

9696
// packing all q in the same kv head group together
97-
const auto head_base = kv_head_idx * params_.group_size;
98-
auto packed_idx_to_coord = [this, head_base](int packed_idx) {
97+
auto packed_idx_to_coord = [this](int packed_idx) {
9998
// packed_idx => (seq, kv_heads):(group_size, 1)
10099
int idx, offset;
101100
params_.group_size.divmod(packed_idx, idx, offset);
102-
return make_coord(idx, head_base + offset);
101+
return make_coord(idx, offset);
103102
};
104103

105-
// (batch, seq, head, dim) => ((seq, kv_head), dim)
106-
const auto offset = batch_idx * get<0>(params_.q_stride);
107-
// (q_packed_len, head_dim) gmem tensor
104+
// (batch, seq, head, dim)
105+
// => (batch, seq, (kv_heads, group), dim)
106+
// => (seq, group, dim)
107+
const auto offset =
108+
batch_idx * get<0>(params_.q_stride) +
109+
kv_head_idx * params_.group_size * get<2>(params_.q_stride);
110+
// gmem tensor: (packed_len, dim) => ((seq, group), dim)
108111
auto Q = make_gather_tensor(
109112
make_gmem_ptr((const Element*)params_.q_ptr + offset),
110113
make_shape(packed_len_, params_.head_dim),
@@ -126,16 +129,20 @@ struct FmhaBlock {
126129
const auto& [batch_idx, m_block_idx, kv_head_idx] = blk_coord_;
127130

128131
// packing all q in the same kv head group together
129-
const auto head_base = kv_head_idx * params_.group_size;
130-
auto packed_idx_to_coord = [this, head_base](int packed_idx) {
132+
auto packed_idx_to_coord = [this](int packed_idx) {
131133
// packed_idx => (seq, kv_heads):(group_size, 1)
132134
int idx, offset;
133135
params_.group_size.divmod(packed_idx, idx, offset);
134-
return make_coord(idx, head_base + offset);
136+
return make_coord(idx, offset);
135137
};
136138

137-
// (batch, seq, head, dim) => ((seq, head), dim)
138-
const auto offset = batch_idx * get<0>(params_.o_stride);
139+
// (batch, seq, head, dim)
140+
// => (batch, seq, (kv_heads, group), dim)
141+
// => (seq, group, dim)
142+
const auto offset =
143+
batch_idx * get<0>(params_.o_stride) +
144+
kv_head_idx * params_.group_size * get<2>(params_.o_stride);
145+
// gmem tensor: (packed_len, dim) => ((seq, group), dim)
139146
auto O = make_gather_tensor(
140147
make_gmem_ptr((Element*)params_.o_ptr + offset),
141148
make_shape(packed_len_, params_.head_dim),

src/kernels/attention/kernel/sm80_kernel_mha.cuh

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,25 @@ struct MHATile<MHAParams> {
4141
// (batch, seq, head, dim)
4242

4343
// packed all q/o in the same kv head group together
44-
const auto head_base = kv_head_idx_ * params_.group_size;
45-
auto packed_idx_to_coord = [this, head_base](int packed_idx) {
44+
auto packed_idx_to_coord = [this](int packed_idx) {
4645
int idx, offset;
4746
params_.group_size.divmod(packed_idx, idx, offset);
48-
return make_coord(idx, head_base + offset);
47+
return make_coord(idx, offset);
4948
};
5049

5150
const auto packed_len = params_.q_len * params_.group_size;
52-
const auto q_offset = batch_idx_ * get<0>(params_.q_stride);
51+
const auto q_offset =
52+
(batch_idx_ * get<0>(params_.q_stride)) +
53+
(kv_head_idx_ * params_.group_size * get<2>(params_.q_stride));
5354
auto q = make_gather_tensor(
5455
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
5556
make_shape(packed_len, params_.head_dim),
5657
make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)),
5758
packed_idx_to_coord);
5859

59-
const auto o_offset = batch_idx_ * get<0>(params_.o_stride);
60+
const auto o_offset =
61+
(batch_idx_ * get<0>(params_.o_stride)) +
62+
(kv_head_idx_ * params_.group_size * get<2>(params_.o_stride));
6063
auto o = make_gather_tensor(
6164
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
6265
make_shape(packed_len, params_.head_dim),
@@ -69,10 +72,10 @@ struct MHATile<MHAParams> {
6972
template <typename Element>
7073
CUTE_HOST_DEVICE auto get_kv_tile() const {
7174
// (batch, seq, kv_head, dim)
72-
const auto k_offset = batch_idx_ * get<0>(params_.k_stride) +
73-
kv_head_idx_ * get<2>(params_.k_stride);
74-
const auto v_offset = batch_idx_ * get<0>(params_.v_stride) +
75-
kv_head_idx_ * get<2>(params_.v_stride);
75+
const auto k_offset = (batch_idx_ * get<0>(params_.k_stride)) +
76+
(kv_head_idx_ * get<2>(params_.k_stride));
77+
const auto v_offset = (batch_idx_ * get<0>(params_.v_stride)) +
78+
(kv_head_idx_ * get<2>(params_.v_stride));
7679
// k[batch_idx, :, kv_head_idx, :]
7780
auto k =
7881
make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset),
@@ -105,22 +108,26 @@ struct MHATile<MHAPagedKVParams> {
105108
CUTE_HOST_DEVICE auto get_qo_tile() const {
106109
const auto begin = params_.q_cu_lens[batch_idx_];
107110
const auto qo_len = params_.q_cu_lens[batch_idx_ + 1] - begin;
108-
const auto head_base = kv_head_idx_ * params_.group_size;
109-
auto packed_idx_to_coord = [this, head_base](int packed_idx) {
111+
112+
auto packed_idx_to_coord = [this](int packed_idx) {
110113
int idx, offset;
111114
params_.group_size.divmod(packed_idx, idx, offset);
112-
return make_coord(idx, head_base + offset);
115+
return make_coord(idx, offset);
113116
};
114117

115118
const auto packed_len = qo_len * params_.group_size;
116-
const auto q_offset = begin * get<0>(params_.q_stride);
119+
const auto q_offset =
120+
(begin * get<0>(params_.q_stride)) +
121+
(kv_head_idx_ * params_.group_size * get<1>(params_.q_stride));
117122
auto q = make_gather_tensor(
118123
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
119124
make_shape(packed_len, params_.head_dim),
120125
make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)),
121126
packed_idx_to_coord);
122127

123-
const auto o_offset = begin * get<0>(params_.o_stride);
128+
const auto o_offset =
129+
(begin * get<0>(params_.o_stride)) +
130+
(kv_head_idx_ * params_.group_size * get<1>(params_.o_stride));
124131
auto o = make_gather_tensor(
125132
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
126133
make_shape(packed_len, params_.head_dim),

0 commit comments

Comments
 (0)