@@ -41,22 +41,25 @@ struct MHATile<MHAParams> {
4141 // (batch, seq, head, dim)
4242
4343 // packed all q/o in the same kv head group together
44- const auto head_base = kv_head_idx_ * params_.group_size ;
45- auto packed_idx_to_coord = [this , head_base](int packed_idx) {
44+ auto packed_idx_to_coord = [this ](int packed_idx) {
4645 int idx, offset;
4746 params_.group_size .divmod (packed_idx, idx, offset);
48- return make_coord (idx, head_base + offset);
47+ return make_coord (idx, offset);
4948 };
5049
5150 const auto packed_len = params_.q_len * params_.group_size ;
52- const auto q_offset = batch_idx_ * get<0 >(params_.q_stride );
51+ const auto q_offset =
52+ (batch_idx_ * get<0 >(params_.q_stride )) +
53+ (kv_head_idx_ * params_.group_size * get<2 >(params_.q_stride ));
5354 auto q = make_gather_tensor (
5455 make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
5556 make_shape (packed_len, params_.head_dim ),
5657 make_stride (select<1 , 2 >(params_.q_stride ), get<3 >(params_.q_stride )),
5758 packed_idx_to_coord);
5859
59- const auto o_offset = batch_idx_ * get<0 >(params_.o_stride );
60+ const auto o_offset =
61+ (batch_idx_ * get<0 >(params_.o_stride )) +
62+ (kv_head_idx_ * params_.group_size * get<2 >(params_.o_stride ));
6063 auto o = make_gather_tensor (
6164 make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
6265 make_shape (packed_len, params_.head_dim ),
@@ -69,10 +72,10 @@ struct MHATile<MHAParams> {
6972 template <typename Element>
7073 CUTE_HOST_DEVICE auto get_kv_tile () const {
7174 // (batch, seq, kv_head, dim)
72- const auto k_offset = batch_idx_ * get<0 >(params_.k_stride ) +
73- kv_head_idx_ * get<2 >(params_.k_stride );
74- const auto v_offset = batch_idx_ * get<0 >(params_.v_stride ) +
75- kv_head_idx_ * get<2 >(params_.v_stride );
75+ const auto k_offset = ( batch_idx_ * get<0 >(params_.k_stride ) ) +
76+ ( kv_head_idx_ * get<2 >(params_.k_stride ) );
77+ const auto v_offset = ( batch_idx_ * get<0 >(params_.v_stride ) ) +
78+ ( kv_head_idx_ * get<2 >(params_.v_stride ) );
7679 // k[batch_idx, :, kv_head_idx, :]
7780 auto k =
7881 make_tensor (make_gmem_ptr ((const Element*)params_.k_ptr + k_offset),
@@ -105,22 +108,26 @@ struct MHATile<MHAPagedKVParams> {
105108 CUTE_HOST_DEVICE auto get_qo_tile () const {
106109 const auto begin = params_.q_cu_lens [batch_idx_];
107110 const auto qo_len = params_.q_cu_lens [batch_idx_ + 1 ] - begin;
108- const auto head_base = kv_head_idx_ * params_. group_size ;
109- auto packed_idx_to_coord = [this , head_base ](int packed_idx) {
111+
112+ auto packed_idx_to_coord = [this ](int packed_idx) {
110113 int idx, offset;
111114 params_.group_size .divmod (packed_idx, idx, offset);
112- return make_coord (idx, head_base + offset);
115+ return make_coord (idx, offset);
113116 };
114117
115118 const auto packed_len = qo_len * params_.group_size ;
116- const auto q_offset = begin * get<0 >(params_.q_stride );
119+ const auto q_offset =
120+ (begin * get<0 >(params_.q_stride )) +
121+ (kv_head_idx_ * params_.group_size * get<1 >(params_.q_stride ));
117122 auto q = make_gather_tensor (
118123 make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
119124 make_shape (packed_len, params_.head_dim ),
120125 make_stride (select<0 , 1 >(params_.q_stride ), get<2 >(params_.q_stride )),
121126 packed_idx_to_coord);
122127
123- const auto o_offset = begin * get<0 >(params_.o_stride );
128+ const auto o_offset =
129+ (begin * get<0 >(params_.o_stride )) +
130+ (kv_head_idx_ * params_.group_size * get<1 >(params_.o_stride ));
124131 auto o = make_gather_tensor (
125132 make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
126133 make_shape (packed_len, params_.head_dim ),
0 commit comments