@@ -299,21 +299,6 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
299299 return ;
300300 }
301301
302- // ############### Prologue ###############
303- int n_block_idx = n_block_max - 1 ;
304- // produce query: [] => [q]
305- produce_query ();
306- cp_async_fence ();
307- // produce key: [q] => [q, k]
308- produce_key (n_block_idx);
309- cp_async_fence ();
310-
311- // ############### Mainloop ###############
312- // attention score accumulator, (MMA,MMA_M,MMA_N)
313- auto tSrAccS = partition_fragment_C (tiled_mma, Shape<_BLK_M, _BLK_N>{});
314- auto tSrAccS_rc_view =
315- make_tensor (tSrAccS.data (), Layout::to_rowcol (tSrAccS.layout ()));
316-
317302 auto apply_logits_soft_cap = [&](auto & tSrAccS) {
318303 if constexpr (SOFT_CAP) {
319304 CUTE_UNROLL
@@ -323,7 +308,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
323308 }
324309 };
325310
326- constexpr int kMMA_M = size<1 >(tSrAccS );
311+ constexpr int kMMA_M = size<1 >(tOrAccO );
327312 using Softmax = OnlineSoftmax<kRowsPerMMA * kMMA_M >;
328313 using Mask = Mask<kBlockM , kBlockM , kRowsPerMMA , kMMA_M , ALIBI, LOCAL>;
329314
@@ -338,12 +323,26 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
338323 sm_scale,
339324 params.alibi_slopes_ptr );
340325
341- // seperate oob mask iterations for better performance
326+ // ############### Prologue ###############
327+ // produce query: [] => [q]
328+ produce_query ();
329+ cp_async_fence ();
330+ // produce key: [q] => [q, k]
331+ produce_key (n_block_max - 1 );
332+ cp_async_fence ();
333+
334+ // ############### Mainloop ###############
342335 constexpr int n_oob_mask = cute::ceil_div (kBlockM , kBlockN ) + 1 ;
336+ const int n_blocks = n_block_max - n_block_min;
343337
344- // oob mask iterations
345- CUTE_UNROLL
346- for (int i = 0 ; i < n_oob_mask; ++i) {
338+ CUTE_NO_UNROLL
339+ for (int i = 0 ; i < n_blocks; ++i) {
340+ const int n_block_idx = n_block_max - 1 - i;
341+
342+ // attention score accumulator, (MMA,MMA_M,MMA_N)
343+ auto tSrAccS = partition_fragment_C (tiled_mma, Shape<_BLK_M, _BLK_N>{});
344+ auto tSrAccS_rc_view =
345+ make_tensor (tSrAccS.data (), Layout::to_rowcol (tSrAccS.layout ()));
347346 clear (tSrAccS);
348347
349348 // wait key, queue: [q, k] => []
@@ -361,57 +360,20 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
361360362361 compute_qk (tSrAccS);
363362
364- if constexpr (SOFT_CAP) {
365- apply_logits_soft_cap (tSrAccS);
366- }
367- mask.apply (tSrAccS_rc_view, n_block_idx);
368- softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
369-
370363 // wait value, [v] => []
371364 cp_async_wait<0 >();
372365 __syncthreads ();
373366
374- // produce next key: [] => [k]
375- if (n_block_idx > n_block_min) {
376- produce_key_no_oob (n_block_idx - 1 );
377- }
378- cp_async_fence ();
379-
380- // 2> O = softmax(S)*V
381- compute_sv (tSrAccS, tOrAccO);
382-
383- --n_block_idx;
384- if (n_block_idx < n_block_min) {
385- // no more kv blocks to process
386- break ;
387- }
388- }
389-
390- // non-oob mask iterations
391- CUTE_NO_UNROLL
392- for (; n_block_idx >= n_block_min; --n_block_idx) {
393- clear (tSrAccS);
394-
395- // wait key, queue: [q, k] => []
396- cp_async_wait<0 >();
397- __syncthreads ();
398-
399- // produce value, [] => [v]
400- produce_value_no_oob (n_block_idx);
401- cp_async_fence ();
402-
403- 404- compute_qk (tSrAccS);
405-
406367 if constexpr (SOFT_CAP) {
407368 apply_logits_soft_cap (tSrAccS);
408369 }
409- mask.apply </* OOB_MASK=*/ false >(tSrAccS_rc_view, n_block_idx);
410- softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
411370
412- // wait value, [v] => []
413- cp_async_wait<0 >();
414- __syncthreads ();
371+ if (i < n_oob_mask) {
372+ mask.apply (tSrAccS_rc_view, n_block_idx);
373+ } else {
374+ mask.apply </* OOB_MASK=*/ false >(tSrAccS_rc_view, n_block_idx);
375+ }
376+ softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
415377
416378 // produce next key: [] => [k]
417379 if (n_block_idx > n_block_min) {
0 commit comments