@@ -44,18 +44,14 @@ struct MHATile<MHAParams> {
4444 auto q = make_gather_tensor (
4545 make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
4646 make_shape (packed_len, params_.head_dim ),
47- make_stride (
48- make_stride (get<1 >(params_.q_stride ), get<2 >(params_.q_stride )),
49- _1{}),
47+ make_stride (select<1 , 2 >(params_.q_stride ), get<3 >(params_.q_stride )),
5048 packed_idx_to_coord);
5149
5250 const auto o_offset = batch_idx_ * get<0 >(params_.o_stride );
5351 auto o = make_gather_tensor (
5452 make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
5553 make_shape (packed_len, params_.head_dim ),
56- make_stride (
57- make_stride (get<1 >(params_.o_stride ), get<2 >(params_.o_stride )),
58- _1{}),
54+ make_stride (select<1 , 2 >(params_.o_stride ), get<3 >(params_.o_stride )),
5955 packed_idx_to_coord);
6056 return make_tuple (q, o);
6157 }
@@ -72,12 +68,12 @@ struct MHATile<MHAParams> {
7268 auto k =
7369 make_tensor (make_gmem_ptr ((const Element*)params_.k_ptr + k_offset),
7470 make_shape (params_.kv_len , params_.head_dim ),
75- make_stride (get< 1 >(params_.k_stride ), _1{} ));
71+ select< 1 , 3 >(params_.k_stride ));
7672 // v[batch_idx, :, kv_head_idx, :]
7773 auto v =
7874 make_tensor (make_gmem_ptr ((const Element*)params_.v_ptr + v_offset),
7975 make_shape (params_.kv_len , params_.head_dim ),
80- make_stride (get< 1 >(params_.v_stride ), _1{} ));
76+ select< 1 , 3 >(params_.v_stride ));
8177 return make_tuple (k, v);
8278 }
8379};
@@ -112,18 +108,14 @@ struct MHATile<MHAPagedKVParams> {
112108 auto q = make_gather_tensor (
113109 make_gmem_ptr ((const Element*)params_.q_ptr + q_offset),
114110 make_shape (packed_len, params_.head_dim ),
115- make_stride (
116- make_stride (get<0 >(params_.q_stride ), get<1 >(params_.q_stride )),
117- _1{}),
111+ make_stride (select<0 , 1 >(params_.q_stride ), get<2 >(params_.q_stride )),
118112 packed_idx_to_coord);
119113
120114 const auto o_offset = begin * get<0 >(params_.o_stride );
121115 auto o = make_gather_tensor (
122116 make_gmem_ptr ((Element*)params_.o_ptr + o_offset),
123117 make_shape (packed_len, params_.head_dim ),
124- make_stride (
125- make_stride (get<0 >(params_.o_stride ), get<1 >(params_.o_stride )),
126- _1{}),
118+ make_stride (select<0 , 1 >(params_.o_stride ), get<2 >(params_.o_stride )),
127119 packed_idx_to_coord);
128120 return make_tuple (q, o);
129121 }
@@ -148,18 +140,18 @@ struct MHATile<MHAPagedKVParams> {
148140 auto k = make_gather_tensor (
149141 make_gmem_ptr ((const Element*)params_.k_ptr + k_offset),
150142 make_shape (kv_len, params_.head_dim ),
151- make_stride (get< 0 >(params_.k_stride ), _1{} ),
143+ select< 0 , 2 >(params_.k_stride ),
152144 idx_to_slot);
153145
154146 // v[:, kv_head_idx, :]
155147 const auto v_offset = kv_head_idx_ * get<1 >(params_.v_stride );
156148 auto v = make_gather_tensor (
157149 make_gmem_ptr ((const Element*)params_.v_ptr + v_offset),
158150 make_shape (kv_len, params_.head_dim ),
159- make_stride (get< 0 >(params_.v_stride ), _1{} ),
151+ select< 0 , 2 >(params_.v_stride ),
160152 idx_to_slot);
161153 return make_tuple (k, v);
162154 }
163155};
164156
165- } // namespace llm
157+ } // namespace llm
0 commit comments