@@ -12,7 +12,8 @@ namespace llm {
1212using namespace cute ;
1313
1414// AttentionTile specialization for AttentionParams
15- template <typename TileShape, // (BLK_M, BLK_N, BLK_K)
15+ template <typename ProblemShape, // (Q, K, D, ((KH, G), B))
16+ typename TileShape, // (BLK_M, BLK_N, BLK_K)
1617 typename BlocKCoord, // (m_block_idx, ((kv_head_idx, _0), batch_idx))
1718 typename Element, // Element type
1819 typename StrideQ, // (Q, D, ((KH, G), B))
@@ -46,27 +47,15 @@ struct FmhaBlock {
4647 StrideV v_stride;
4748 StrideO o_stride;
4849
49- // Parameters from problem shape
50- int batch_size;
51- int q_len;
52- int kv_len;
53- int head_dim;
54- // int n_heads;
55- int n_kv_heads; // number of kv heads
50+ // for fast divmod
5651 FastDivmod group_size;
5752 };
5853
59- template <class ProblemShape >
6054 static Params to_underlying_arguments (const ProblemShape& problem_shape,
6155 const Arguments& args,
62- void * workspace = nullptr ) {
56+ void * /* workspace*/ ) {
6357 // ProblemShape: (Q, K, D, ((KH, G), B))
64- const int q_len = get<0 >(problem_shape);
65- const int kv_len = get<1 >(problem_shape);
66- const int head_dim = get<2 >(problem_shape);
67- const int n_kv_heads = get<3 , 0 , 0 >(problem_shape);
6858 const int group_size = get<3 , 0 , 1 >(problem_shape);
69- const int batch_size = get<3 , 1 >(problem_shape);
7059
7160 // TODO: construct tma_load for k/v tensors
7261 return {
@@ -78,11 +67,6 @@ struct FmhaBlock {
7867 .k_stride = args.k_stride ,
7968 .v_stride = args.v_stride ,
8069 .o_stride = args.o_stride ,
81- .batch_size = batch_size,
82- .q_len = q_len,
83- .kv_len = kv_len,
84- .head_dim = head_dim,
85- .n_kv_heads = n_kv_heads,
8670 .group_size = FastDivmod (group_size),
8771 };
8872 }
@@ -96,19 +80,22 @@ struct FmhaBlock {
9680 using BLK_K = Int<kBlockK >;
9781
9882 // hold a reference to the parameters and block coordination
83+ const ProblemShape& problem_shape_;
9984 const Params& params_;
100- const BlocKCoord& blk_coord_ ;
85+ const BlocKCoord& coord_ ;
10186
10287 // derived parameters to avoid recomputation
10388 int m_block_base_;
10489 int packed_len_;
10590
10691 // Constructor
107- CUTE_HOST_DEVICE FmhaBlock (const Params& params, const BlocKCoord& blk_coord)
108- : params_(params), blk_coord_(blk_coord) {
109- // derive parameters
110- m_block_base_ = get<0 >(blk_coord) * get<0 >(TileShape{});
111- packed_len_ = params_.q_len * params_.group_size ;
92+ CUTE_HOST_DEVICE FmhaBlock (const ProblemShape& problem_shape,
93+ const Params& params,
94+ const BlocKCoord& coord)
95+ : problem_shape_(problem_shape), params_(params), coord_(coord) {
96+ // derived parameters
97+ m_block_base_ = get<0 >(coord) * get<0 >(TileShape{});
98+ packed_len_ = get<0 >(problem_shape_) * params_.group_size ;
11299 }
113100
114101 // check if the m_block is valid
@@ -120,36 +107,33 @@ struct FmhaBlock {
120107 // returns packed_len
121108 CUTE_HOST_DEVICE int get_packed_len () const { return packed_len_; }
122109
123- // returns actual query length
124- CUTE_HOST_DEVICE int get_q_len () const { return params_.q_len ; }
125-
126- // returns actual kv length
127- CUTE_HOST_DEVICE int get_kv_len () const { return params_.kv_len ; }
110+ // returns problem shape: (Q, K, D, ((KH, G), B))
111+ CUTE_HOST_DEVICE const auto & get_problem_shape () const {
112+ return problem_shape_;
113+ }
128114
129- // returns head_dim
130- CUTE_HOST_DEVICE int get_head_dim () const { return params_. head_dim ; }
115+ // returns (m_block_idx, ((kv_head_idx, _0), batch_idx))
116+ CUTE_HOST_DEVICE const auto & get_coord () const { return coord_ ; }
131117
132- // returns group size
118+ // returns group size fast divmod
133119 CUTE_HOST_DEVICE const FastDivmod& get_group_size () const {
134120 return params_.group_size ;
135121 }
136122
137123 // returns redidue mnk
138124 CUTE_HOST_DEVICE auto get_residue_mnk () const {
139- return make_tuple (packed_len_, params_.kv_len , params_.head_dim );
125+ auto residue_mnk = select<0 , 1 , 2 >(problem_shape_);
126+ get<0 >(residue_mnk) = packed_len_;
127+ return residue_mnk;
140128 }
141129
142- // returns (m_block_idx, ((kv_head_idx, _0), batch_idx))
143- CUTE_HOST_DEVICE const auto & get_block_coord () const { return blk_coord_; }
144-
145130 // returns kv block range: (n_block_min, n_block_max]
146131 template <bool kLocal >
147132 CUTE_HOST_DEVICE auto get_kv_blocks (int sliding_window) const {
148133 static constexpr int kBlockM = get<0 >(TileShape{});
149134 static constexpr int kBlockN = get<1 >(TileShape{});
150135
151- const int q_len = params_.q_len ;
152- const int kv_len = params_.kv_len ;
136+ const auto [q_len, kv_len] = select<0 , 1 >(problem_shape_);
153137 const int q_idx = m_block_base_ / params_.group_size ;
154138 // take care of causal mask
155139 const int diagonal = q_idx + kv_len - q_len;
@@ -168,15 +152,11 @@ struct FmhaBlock {
168152 // return the query tile: (BLK_M, BLK_K) => (M, K)
169153 CUTE_HOST_DEVICE auto get_q_tile () const {
170154 // (Q, D, ((KH, G), B))
171- auto q_shape = make_shape (
172- params_.q_len ,
173- params_.head_dim ,
174- make_shape (make_shape (params_.n_kv_heads , (int )params_.group_size ),
175- params_.batch_size ));
155+ auto q_shape = select<0 , 2 , 3 >(problem_shape_);
176156 auto mQ =
177157 make_tensor (make_gmem_ptr (params_.q_ptr ), q_shape, params_.q_stride );
178158 // (Q, D, G*)
179- auto Q = mQ (_, _, get<1 >(blk_coord_ ));
159+ auto Q = mQ (_, _, get<1 >(coord_ ));
180160
181161 // packing all q in the same kv head group together
182162 auto packed_idx_to_coord = [this ](int packed_idx) {
@@ -192,12 +172,13 @@ struct FmhaBlock {
192172 get<1 >(params_.q_stride ));
193173
194174 // packed tensor: (pQ, D) => ((Q, G), D)
175+ const int head_dim = get<2 >(problem_shape_);
195176 auto pQ = make_gather_tensor (Q.data (),
196- make_shape (packed_len_, params_. head_dim ),
177+ make_shape (packed_len_, head_dim),
197178 q_stride,
198179 packed_idx_to_coord);
199180
200- const auto m_block_idx = get<0 >(blk_coord_ );
181+ const auto m_block_idx = get<0 >(coord_ );
201182 // (BLK_M, BLK_K)
202183 Tensor gQ =
203184 local_tile (pQ, Shape<BLK_M, BLK_K>{}, make_coord (m_block_idx, _0{}));
@@ -211,15 +192,11 @@ struct FmhaBlock {
211192 // return the output tile: (BLK_M, BLK_K) => (M, K)
212193 CUTE_HOST_DEVICE auto get_o_tile () const {
213194 // (Q, D, ((KH, G), B))
214- auto o_shape = make_shape (
215- params_.q_len ,
216- params_.head_dim ,
217- make_shape (make_shape (params_.n_kv_heads , (int )params_.group_size ),
218- params_.batch_size ));
195+ auto o_shape = select<0 , 2 , 3 >(problem_shape_);
219196 auto mO =
220197 make_tensor (make_gmem_ptr (params_.o_ptr ), o_shape, params_.o_stride );
221198 // (Q, D, G*)
222- auto O = mO (_, _, get<1 >(blk_coord_ ));
199+ auto O = mO (_, _, get<1 >(coord_ ));
223200
224201 // packing all q in the same kv head group together
225202 auto packed_idx_to_coord = [this ](int packed_idx) {
@@ -231,39 +208,36 @@ struct FmhaBlock {
231208 auto o_stride = make_stride (
232209 make_stride (get<0 >(params_.o_stride ), get<2 , 0 , 1 >(params_.o_stride )),
233210 get<1 >(params_.o_stride ));
211+ const int head_dim = get<2 >(problem_shape_);
234212 // packed tensor: (pO, D) => ((O, G), D)
235213 auto pO = make_gather_tensor (O.data (),
236- make_shape (packed_len_, params_. head_dim ),
214+ make_shape (packed_len_, head_dim),
237215 o_stride,
238216 packed_idx_to_coord);
239217
240- const auto m_block_idx = get<0 >(blk_coord_ );
218+ const auto m_block_idx = get<0 >(coord_ );
241219 // (BLK_M, BLK_K)
242220 Tensor gO =
243221 local_tile (pO, Shape<BLK_M, BLK_K>{}, make_coord (m_block_idx, _0{}));
244222 // (BLK_M, BLK_K) => (M, K)
245- Tensor cQ = local_tile (make_identity_tensor (shape (pO)),
223+ Tensor cO = local_tile (make_identity_tensor (shape (pO)),
246224 Shape<BLK_M, BLK_K>{},
247225 make_coord (m_block_idx, _0{}));
248- return make_tuple (gO , cQ );
226+ return make_tuple (gO , cO );
249227 }
250228
251229 // return the key/value tile: (BLK_N, BLK_K, n) => (N, K)
252230 CUTE_HOST_DEVICE auto get_kv_tile () const {
253231 // (KV, D, ((KH, G), B))
254- auto kv_shape = make_shape (
255- params_.kv_len ,
256- params_.head_dim ,
257- make_shape (make_shape (params_.n_kv_heads , (int )params_.group_size ),
258- params_.batch_size ));
232+ auto kv_shape = select<1 , 2 , 3 >(problem_shape_);
259233 auto mK =
260234 make_tensor (make_gmem_ptr (params_.k_ptr ), kv_shape, params_.k_stride );
261235 auto mV =
262236 make_tensor (make_gmem_ptr (params_.v_ptr ), kv_shape, params_.v_stride );
263237
264238 // (K/V, D)
265- auto K = mK (_, _, get<1 >(blk_coord_ ));
266- auto V = mV (_, _, get<1 >(blk_coord_ ));
239+ auto K = mK (_, _, get<1 >(coord_ ));
240+ auto V = mV (_, _, get<1 >(coord_ ));
267241
268242 // (BLK_N, BLK_K, n)
269243 Tensor gK = local_tile (K, Shape<BLK_N, BLK_K>{}, make_coord (_, _0{}));
0 commit comments