@@ -138,6 +138,8 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q,
138
138
int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br;
139
139
if (load_gmem_Q_Br >= QKV_seqlen) return ;
140
140
constexpr bool kIsVCanLoadIn128b = (kHeadDim / (kNumThreads / kMmaAtomK )) % 8 == 0 ;
141
+ constexpr bool kIsVCanLoadIn64b = (kHeadDim / (kNumThreads / kMmaAtomK )) % 4 == 0 ;
142
+ static_assert (kIsVCanLoadIn128b || kIsVCanLoadIn64b , " V can't load in 128b or 64b." ); // 32,64,128,192,256,...
141
143
142
144
// Shared memory for Q,K,V, we don not need additional smem for O
143
145
// collective store which perform via registers reuse and warp shuffle.
@@ -763,17 +765,17 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q,
763
765
template <const int kHeadDim , const int kStage >
764
766
void launch_flash_attn_mma_stages_split_q_tiling_qk (
765
767
torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
766
- // Now: fixed tile BrxBc=128x128
768
+ // Now: fixed tile BrxBc=128x128 for d>= 128, 64x64 for d<128.
767
769
// TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
768
770
constexpr int kMmaAtomM = 16 ;
769
771
constexpr int kMmaAtomN = 8 ;
770
772
constexpr int kMmaAtomK = 16 ;
771
- constexpr int kMmaTileSeqLenQ = 8 ;
773
+ constexpr int kMmaTileSeqLenQ = ( kHeadDim < 128 ) ? 4 : 8 ;
772
774
constexpr int kMmaTileSeqLenK = 1 ;
773
- constexpr int kMmaTileSeqLenP = 8 ;
775
+ constexpr int kMmaTileSeqLenP = ( kHeadDim < 128 ) ? 4 : 8 ;
774
776
constexpr int kMmaTileHeadDimV = 1 ;
775
777
constexpr int kWarpTileSeqLenQ = 1 ;
776
- constexpr int kWarpTileSeqLenK = 16 ;
778
+ constexpr int kWarpTileSeqLenK = ( kHeadDim < 128 ) ? 8 : 16 ;
777
779
constexpr int kWarpTileSeqLenP = 1 ;
778
780
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV )); // (d=64)8,(d=128)16,32,....
779
781
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ ; // 16*4*1=64
0 commit comments