1414#include " common/mask.h"
1515#include " common/online_softmax.cuh"
1616#include " common/safe_copy.h"
17+ #include " common/selector.h"
1718#include " sm120_collective_load_cpasync_ws.cuh"
1819#include " sm120_collective_load_tma_ws.cuh"
1920
@@ -23,21 +24,20 @@ using namespace cute;
2324
2425template <class TileShape_ ,
2526 class Element_ ,
26- int HeadDim_,
2727 bool EVEN_K,
2828 bool ALIBI,
2929 bool SOFT_CAP,
3030 bool LOCAL,
3131 bool KV_USE_TMA = false // whether to use TMA for K/V loading
3232 >
3333struct Sm120CollectiveFMhaWs {
34+ // exposed template parameters
3435 using TileShape = TileShape_;
3536 using Element = Element_;
3637 using ElementAccum = float ;
3738
3839 using ClusterShape = Shape<_1, _1, _1>;
3940
40- static constexpr int kHeadDim = HeadDim_;
4141 static constexpr int kBlockM = get<0 >(TileShape{});
4242 static constexpr int kBlockN = get<1 >(TileShape{});
4343 static constexpr int kBlockK = get<2 >(TileShape{});
@@ -46,13 +46,9 @@ struct Sm120CollectiveFMhaWs {
4646 static constexpr bool kLocal = LOCAL;
4747 static constexpr bool kKVUseTma = KV_USE_TMA;
4848
49- static_assert (kBlockK == 32 || kBlockK == 64 );
50- static_assert (kHeadDim % kBlockK == 0 );
51-
5249 using BLK_M = Int<kBlockM >;
5350 using BLK_N = Int<kBlockN >;
5451 using BLK_K = Int<kBlockK >;
55- using HEAD_DIM = Int<kHeadDim >;
5652
5753 // TiledMMA (64x16x16) for gemm-I and gemm-II
5854 using MMA_Atom_ =
@@ -70,27 +66,25 @@ struct Sm120CollectiveFMhaWs {
7066 static constexpr int StageCountQ = 1 ;
7167 static constexpr int StageCountKV = 3 ;
7268
73- // Atom layout: (8, BLK_K):(BLK_K, 1) k-major
7469 using SmemLayoutAtom_ =
75- decltype (composition(Swizzle<3 , 3 , 3 >{},
76- Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
70+ decltype (smem_layout_atom_selector<Element, kBlockK >());
7771
78- // Q smem: (BLK_M, HEAD_DIM )
72+ // Q smem: (BLK_M, BLK_K )
7973 using SmemLayoutQ =
80- decltype (tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM >{}));
74+ decltype (tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K >{}));
8175
82- // KV smem: (BLK_N, HEAD_DIM , KVStages)
76+ // KV smem: (BLK_N, BLK_K , KVStages)
8377 using SmemLayoutK =
8478 decltype (tile_to_shape(SmemLayoutAtom_{},
85- Shape<BLK_N, HEAD_DIM , Int<StageCountKV>>{}));
79+ Shape<BLK_N, BLK_K , Int<StageCountKV>>{}));
8680 using SmemLayoutV = SmemLayoutK;
8781
88- // V^T smem: (HEAD_DIM , BLK_N, KVStages)
82+ // V^T smem: (BLK_K , BLK_N, KVStages)
8983 using SmemLayoutVt = decltype (select<1 , 0 , 2 >(SmemLayoutV{}));
9084
91- // tma transaction bytes for (BLK_N, HEAD_DIM )
92- static constexpr uint32_t kTmaTransactionBytes =
93- size (take<0 , 2 >(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8 ;
85+ // tma transaction bytes for (BLK_N, BLK_K )
86+ static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes(
87+ cosize (take<0 , 2 >(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element>) ;
9488
9589 struct TensorStorage {
9690 cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
@@ -201,12 +195,12 @@ struct Sm120CollectiveFMhaWs {
201195 const auto kv_len = block.get_kv_len ();
202196
203197 // Construct smem tensors
204- // (BLK_M, HEAD_DIM ), k-major
198+ // (BLK_M, BLK_K ), k-major
205199 Tensor sQ = make_tensor (make_smem_ptr (ss.smem_q .data ()), SmemLayoutQ{});
206- // (BLK_N, HEAD_DIM , KVStages), k-major
200+ // (BLK_N, BLK_K , KVStages), k-major
207201 Tensor sK = make_tensor (make_smem_ptr (ss.smem_k .data ()), SmemLayoutK{});
208202 // Tensor for V^t; used in GEMM-II.
209- // (HEAD_DIM , BLK_N, KVStages), k-major
203+ // (BLK_K , BLK_N, KVStages), k-major
210204 Tensor sVt = make_tensor (make_smem_ptr (ss.smem_vt .data ()), SmemLayoutVt{});
211205
212206 TiledMma tiled_mma;
0 commit comments