Skip to content

Commit 1b96c96

Browse files
authored
[FA2] fix tiling-qk misaligned address✔️ (#174)
* Update README.md * fix misaligned address * fix misaligned address
1 parent f9fa8f0 commit 1b96c96

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

kernels/flash-attn/README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVI
4949
```
5050
- Example: B=1, H=48, N=8192, D=512 (NVIDIA RTX 3080 Laptop), FA2 not supported.
5151
```bash
52-
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
53-
-----------------------B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1041, Warmup: 1, Iters: 10------------------------
54-
----------------------------------------------------------------------------------------------------------------------------------
52+
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
5553
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
5654
mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
5755
mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)

kernels/flash-attn/flash_attn_mma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def get_args():
6666

6767

6868
args = get_args()
69+
if args.D and args.D >= 256:
70+
args.run_torch_sdpa = True
6971
pretty_print_line()
7072
print(args)
7173
pretty_print_line()

kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q,
138138
int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br;
139139
if (load_gmem_Q_Br >= QKV_seqlen) return;
140140
constexpr bool kIsVCanLoadIn128b = (kHeadDim / (kNumThreads / kMmaAtomK)) % 8 == 0;
141+
constexpr bool kIsVCanLoadIn64b = (kHeadDim / (kNumThreads / kMmaAtomK)) % 4 == 0;
142+
static_assert(kIsVCanLoadIn128b || kIsVCanLoadIn64b, "V can't load in 128b or 64b."); // 32,64,128,192,256,...
141143

142144
// Shared memory for Q,K,V, we don not need additional smem for O
143145
// collective store which perform via registers reuse and warp shuffle.
@@ -763,17 +765,17 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q,
763765
template<const int kHeadDim, const int kStage>
764766
void launch_flash_attn_mma_stages_split_q_tiling_qk(
765767
torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
766-
// Now: fixed tile BrxBc=128x128
768+
// Now: fixed tile BrxBc=128x128 for d>= 128, 64x64 for d<128.
767769
// TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
768770
constexpr int kMmaAtomM = 16;
769771
constexpr int kMmaAtomN = 8;
770772
constexpr int kMmaAtomK = 16;
771-
constexpr int kMmaTileSeqLenQ = 8;
773+
constexpr int kMmaTileSeqLenQ = (kHeadDim < 128) ? 4 : 8;
772774
constexpr int kMmaTileSeqLenK = 1;
773-
constexpr int kMmaTileSeqLenP = 8;
775+
constexpr int kMmaTileSeqLenP = (kHeadDim < 128) ? 4 : 8;
774776
constexpr int kMmaTileHeadDimV = 1;
775777
constexpr int kWarpTileSeqLenQ = 1;
776-
constexpr int kWarpTileSeqLenK = 16;
778+
constexpr int kWarpTileSeqLenK = (kHeadDim < 128) ? 8 : 16;
777779
constexpr int kWarpTileSeqLenP = 1;
778780
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // (d=64)8,(d=128)16,32,....
779781
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64

0 commit comments

Comments
 (0)