@@ -57,9 +57,9 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
5757 using SmemTiledCopyO = typename Traits::SmemTiledCopyO;
5858
5959 const int m_block = blockIdx .x ;
60- const auto batch_idx = blockIdx .y ;
61- const auto head_idx = blockIdx .z ;
62- const auto tidx = threadIdx .x ;
60+ const int batch_idx = blockIdx .y ;
61+ const int head_idx = blockIdx .z ;
62+ const int tidx = threadIdx .x ;
6363
6464 AttentionTile<Params> tile (params);
6565
@@ -75,7 +75,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
7575 const int kv_len = size<0 >(K);
7676
7777 if (m_block * kBlockM >= q_len) {
78- // out of bound, return
78+ // m out of bound, return
7979 return ;
8080 }
8181
@@ -134,46 +134,51 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
134134 // (BLK_M, HEAD_DIM) -> (blk_m, head_dim)
135135 Tensor cQ = make_identity_tensor (Shape<_BLK_M, _HEAD_DIM>{});
136136 Tensor tQcQ = gmem_thr_copy_Q.partition_S (cQ);
137- // (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
138- Tensor cKV = make_identity_tensor (Shape<_BLK_N, _HEAD_DIM>{});
139- Tensor tKcKV = gmem_thr_copy_KV.partition_S (cKV);
140137
141138 auto produce_q = [&]() {
142139 auto tQgQ = gmem_thr_copy_Q.partition_S (gQ );
143140 auto tQsQ = gmem_thr_copy_Q.partition_D (sQ );
141+ auto max_coord = make_coord (q_len - m_block * kBlockM , head_dim);
144142 safe_copy</* EVEN_MN=*/ false , EVEN_K>(
145- gmem_tiled_copy_Q,
146- tQgQ,
147- tQsQ,
148- tQcQ,
149- make_coord (q_len - m_block * kBlockM , head_dim));
143+ gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord);
150144 };
151145
152- // TODO: seperate mask iterations
146+ // (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
147+ Tensor cKV = make_identity_tensor (Shape<_BLK_N, _HEAD_DIM>{});
148+ Tensor tKVcKV = gmem_thr_copy_KV.partition_S (cKV);
149+
153150 Tensor tKsK = gmem_thr_copy_KV.partition_D (sK );
154151 auto produce_k = [&](int ni) {
155152 auto tKgK = gmem_thr_copy_KV.partition_S (gK (_, _, ni));
153+ auto max_coord = make_coord (kv_len - ni * kBlockN , head_dim);
156154 // skip zfill_mn for k since mask will mask out oob with -inf
157155 safe_copy</* EVEN_MN=*/ false ,
158156 EVEN_K,
159- /* ZERO_FILL_MN=*/ false >(
160- gmem_tiled_copy_KV,
161- tKgK,
162- tKsK,
163- tKcKV,
164- make_coord (kv_len - ni * kBlockN , head_dim));
157+ /* ZFILL_MN=*/ false >(
158+ gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
159+ };
160+
161+ auto produce_k_no_oob = [&](int ni) {
162+ auto tKgK = gmem_thr_copy_KV.partition_S (gK (_, _, ni));
163+ auto max_coord = make_coord (kv_len - ni * kBlockN , head_dim);
164+ safe_copy</* EVEN_MN=*/ true , EVEN_K>(
165+ gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
165166 };
166167
167168 Tensor tVsV = gmem_thr_copy_KV.partition_D (sV );
168169 auto produce_v = [&](int ni) {
169170 auto tVgV = gmem_thr_copy_KV.partition_S (gV (_, _, ni));
171+ auto max_coord = make_coord (kv_len - ni * kBlockN , head_dim);
170172 // skipping ZFILL_MN for v may cause nan issue
171173 safe_copy</* EVEN_MN=*/ false , EVEN_K>(
172- gmem_tiled_copy_KV,
173- tVgV,
174- tVsV,
175- tKcKV,
176- make_coord (kv_len - ni * kBlockN , head_dim));
174+ gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
175+ };
176+
177+ auto produce_v_no_oob = [&](int ni) {
178+ auto tVgV = gmem_thr_copy_KV.partition_S (gV (_, _, ni));
179+ auto max_coord = make_coord (kv_len - ni * kBlockN , head_dim);
180+ safe_copy</* EVEN_MN=*/ true , EVEN_K>(
181+ gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
177182 };
178183
179184 TiledMma tiled_mma;
@@ -281,84 +286,131 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
281286
282287 // wait for smem copy done before gmem copy
283288 __syncthreads ();
289+
290+ auto max_coord = make_coord (q_len - m_block * kBlockM , head_dim);
284291 safe_copy</* EVEN_MN=*/ false ,
285292 EVEN_K,
286- /* ZERO_FILL_MN=*/ false ,
287- /* ZERO_FILL_K=*/ false >(
288- gmem_tiled_copy_O,
289- tOsO,
290- tOgO,
291- tOcO,
292- make_coord (q_len - m_block * kBlockM , head_dim));
293+ /* ZFILL_MN=*/ false ,
294+ /* ZFILL_K=*/ false >(
295+ gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord);
293296 };
294297
298+ // output accumulator, (MMA,MMA_M,MMA_K)
299+ auto tOrAccO = partition_fragment_C (tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
300+ auto tOrAccO_rc_view =
301+ make_tensor (tOrAccO.data (), Layout::to_rowcol (tOrAccO.layout ()));
302+ clear (tOrAccO);
303+
295304 const int diagonal = m_block * kBlockM + kv_len - q_len;
296305 // process kv in range: [kv_idx_min, kv_idx_max)
297306 const int kv_idx_min = std::max (0 , diagonal - sliding_window);
298307 const int kv_idx_max = std::min (kv_len, diagonal + kBlockM );
299308 const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0 ;
300309 const int n_block_max = cute::ceil_div (kv_idx_max, kBlockN );
301- // TODO: handle n_block_min >= n_block_max
302310
303- // ############### Prologue ###############
311+ if (n_block_min >= n_block_max) {
312+ // write output to gmem
313+ epilogue (tOrAccO);
314+ return ;
315+ }
304316
317+ // ############### Prologue ###############
318+ int n_block_idx = n_block_max - 1 ;
305319 // produce q: [] => [q]
306320 produce_q ();
307321 cp_async_fence ();
308322 // produce k: [q] => [q, k]
309- produce_k (n_block_min );
323+ produce_k (n_block_idx );
310324 cp_async_fence ();
311325
312326 // ############### Mainloop ###############
313327
314- // output accumulator, (MMA,MMA_M,MMA_K)
315- auto tOrAccO = partition_fragment_C (tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
316- auto tOrAccO_rc_view =
317- make_tensor (tOrAccO.data (), Layout::to_rowcol (tOrAccO.layout ()));
328+ OnlineSoftmax<kRowsPerMMA * size<1 >(tOrAccO)> softmax (sm_scale_log2);
329+ Mask<kBlockM , kBlockM , ALIBI, LOCAL> mask (
330+ q_len, kv_len, sliding_window, alibi_slope);
318331
319332 // attention score accumulator, (MMA,MMA_M,MMA_N)
320333 auto tSrAccS = partition_fragment_C (tiled_mma, Shape<_BLK_M, _BLK_N>{});
321334 auto tSrAccS_rc_view =
322335 make_tensor (tSrAccS.data (), Layout::to_rowcol (tSrAccS.layout ()));
336+ // seperate oob mask iterations for better performance
337+ constexpr int n_oob_mask = cute::ceil_div (kBlockM , kBlockN ) + 1 ;
323338
324- OnlineSoftmax<kRowsPerMMA * size<1 >(tOrAccO)> softmax (sm_scale_log2);
325- Mask<kBlockM , kBlockM , ALIBI, LOCAL> mask (
326- q_len, kv_len, sliding_window, alibi_slope);
327-
328- clear (tOrAccO);
329- CUTE_NO_UNROLL
330- for (int ni = n_block_min; ni < n_block_max; ++ni) {
339+ // oob mask iterations
340+ CUTE_UNROLL
341+ for (int i = 0 ; i < n_oob_mask; ++i) {
331342 clear (tSrAccS);
332343
333344 // wait k, queue: [q, k] => []
334345 cp_async_wait<0 >();
335346 __syncthreads ();
336347
337348 // produce v, [] => [v]
338- produce_v (ni);
349+ if (i == 0 ) {
350+ produce_v (n_block_idx);
351+ } else {
352+ produce_v_no_oob (n_block_idx);
353+ }
339354 cp_async_fence ();
340355
341356342357 compute_qk (tSrAccS);
343358
344- // apply soft cap if needed
345359 if constexpr (SOFT_CAP) {
346360 apply_logits_soft_cap (tSrAccS);
347361 }
362+ mask.apply (tSrAccS_rc_view, m_block, n_block_idx, tidx);
363+ softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
348364
349- // apply mask for block (m_block, ni)
350- mask.apply (tSrAccS_rc_view, m_block, ni, tidx);
365+ // wait v, [v] => []
366+ cp_async_wait<0 >();
367+ __syncthreads ();
368+
369+ // produce next k: [] => [k]
370+ if (n_block_idx > n_block_min) {
371+ produce_k_no_oob (n_block_idx - 1 );
372+ }
373+ cp_async_fence ();
374+
375+ // 2> O = softmax(S)*V
376+ compute_sv (tSrAccS, tOrAccO);
377+
378+ --n_block_idx;
379+ if (n_block_idx < n_block_min) {
380+ // no more kv blocks to process
381+ break ;
382+ }
383+ }
351384
352- // apply softmax and rescale
385+ // non-oob mask iterations
386+ CUTE_NO_UNROLL
387+ for (; n_block_idx >= n_block_min; --n_block_idx) {
388+ clear (tSrAccS);
389+
390+ // wait k, queue: [q, k] => []
391+ cp_async_wait<0 >();
392+ __syncthreads ();
393+
394+ // produce v, [] => [v]
395+ produce_v_no_oob (n_block_idx);
396+ cp_async_fence ();
397+
398+ 399+ compute_qk (tSrAccS);
400+
401+ if constexpr (SOFT_CAP) {
402+ apply_logits_soft_cap (tSrAccS);
403+ }
404+ mask.apply </* OOB_MASK=*/ false >(tSrAccS_rc_view, m_block, n_block_idx, tidx);
353405 softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
354406
355407 // wait v, [v] => []
356408 cp_async_wait<0 >();
357409 __syncthreads ();
358410
359411 // produce next k: [] => [k]
360- if (ni != n_block_max - 1 ) {
361- produce_k (ni + 1 );
412+ if (n_block_idx > n_block_min ) {
413+ produce_k_no_oob (n_block_idx - 1 );
362414 }
363415 cp_async_fence ();
364416
0 commit comments