Skip to content

Commit 9324ddf

Browse files
authored
[FA2] Update flash-attn-mma shared-qkv🎉 (#168)
* Update flash_attn_mma_share_qkv.cu * Update flash_attn_mma_share_kv.cu * Update flash_attn_mma_share_kv.cu * Update flash_attn_mma_share_qkv.cu * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md
1 parent db8b8e8 commit 9324ddf

File tree

5 files changed

+49
-43
lines changed

5 files changed

+49
-43
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
119119
int QKV_seqlen);
120120
```
121121

122-
- 📚 Split Q + Shared KV SMEM (Faster+)
122+
- 📚 Split Q + Shared KV SMEM (**1/2 SRAM** vs FA2)
123123
<div id="mma-share-kv"></div>
124124

125125
```C++
@@ -131,7 +131,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
131131
half* O,
132132
int QKV_seqlen);
133133
```
134-
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
134+
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
135135
136136
<div id="mma-share-qkv"></div>
137137

kernels/flash-attn/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
9393
int QKV_seqlen);
9494
```
9595

96-
- 📚 Split Q + Shared KV SMEM (Faster+)
96+
- 📚 Split Q + Shared KV SMEM (**1/2 SRAM** vs FA2)
9797
<div id="mma-share-kv"></div>
9898

9999
```C++
@@ -105,7 +105,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
105105
half* O,
106106
int QKV_seqlen);
107107
```
108-
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
108+
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
109109
110110
<div id="mma-share-qkv"></div>
111111
@@ -119,6 +119,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
119119
int QKV_seqlen);
120120
```
121121

122+
122123
## 📖 Prerequisites
123124
<div id="prerequisites"></div>
124125

kernels/flash-attn/flash_attn_mma.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ def get_args():
4545
parser.add_argument("--N", type=int, default=None)
4646
parser.add_argument("--D", type=int, default=None)
4747
parser.add_argument("--seed", type=int, default=None)
48+
parser.add_argument("--sleep", type=float, default=0.05)
4849
parser.add_argument("--debug", action="store_true")
4950
parser.add_argument("--verbose", '--v', action="store_true")
50-
parser.add_argument("--warmup", type=int, default=1)
51-
parser.add_argument("--iters", type=int, default=5)
51+
parser.add_argument("--warmup", "--w", type=int, default=1)
52+
parser.add_argument("--iters", "--i", type=int, default=5)
5253
parser.add_argument("--range-k", '--gk', action="store_true")
5354
return parser.parse_args()
5455

@@ -178,7 +179,7 @@ def run_benchmark(perf_func: callable,
178179
print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
179180
if show_all:
180181
print(out)
181-
time.sleep(0.05)
182+
time.sleep(args.sleep)
182183
torch.cuda.synchronize()
183184
return out.clone(), mean_time
184185

kernels/flash-attn/mma/flash_attn_mma_share_kv.cu

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -213,38 +213,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
213213
for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
214214
// TODO: process last tile_K_seqlen ? pad to multiple of 8.
215215

216-
// <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
217-
if constexpr (kCanPrefetchQs2r) {
218-
// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
219-
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
220-
if (tile_K_seqlen == 0) {
221-
CP_ASYNC_WAIT_GROUP(0);
222-
__syncthreads();
223-
224-
#pragma unroll
225-
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
226-
// Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
227-
// By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
228-
// Then we can load Q from smem only once and reuse it for <loop over K seqlen>
229-
// processes. This will reduce large io-access for Q smem while N is large.
230-
#pragma unroll
231-
for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
232-
int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
233-
int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
234-
int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
235-
uint32_t lane_smem_Q_ptr = (
236-
smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
237-
lane_smem_Q_d) * sizeof(half)
238-
);
239-
LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
240-
R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
241-
lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
242-
}
243-
}
244-
} // end if tile_K_seqlen == 0
245-
} // end if kCanPrefetchQs2r
246-
247-
// Load K tile from gmem -> smem, always use smem part 0.
216+
// Load K tile from gmem -> smem, always use smem part 0, send g2s
217+
// memory issues before Prefetch Q s2r to enable time overlap.
248218
if constexpr (kCanPrefetchKVg2s) {
249219
if (tile_K_seqlen == 0) {
250220
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
@@ -301,6 +271,38 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
301271
__syncthreads();
302272
}
303273

274+
// <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
275+
if constexpr (kCanPrefetchQs2r) {
276+
// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
277+
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
278+
if (tile_K_seqlen == 0) {
279+
CP_ASYNC_WAIT_GROUP(0);
280+
__syncthreads();
281+
282+
#pragma unroll
283+
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
284+
// Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
285+
// By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
286+
// Then we can load Q from smem only once and reuse it for <loop over K seqlen>
287+
// processes. This will reduce large io-access for Q smem while N is large.
288+
#pragma unroll
289+
for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
290+
int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
291+
int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
292+
int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
293+
uint32_t lane_smem_Q_ptr = (
294+
smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
295+
lane_smem_Q_d) * sizeof(half)
296+
);
297+
LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
298+
R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
299+
lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
300+
}
301+
}
302+
__syncthreads(); // wait all warps ready.
303+
} // end if tile_K_seqlen == 0
304+
} // end if kCanPrefetchQs2r
305+
304306
// <loop over K d>: tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
305307
// Matmul with NN layout, Q row major, K row major.
306308
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]

kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,20 +240,22 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
240240
lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
241241
}
242242
}
243+
__syncthreads(); // wait all warps ready.
243244
} // end if tile_K_seqlen == 0
244245
} // end if kCanPrefetchQs2r
245246

246-
// Load K tile from gmem -> smem, always use smem part 0.
247+
// Load K tile from gmem -> smem, always use smem part 0.
248+
// must after prefetch Q s2r in order to reuse Q smem.
247249
if constexpr (kCanPrefetchKVg2s) {
248250
if (tile_K_seqlen == 0) {
249251
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
250252
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
251253
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
252254
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
253255
uint32_t load_smem_K_ptr = (
254-
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
255-
load_smem_K_d * (Bc + kPad) +
256-
load_smem_K_Bc) * sizeof(half));
256+
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
257+
load_smem_K_d * (Bc + kPad) +
258+
load_smem_K_Bc) * sizeof(half));
257259
#pragma unroll
258260
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
259261
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);

0 commit comments

Comments
 (0)