Skip to content

Commit 15c09aa

Browse files
authored
feat: added smem and gmem layout selector for attn kernel (#490)
1 parent 96be5da commit 15c09aa

16 files changed

+517
-304
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ cc_test(
6161
INCLUDES
6262
${CMAKE_CURRENT_SOURCE_DIR}
6363
SRCS
64-
# sm80_mha_test.cu
64+
# tests/sm80_mha_test.cu
6565
tests/sm80_mha_pagedkv_test.cu
6666
DEPS
6767
:attention.kernels

src/kernels/attention/collective/sm120_collective_epilogue.cuh

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,51 +9,38 @@
99

1010
#include "common/fast_cast.cuh"
1111
#include "common/safe_copy.h"
12+
#include "common/selector.h"
1213

1314
namespace llm {
1415
using namespace cute;
1516

16-
template <class TileShape_, class Element_, int HeadDim_, bool EVEN_K_>
17+
template <class TileShape, class Element, bool EVEN_K>
1718
struct Sm120CollectiveEpilogue {
18-
using TileShape = TileShape_;
19-
using Element = Element_;
20-
21-
static constexpr int kHeadDim = HeadDim_;
22-
static constexpr bool EVEN_K = EVEN_K_;
23-
19+
static constexpr int kThreads = 128;
2420
static constexpr int kBlockM = get<0>(TileShape{});
2521
static constexpr int kBlockK = get<2>(TileShape{});
2622

2723
using BLK_M = Int<kBlockM>;
2824
using BLK_K = Int<kBlockK>;
29-
using HEAD_DIM = Int<kHeadDim>;
3025

3126
using SmemLayoutAtom_ =
32-
decltype(composition(Swizzle<3, 3, 3>{},
33-
Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
27+
decltype(smem_layout_atom_selector<Element, kBlockK>());
28+
static constexpr int kSmemBlockK = size<1>(SmemLayoutAtom_{});
3429

35-
// Q smem: (BLK_M, HEAD_DIM)
30+
// Q smem: (BLK_M, BLK_K)
3631
using SmemLayoutO =
37-
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));
32+
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K>{}));
3833

3934
// use 128-bit vectorizing copy
4035
using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>;
4136

4237
// r2s copy atom for O
4338
using SmemCopyAtom_ = Copy_Atom<VectorizingCopy_, Element>;
4439

45-
// Thr layout for gmem copy
46-
using GmemCopyThrLayout_ =
47-
std::conditional_t<kBlockK == 32,
48-
Layout<Shape<_32, _4>, Stride<_4, _1>>,
49-
Layout<Shape<_16, _8>, Stride<_8, _1>>>;
50-
5140
// s2g tiled copy for O
52-
using GmemTiledCopyO = decltype(make_tiled_copy(
53-
Copy_Atom<VectorizingCopy_, Element>{},
54-
GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4)
55-
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
56-
));
41+
using GmemTiledCopyO =
42+
decltype(gmem_tiled_copy_selector<Element, kThreads, kBlockK>(
43+
Copy_Atom<VectorizingCopy_, Element>{}));
5744

5845
struct TensorStorage {
5946
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
@@ -80,11 +67,11 @@ struct Sm120CollectiveEpilogue {
8067
return;
8168
}
8269

83-
// (BLK_M, HEAD_DIM) => (M, K)
70+
// (BLK_M, BLK_K) => (M, K)
8471
auto [gO, cO] = block.get_o_tile();
8572
auto residue_mnk = block.get_residue_mnk();
8673

87-
// (BLK_M, HEAD_DIM)
74+
// (BLK_M, BLK_K)
8875
Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{});
8976

9077
// 1. cast output from ElementAccumulator to Element

src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "common/mask.h"
1515
#include "common/online_softmax.cuh"
1616
#include "common/safe_copy.h"
17+
#include "common/selector.h"
1718
#include "sm120_collective_load_cpasync_ws.cuh"
1819
#include "sm120_collective_load_tma_ws.cuh"
1920

@@ -23,21 +24,20 @@ using namespace cute;
2324

2425
template <class TileShape_,
2526
class Element_,
26-
int HeadDim_,
2727
bool EVEN_K,
2828
bool ALIBI,
2929
bool SOFT_CAP,
3030
bool LOCAL,
3131
bool KV_USE_TMA = false // whether to use TMA for K/V loading
3232
>
3333
struct Sm120CollectiveFMhaWs {
34+
// exposed template parameters
3435
using TileShape = TileShape_;
3536
using Element = Element_;
3637
using ElementAccum = float;
3738

3839
using ClusterShape = Shape<_1, _1, _1>;
3940

40-
static constexpr int kHeadDim = HeadDim_;
4141
static constexpr int kBlockM = get<0>(TileShape{});
4242
static constexpr int kBlockN = get<1>(TileShape{});
4343
static constexpr int kBlockK = get<2>(TileShape{});
@@ -46,13 +46,9 @@ struct Sm120CollectiveFMhaWs {
4646
static constexpr bool kLocal = LOCAL;
4747
static constexpr bool kKVUseTma = KV_USE_TMA;
4848

49-
static_assert(kBlockK == 32 || kBlockK == 64);
50-
static_assert(kHeadDim % kBlockK == 0);
51-
5249
using BLK_M = Int<kBlockM>;
5350
using BLK_N = Int<kBlockN>;
5451
using BLK_K = Int<kBlockK>;
55-
using HEAD_DIM = Int<kHeadDim>;
5652

5753
// TiledMMA (64x16x16) for gemm-I and gemm-II
5854
using MMA_Atom_ =
@@ -70,27 +66,25 @@ struct Sm120CollectiveFMhaWs {
7066
static constexpr int StageCountQ = 1;
7167
static constexpr int StageCountKV = 3;
7268

73-
// Atom layout: (8, BLK_K):(BLK_K, 1) k-major
7469
using SmemLayoutAtom_ =
75-
decltype(composition(Swizzle<3, 3, 3>{},
76-
Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
70+
decltype(smem_layout_atom_selector<Element, kBlockK>());
7771

78-
// Q smem: (BLK_M, HEAD_DIM)
72+
// Q smem: (BLK_M, BLK_K)
7973
using SmemLayoutQ =
80-
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));
74+
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K>{}));
8175

82-
// KV smem: (BLK_N, HEAD_DIM, KVStages)
76+
// KV smem: (BLK_N, BLK_K, KVStages)
8377
using SmemLayoutK =
8478
decltype(tile_to_shape(SmemLayoutAtom_{},
85-
Shape<BLK_N, HEAD_DIM, Int<StageCountKV>>{}));
79+
Shape<BLK_N, BLK_K, Int<StageCountKV>>{}));
8680
using SmemLayoutV = SmemLayoutK;
8781

