Skip to content

Commit 81404c1

Browse files
authored
[FlashAttention] Update flash-attention-mma 0.0.1 🎉 (#159)
* Update flash_attn_mma_stage.cu * Update flash_attn_mma_tiling.cu * Update README.md * Update README.md * Update README.md
1 parent b1b923a commit 81404c1

File tree

4 files changed

+4
-18
lines changed

4 files changed

+4
-18
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
4242
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
4343
|✔️|✔️|✔️|✔️|
4444

45-
I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-atttention-mma⚡️⚡️](./kernels/flash-attn) for more details.
45+
I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.
4646

4747
![flash-attn-mma](https://github.com/user-attachments/assets/3e20fdaa-9b31-4dcd-91d5-204905842dce)
4848

kernels/flash-attn/README.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,8 @@
1010

1111
## 📖 说明
1212

13-
包含以下内容:(性能持续优化中,敬请期待...)
14-
15-
- [X] flash_attn_cuda_kernel (F32)
16-
- [x] flash_attn_mma_naive_kernel (ldmatrix + MMA)
17-
- [X] flash_attn_mma_stage_kernel (ldmatrix + MMA, Stages, Tile MMA/Warp, Copy Async, Collective Store, SMEM Padding)
18-
1913
本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:[flash-attention](https://github.com/Dao-AILab/flash-attention)
2014

21-
## 📖 Kernel 调用
22-
- flash_attn_mma_stage_kernel:
2315
```C++
2416
template<
2517
const int kHeadDim, // Headdim, 32,64,128

kernels/flash-attn/mma/flash_attn_mma_stage.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using namespace nvcuda;
1717
// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
1818

1919
// The FlashAttention-2 algorithm is described in the following paper:
20-
// https://arxiv.org/abs/2110.08210
20+
// https://arxiv.org/pdf/2307.08691
2121

2222
// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
2323
// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]

kernels/flash-attn/mma/flash_attn_mma_tiling.cu

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using namespace nvcuda;
1717
// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
1818

1919
// The FlashAttention-2 algorithm is described in the following paper:
20-
// https://arxiv.org/abs/2110.08210
20+
// https://arxiv.org/pdf/2307.08691
2121

2222
// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
2323
// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]
@@ -609,14 +609,8 @@ flash_attn_mma_stages_kernel(half* Q,
609609
// 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}
610610
#pragma unroll
611611
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=2
612-
FA_MMA_CHECK_PRINT_REG(R_S[i][0][0], R_Q[i][0], "Check failed, R_S[%d][0][0], R_Q[%d][0], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
613-
FA_MMA_CHECK_PRINT_REG(R_S[i][0][1], R_Q[i][1], "Check failed, R_S[%d][0][1], R_Q[%d][1], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
614-
FA_MMA_CHECK_PRINT_REG(R_S[i][1][0], R_Q[i][2], "Check failed, R_S[%d][1][0], R_Q[%d][2], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
615-
FA_MMA_CHECK_PRINT_REG(R_S[i][1][1], R_Q[i][3], "Check failed, R_S[%d][1][1], R_Q[%d][3], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
616612
#pragma unroll
617613
for (int j = 0; j < kWarpTileHeadDimV; ++j) { // kWarpTileHeadDimV=1,2,3,4,...
618-
FA_MMA_PRINT_REG(R_V[j][0], "[Before] MMA P@V, R_V[%d][0], tile_V_Bc: %d, tid: %d, lane: %d", j, tile_V_Bc, tid, lane_id);
619-
FA_MMA_PRINT_REG(R_V[j][1], "[Before] MMA P@V, R_V[%d][1], tile_V_Bc: %d, tid: %d, lane: %d", j, tile_V_Bc, tid, lane_id);
620614
HMMA16816(R_O[i][j][0], R_O[i][j][1],
621615
// FIXME(DefTruth): Still have some error while using R_S
622616
// as registers for P(A) matrix directly. I will remove this
@@ -795,7 +789,7 @@ void launch_flash_attn_mma_stages(
795789
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN*kMmaTileHeadDimV));
796790
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*2*2=64
797791
constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64
798-
constexpr int kPad = 0;
792+
constexpr int kPad = 8;
799793

800794
// Calculate SRAM size needed per block, Q,K,V,S smem size
801795
const int smem_max_size = ((Br * (kHeadDim + kPad)) +

0 commit comments

Comments
 (0)