@@ -24,21 +24,43 @@ struct AttentionTile<AttentionParams> {
2424
2525 // return the query/output tile: (q_len, head_dim)
2626 template <typename Element>
27- CUTE_HOST_DEVICE auto get_qo_tile (int batch_idx, int head_idx ) const {
27+ CUTE_HOST_DEVICE auto get_qo_tile (int batch_idx, int kv_head_idx ) const {
2828 // (batch, seq, head, dim)
29- const auto q_offset = batch_idx * get<0 >(params_.q_stride ) +
30- head_idx * get<2 >(params_.q_stride );
31- const auto o_offset = batch_idx * get<0 >(params_.o_stride ) +
32- head_idx * get<2 >(params_.o_stride );
33-
34- // q[batch_idx, :, head_idx, :]
35- auto q =
36- make_tensor (make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
37- make_shape (params_.q_len , params_.head_dim ),
38- make_stride (get<1 >(params_.q_stride ), _1{}));
39- auto o = make_tensor (make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
40- make_shape (params_.q_len , params_.head_dim ),
41- make_stride (get<1 >(params_.o_stride ), _1{}));
29+
30+ // packed all q/o in the same kv head group together
31+ // q/o [batch, n_tokens, n_heads, dim]
32+ // => q/o [*batch_idx, n_tokens, n_heads, dim]
33+ // => q/o [n_tokens, group_size, n_kv_heads, dim]
34+ // => q/o [n_tokens, group_size, *kv_head_idx, dim]
35+ // => q/o [(group_size, n_tokens), dim]
36+ // => q/o [packed_len, dim]
37+ const auto group_size = params_.group_size ;
38+ const auto head_base = kv_head_idx * group_size;
39+ auto packed_idx_to_coord = [group_size, head_base](int packed_idx) {
40+ const int idx = packed_idx / group_size;
41+ const int offset = packed_idx % group_size;
42+ // (group_size, n_tokens)
43+ return make_coord (head_base + offset, idx);
44+ };
45+
46+ const auto packed_len = params_.q_len * group_size;
47+ const auto q_offset = batch_idx * get<0 >(params_.q_stride );
48+ auto q = make_gather_tensor (
49+ make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
50+ make_shape (packed_len, params_.head_dim ),
51+ make_stride (
52+ make_stride (get<2 >(params_.q_stride ), get<1 >(params_.q_stride )),
53+ _1{}),
54+ packed_idx_to_coord);
55+
56+ const auto o_offset = batch_idx * get<0 >(params_.o_stride );
57+ auto o = make_gather_tensor (
58+ make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
59+ make_shape (packed_len, params_.head_dim ),
60+ make_stride (
61+ make_stride (get<2 >(params_.o_stride ), get<1 >(params_.o_stride )),
62+ _1{}),
63+ packed_idx_to_coord);
4264 return make_tuple (q, o);
4365 }
4466
@@ -75,24 +97,37 @@ struct AttentionTile<VarLenAttentionParams> {
7597
7698 // return the query tile: (q_len, head_dim)
7799 template <typename Element>
78- CUTE_HOST_DEVICE auto get_qo_tile (int batch_idx, int head_idx ) const {
100+ CUTE_HOST_DEVICE auto get_qo_tile (int batch_idx, int kv_head_idx ) const {
79101 const auto begin = params_.q_cu_lens [batch_idx];
80102 const auto qo_len = params_.q_cu_lens [batch_idx + 1 ] - begin;
81- // (seq, head, dim)
82- const auto q_offset =
83- begin * get<0 >(params_.q_stride ) + head_idx * get<1 >(params_.q_stride );
84- const auto o_offset =
85- begin * get<0 >(params_.o_stride ) + head_idx * get<1 >(params_.o_stride );
86-
87- // q[begin:begin + q_len, head_idx, :]
88- auto q =
89- make_tensor (make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
90- make_shape (qo_len, params_.head_dim ),
91- make_stride (get<0 >(params_.q_stride ), _1{}));
92- // o[begin:begin + o_len, head_idx, :]
93- auto o = make_tensor (make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
94- make_shape (qo_len, params_.head_dim ),
95- make_stride (get<0 >(params_.o_stride ), _1{}));
103+
104+ const auto group_size = params_.group_size ;
105+ const auto head_base = kv_head_idx * group_size;
106+ auto packed_idx_to_coord = [group_size, head_base](int packed_idx) {
107+ const int idx = packed_idx / group_size;
108+ const int offset = packed_idx % group_size;
109+ // (group_size, n_tokens)
110+ return make_coord (head_base + offset, idx);
111+ };
112+
113+ const auto packed_len = qo_len * group_size;
114+ const auto q_offset = begin * get<0 >(params_.q_stride );
115+ auto q = make_gather_tensor (
116+ make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
117+ make_shape (packed_len, params_.head_dim ),
118+ make_stride (
119+ make_stride (get<1 >(params_.q_stride ), get<0 >(params_.q_stride )),
120+ _1{}),
121+ packed_idx_to_coord);
122+
123+ const auto o_offset = begin * get<0 >(params_.o_stride );
124+ auto o = make_gather_tensor (
125+ make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
126+ make_shape (packed_len, params_.head_dim ),
127+ make_stride (
128+ make_stride (get<1 >(params_.o_stride ), get<0 >(params_.o_stride )),
129+ _1{}),
130+ packed_idx_to_coord);
96131 return make_tuple (q, o);
97132 }
98133
@@ -132,24 +167,36 @@ struct AttentionTile<PagedKVAttentionParams> {
132167
133168 // return the query/output tile: (q_len, head_dim)
134169 template <typename Element>
135- CUTE_HOST_DEVICE auto get_qo_tile (int batch_idx, int head_idx ) const {
170+ CUTE_HOST_DEVICE auto get_qo_tile (int batch_idx, int kv_head_idx ) const {
136171 const auto begin = params_.q_cu_lens [batch_idx];
137172 const auto qo_len = params_.q_cu_lens [batch_idx + 1 ] - begin;
138- // (seq, head, dim)
139- const auto q_offset =
140- begin * get<0 >(params_.q_stride ) + head_idx * get<1 >(params_.q_stride );
141- const auto o_offset =
142- begin * get<0 >(params_.o_stride ) + head_idx * get<1 >(params_.o_stride );
143-
144- // q[begin:begin + q_len, head_idx, :]
145- auto q =
146- make_tensor (make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
147- make_shape (qo_len, params_.head_dim ),
148- make_stride (get<0 >(params_.q_stride ), _1{}));
149- // o[begin:begin + o_len, head_idx, :]
150- auto o = make_tensor (make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
151- make_shape (qo_len, params_.head_dim ),
152- make_stride (get<0 >(params_.o_stride ), _1{}));
173+ const auto group_size = params_.group_size ;
174+ const auto head_base = kv_head_idx * group_size;
175+ auto packed_idx_to_coord = [group_size, head_base](int packed_idx) {
176+ const int idx = packed_idx / group_size;
177+ const int offset = packed_idx % group_size;
178+ // (group_size, n_tokens)
179+ return make_coord (head_base + offset, idx);
180+ };
181+
182+ const auto packed_len = qo_len * group_size;
183+ const auto q_offset = begin * get<0 >(params_.q_stride );
184+ auto q = make_gather_tensor (
185+ make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
186+ make_shape (packed_len, params_.head_dim ),
187+ make_stride (
188+ make_stride (get<1 >(params_.q_stride ), get<0 >(params_.q_stride )),
189+ _1{}),
190+ packed_idx_to_coord);
191+
192+ const auto o_offset = begin * get<0 >(params_.o_stride );
193+ auto o = make_gather_tensor (
194+ make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
195+ make_shape (packed_len, params_.head_dim ),
196+ make_stride (
197+ make_stride (get<1 >(params_.o_stride ), get<0 >(params_.o_stride )),
198+ _1{}),
199+ packed_idx_to_coord);
153200 return make_tuple (q, o);
154201 }
155202
0 commit comments