@@ -213,38 +213,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
213
213
for (int tile_K_seqlen = 0 ; tile_K_seqlen < Tc; ++tile_K_seqlen) {
214
214
// TODO: process last tile_K_seqlen ? pad to multiple of 8.
215
215
216
- // <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
217
- if constexpr (kCanPrefetchQs2r ) {
218
- // Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
219
- // NOTE: we only need to load Q once from smem -> regs, and then reuse it.
220
- if (tile_K_seqlen == 0 ) {
221
- CP_ASYNC_WAIT_GROUP (0 );
222
- __syncthreads ();
223
-
224
- #pragma unroll
225
- for (int tile_K_d = 0 ; tile_K_d < (kHeadDim / kMmaAtomK ); ++tile_K_d) {
226
- // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
227
- // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
228
- // Then we can load Q from smem only once and reuse it for <loop over K seqlen>
229
- // processes. This will reduce large io-access for Q smem while N is large.
230
- #pragma unroll
231
- for (int i = 0 ; i < kWarpTileSeqLenQ ; ++i) { // Q[Br,d]=[M,K]
232
- int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ ) + i * kMmaAtomM ;
233
- int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16 ; // 0~15
234
- int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16 ) * 8 ; // 0,8
235
- uint32_t lane_smem_Q_ptr = (
236
- smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad ) +
237
- lane_smem_Q_d) * sizeof (half)
238
- );
239
- LDMATRIX_X4 (R_Q[tile_K_d][i][0 ], R_Q[tile_K_d][i][1 ],
240
- R_Q[tile_K_d][i][2 ], R_Q[tile_K_d][i][3 ],
241
- lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
242
- }
243
- }
244
- } // end if tile_K_seqlen == 0
245
- } // end if kCanPrefetchQs2r
246
-
247
- // Load K tile from gmem -> smem, always use smem part 0.
216
+ // Load K tile from gmem -> smem, always use smem part 0, send g2s
217
+ // memory issues before Prefetch Q s2r to enable time overlap.
248
218
if constexpr (kCanPrefetchKVg2s ) {
249
219
if (tile_K_seqlen == 0 ) {
250
220
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
@@ -301,6 +271,38 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
301
271
__syncthreads ();
302
272
}
303
273
274
+ // <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
275
+ if constexpr (kCanPrefetchQs2r ) {
276
+ // Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
277
+ // NOTE: we only need to load Q once from smem -> regs, and then reuse it.
278
+ if (tile_K_seqlen == 0 ) {
279
+ CP_ASYNC_WAIT_GROUP (0 );
280
+ __syncthreads ();
281
+
282
+ #pragma unroll
283
+ for (int tile_K_d = 0 ; tile_K_d < (kHeadDim / kMmaAtomK ); ++tile_K_d) {
284
+ // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
285
+ // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
286
+ // Then we can load Q from smem only once and reuse it for <loop over K seqlen>
287
+ // processes. This will reduce large io-access for Q smem while N is large.
288
+ #pragma unroll
289
+ for (int i = 0 ; i < kWarpTileSeqLenQ ; ++i) { // Q[Br,d]=[M,K]
290
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ ) + i * kMmaAtomM ;
291
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16 ; // 0~15
292
+ int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16 ) * 8 ; // 0,8
293
+ uint32_t lane_smem_Q_ptr = (
294
+ smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad ) +
295
+ lane_smem_Q_d) * sizeof (half)
296
+ );
297
+ LDMATRIX_X4 (R_Q[tile_K_d][i][0 ], R_Q[tile_K_d][i][1 ],
298
+ R_Q[tile_K_d][i][2 ], R_Q[tile_K_d][i][3 ],
299
+ lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
300
+ }
301
+ }
302
+ __syncthreads (); // wait all warps ready.
303
+ } // end if tile_K_seqlen == 0
304
+ } // end if kCanPrefetchQs2r
305
+
304
306
// <loop over K d>: tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
305
307
// Matmul with NN layout, Q row major, K row major.
306
308
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
0 commit comments