88-
// V^T smem: (HEAD_DIM, BLK_N, KVStages)
82+
// V^T smem: (BLK_K, BLK_N, KVStages)
8983
using SmemLayoutVt = decltype(select<1, 0, 2>(SmemLayoutV{}));
9084

91-
// tma transaction bytes for (BLK_N, HEAD_DIM)
92-
static constexpr uint32_t kTmaTransactionBytes =
93-
size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8;
85+
// tma transaction bytes for (BLK_N, BLK_K)
86+
static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes(
87+
cosize(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element>);
9488

9589
struct TensorStorage {
9690
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
@@ -201,12 +195,12 @@ struct Sm120CollectiveFMhaWs {
201195
const auto kv_len = block.get_kv_len();
202196

203197
// Construct smem tensors
204-
// (BLK_M, HEAD_DIM), k-major
198+
// (BLK_M, BLK_K), k-major
205199
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
206-
// (BLK_N, HEAD_DIM, KVStages), k-major
200+
// (BLK_N, BLK_K, KVStages), k-major
207201
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
208202
// Tensor for V^t; used in GEMM-II.
209-
// (HEAD_DIM, BLK_N, KVStages), k-major
203+
// (BLK_K, BLK_N, KVStages), k-major
210204
Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{});
211205

212206
TiledMma tiled_mma;

0 commit comments

Comments
 (0)