Skip to content

Commit 56e8bed

Browse files
committed
use selector for sm80 mha
1 parent b595267 commit 56e8bed

File tree

5 files changed

+65
-100
lines changed

5 files changed

+65
-100
lines changed

src/kernels/attention/collective/sm80_collective_epilogue.cuh

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,51 +9,42 @@
99

1010
#include "common/fast_cast.cuh"
1111
#include "common/safe_copy.h"
12+
#include "common/selector.h"
1213

1314
namespace llm {
1415
using namespace cute;
1516

16-
template <class TileShape_, class Element_, int HeadDim_, bool EVEN_K_>
17+
template <class TileShape_, class Element_, bool EVEN_K_>
1718
struct Sm80CollectiveEpilogue {
1819
using TileShape = TileShape_;
1920
using Element = Element_;
2021

21-
static constexpr int kHeadDim = HeadDim_;
22+
static constexpr int kThreads = 128;
2223
static constexpr bool EVEN_K = EVEN_K_;
2324

2425
static constexpr int kBlockM = get<0>(TileShape{});
2526
static constexpr int kBlockK = get<2>(TileShape{});
2627

2728
using BLK_M = Int<kBlockM>;
2829
using BLK_K = Int<kBlockK>;
29-
using HEAD_DIM = Int<kHeadDim>;
3030

3131
using SmemLayoutAtom_ =
32-
decltype(composition(Swizzle<3, 3, 3>{},
33-
Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
32+
decltype(smem_layout_atom_selector<Element, kBlockK>());
3433

35-
// Q smem: (BLK_M, HEAD_DIM)
34+
// Q smem: (BLK_M, BLK_K)
3635
using SmemLayoutO =
37-
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));
36+
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K>{}));
3837

3938
// use 128-bit vectorizing copy
4039
using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>;
4140

4241
// r2s copy atom for O
4342
using SmemCopyAtom_ = Copy_Atom<VectorizingCopy_, Element>;
4443

45-
// Thr layout for gmem copy
46-
using GmemCopyThrLayout_ =
47-
std::conditional_t<kBlockK == 32,
48-
Layout<Shape<_32, _4>, Stride<_4, _1>>,
49-
Layout<Shape<_16, _8>, Stride<_8, _1>>>;
50-
5144
// s2g tiled copy for O
52-
using GmemTiledCopyO = decltype(make_tiled_copy(
53-
Copy_Atom<VectorizingCopy_, Element>{},
54-
GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4)
55-
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
56-
));
45+
using GmemTiledCopyO =
46+
decltype(gmem_tiled_copy_selector<Element, kThreads, kBlockK>(
47+
Copy_Atom<VectorizingCopy_, Element>{}));
5748

5849
struct SharedStorage : cute::aligned_struct<128> {
5950
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
@@ -73,20 +64,19 @@ struct Sm80CollectiveEpilogue {
7364
class TensorO,
7465
class TensorCO,
7566
class ResidueMNK>
76-
CUTE_DEVICE void operator()(
77-
const Params& /*params*/,
78-
const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N)
79-
TiledMma tiled_mma,
80-
TensorO& gO, // (BLK_M, HEAD_DIM)
81-
const TensorCO& cO, // (BLK_M, HEAD_DIM) => (M, K)
82-
int tidx,
83-
const ResidueMNK& residue_mnk,
84-
char* smem) {
67+
CUTE_DEVICE void operator()(const Params& /*params*/,
68+
const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N)
69+
TiledMma tiled_mma,
70+
TensorO& gO, // (BLK_M, BLK_K)
71+
const TensorCO& cO, // (BLK_M, BLK_K) => (M, K)
72+
int tidx,
73+
const ResidueMNK& residue_mnk,
74+
char* smem) {
8575
static constexpr int kBlockM = get<0>(TileShape{});
8676

8777
// Smem
8878
auto& ss = *reinterpret_cast<SharedStorage*>(smem);
89-
// (BLK_M, HEAD_DIM)
79+
// (BLK_M, BLK_K)
9080
Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{});
9181

9282
// 1. cast output from ElementAccumulator to Element

src/kernels/attention/collective/sm80_collective_mha.cuh

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
#include "common/fast_cast.cuh"
1212
#include "common/layout_convertor.h"
1313
#include "common/safe_copy.h"
14+
#include "common/selector.h"
1415

1516
namespace llm {
1617

1718
using namespace cute;
1819

1920
template <class TileShape_,
2021
class Element_,
21-
int HeadDim_,
2222
bool EVEN_K,
2323
bool ALIBI,
2424
bool SOFT_CAP,
@@ -29,21 +29,16 @@ struct Sm80CollectiveMha {
2929
using Element = Element_;
3030
using ElementAccum = float;
3131

32-
static constexpr int kHeadDim = HeadDim_;
3332
static constexpr int kBlockM = get<0>(TileShape{});
3433
static constexpr int kBlockN = get<1>(TileShape{});
3534
static constexpr int kBlockK = get<2>(TileShape{});
3635

3736
static constexpr bool kAlibi = ALIBI;
3837
static constexpr bool kLocal = LOCAL;
3938

40-
static_assert(kBlockK == 32 || kBlockK == 64);
41-
static_assert(kHeadDim % kBlockK == 0);
42-
4339
using BLK_M = Int<kBlockM>;
4440
using BLK_N = Int<kBlockN>;
4541
using BLK_K = Int<kBlockK>;
46-
using HEAD_DIM = Int<kHeadDim>;
4742

4843
// TiledMMA (64x16x16) for gemm-I and gemm-II
4944
using MMA_Atom_ =
@@ -57,36 +52,26 @@ struct Sm80CollectiveMha {
5752
static constexpr int kRowsPerMMA = 2;
5853
static constexpr int kMmaThreads = size(TiledMma{});
5954

60-
// Atom layout: (8, BLK_K):(BLK_K, 1) k-major
55+
// Atom layout for shared memory
6156
using SmemLayoutAtom_ =
62-
decltype(composition(Swizzle<3, 3, 3>{},
63-
Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
57+
decltype(smem_layout_atom_selector<Element, kBlockK>());
6458

65-
// Q smem: (BLK_M, HEAD_DIM)
59+
// Q smem: (BLK_M, BLK_K)
6660
using SmemLayoutQ =
67-
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));
61+
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K>{}));
6862

69-
// KV smem: (BLK_N, HEAD_DIM)
63+
// KV smem: (BLK_N, BLK_K)
7064
using SmemLayoutK =
71-
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_N, HEAD_DIM>{}));
65+
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_N, BLK_K>{}));
7266
using SmemLayoutV =
73-
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_N, HEAD_DIM>{}));
67+
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_N, BLK_K>{}));
7468

75-
// V^T smem: (HEAD_DIM, BLK_N)
69+
// V^T smem: (BLK_K, BLK_N)
7670
using SmemLayoutVt = decltype(select<1, 0>(SmemLayoutV{}));
7771

78-
// Thr layout for gmem copy
79-
using GmemCopyThrLayout_ =
80-
std::conditional_t<kBlockK == 32,
81-
Layout<Shape<_32, _4>, Stride<_4, _1>>,
82-
Layout<Shape<_16, _8>, Stride<_8, _1>>>;
83-
84-
// g2s tiled copy for q
85-
using GmemTiledCopyQ = decltype(make_tiled_copy(
86-
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>{},
87-
GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4)
88-
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
89-
));
72+
using GmemTiledCopyQ =
73+
decltype(gmem_tiled_copy_selector<Element, kMmaThreads, kBlockK>(
74+
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>{}));
9075

9176
// g2s tiled copy for kv
9277
using GmemTiledCopyKV = GmemTiledCopyQ;
@@ -145,11 +130,11 @@ struct Sm80CollectiveMha {
145130
class ResidueMNK>
146131
CUTE_DEVICE void operator()(
147132
const Params& params,
148-
const TensorQ& gQ, // (BLK_M, HEAD_DIM)
149-
const TensorCQ& cQ, // (BLK_M, HEAD_DIM) => (M, K)
150-
const TensorK& gK, // (BLK_N, HEAD_DIM, n)
151-
const TensorV& gV, // (BLK_N, HEAD_DIM, n)
152-
const TensorCKV& cKV, // (BLK_N, HEAD_DIM, n) => (N, K)
133+
const TensorQ& gQ, // (BLK_M, BLK_K)
134+
const TensorCQ& cQ, // (BLK_M, BLK_K) => (M, K)
135+
const TensorK& gK, // (BLK_N, BLK_K, n)
136+
const TensorV& gV, // (BLK_N, BLK_K, n)
137+
const TensorCKV& cKV, // (BLK_N, BLK_K, n) => (N, K)
153138
const TensorCMN& tScMN_mn, // ((2, MMA_M), (2, MMA_N), n) => (M, N)
154139
FrgTensor& tOrO, // (MMA, MMA_M, MMA_N)
155140
Softmax& softmax,
@@ -173,14 +158,14 @@ struct Sm80CollectiveMha {
173158
// Construct shared memory tiles
174159
auto& ss = *reinterpret_cast<SharedStorage*>(smem);
175160

176-
// (BLK_M, HEAD_DIM), k-major
161+
// (BLK_M, BLK_K), k-major
177162
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
178-
// (BLK_N, HEAD_DIM), k-major
163+
// (BLK_N, BLK_K), k-major
179164
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
180165
Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{});
181166

182167
// Tensor for V^t; used in GEMM-II.
183-
// (HEAD_DIM, BLK_N), k-major
168+
// (BLK_K, BLK_N), k-major
184169
Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{});
185170

186171
// g2s tiled copy for qkv

src/kernels/attention/common/selector.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ using Layout_K_INTER_Atom_Bits = ComposedLayout<Swizzle<0,4,3>, smem_ptr_flag,
1515
using Layout_K_SW32_Atom_Bits = ComposedLayout<Swizzle<1,4,3>, smem_ptr_flag, Layout<Shape<_8, _256>,Stride< _256,_1>>>;
1616
using Layout_K_SW64_Atom_Bits = ComposedLayout<Swizzle<2,4,3>, smem_ptr_flag, Layout<Shape<_8, _512>,Stride< _512,_1>>>;
1717
using Layout_K_SW128_Atom_Bits = ComposedLayout<Swizzle<3,4,3>, smem_ptr_flag, Layout<Shape<_8,_1024>,Stride<_1024,_1>>>;
18-
using Layout_K_SW256_Atom_Bits = ComposedLayout<Swizzle<3,4,4>, smem_ptr_flag, Layout<Shape<_8,_2048>,Stride<_2048,_1>>>;
1918

2019
// K-major layouts in units of Type
2120
template <class Type>
@@ -26,17 +25,12 @@ template <class Type>
2625
using Layout_K_SW64_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_SW64_Atom_Bits{}));
2726
template <class Type>
2827
using Layout_K_SW128_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_SW128_Atom_Bits{}));
29-
template <class Type>
30-
using Layout_K_SW256_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_SW256_Atom_Bits{}));
3128

3229
} // namespace detail
3330

3431
template <class Element, int kBlockK>
3532
CUTE_HOST_DEVICE constexpr auto smem_layout_atom_selector() {
36-
if constexpr (kBlockK % size<1>(detail::Layout_K_SW256_Atom<Element>{}) == 0) {
37-
return detail::Layout_K_SW256_Atom<Element>{};
38-
}
39-
else if constexpr (kBlockK % size<1>(detail::Layout_K_SW128_Atom<Element>{}) == 0) {
33+
if constexpr (kBlockK % size<1>(detail::Layout_K_SW128_Atom<Element>{}) == 0) {
4034
return detail::Layout_K_SW128_Atom<Element>{};
4135
}
4236
else if constexpr (kBlockK % size<1>(detail::Layout_K_SW64_Atom<Element>{}) == 0) {
@@ -59,13 +53,16 @@ template <class Element, int kThreads, int kBlockK, class Copy_Atom>
5953
CUTE_HOST_DEVICE constexpr auto gmem_tiled_copy_selector(Copy_Atom cp_atom) {
6054
// maxmize vectorized load (128-bits or 16 bytes per thread)
6155
constexpr int kElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
62-
static_assert(kBlockK % kElemsPerLoad == 0,
56+
constexpr int kSmemBlockK =
57+
size<1>(smem_layout_atom_selector<Element, kBlockK>());
58+
static_assert(kSmemBlockK % kElemsPerLoad == 0,
6359
"kBlockK must be a multiple of kGmemElemsPerLoad");
6460

65-
constexpr int kThreadsPerRow = kBlockK / kElemsPerLoad;
61+
constexpr int kThreadsPerRow = kSmemBlockK / kElemsPerLoad;
6662
static_assert(kThreads % kThreadsPerRow == 0,
6763
"kThreads must be a multiple of kThreadsPerRow");
6864
constexpr int kRows = kThreads / kThreadsPerRow;
65+
static_assert(kRows <= 64, "kRows must be less than or equal to 64");
6966

7067
constexpr auto thr_layout = Layout<Shape<Int<kRows>, Int<kThreadsPerRow>>,
7168
Stride<Int<kThreadsPerRow>, _1>>{};

src/kernels/attention/device/sm80_mha_launch.cuh

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,11 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) {
5151
// * 12.0 : 0, 8, 16, 32, 64, 100
5252
constexpr int BLK_M = 64;
5353
constexpr int BLK_N = 64;
54-
constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32;
55-
56-
using TileShape = Shape<Int<BLK_M>, Int<BLK_N>, Int<BLK_K>>;
57-
using CollectiveMainloop = Sm80CollectiveMha<TileShape,
58-
Dtype,
59-
HEAD_DIM,
60-
EVEN_K,
61-
ALIBI,
62-
SOFT_CAP,
63-
LOCAL>;
64-
using CollectiveEpilogue =
65-
Sm80CollectiveEpilogue<TileShape, Dtype, HEAD_DIM, EVEN_K>;
54+
55+
using TileShape = Shape<Int<BLK_M>, Int<BLK_N>, Int<HEAD_DIM>>;
56+
using CollectiveMainloop =
57+
Sm80CollectiveMha<TileShape, Dtype, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
58+
using CollectiveEpilogue = Sm80CollectiveEpilogue<TileShape, Dtype, EVEN_K>;
6659

6760
// TODO: support persistent kernels
6861
using TileScheduler = SingleTileScheduler;

src/kernels/attention/kernel/sm80_kernel_mha.cuh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class Sm80KernelMha {
178178
using Element = typename CollectiveMainloop::Element;
179179
using BLK_M = typename CollectiveMainloop::BLK_M;
180180
using BLK_N = typename CollectiveMainloop::BLK_N;
181-
using HEAD_DIM = typename CollectiveMainloop::HEAD_DIM;
181+
using BLK_K = typename CollectiveMainloop::BLK_K;
182182

183183
static constexpr int kSharedStorageSize =
184184
cute::max(sizeof(typename CollectiveMainloop::SharedStorage),
@@ -224,10 +224,10 @@ class Sm80KernelMha {
224224
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord;
225225
const auto tidx = threadIdx.x;
226226

227-
// (q_packed_len, HEAD_DIM)
227+
// (q_packed_len, BLK_K)
228228
detail::MHATile<Params> tile(params, batch_idx, kv_head_idx);
229229
auto [Q, O] = tile.template get_qo_tile<Element>();
230-
// (kv_len, HEAD_DIM)
230+
// (kv_len, BLK_K)
231231
auto [K, V] = tile.template get_kv_tile<Element>();
232232

233233
// problem shape
@@ -252,22 +252,22 @@ class Sm80KernelMha {
252252
const int n_block_min = kLocal ? kv_idx_min / kBlockN : 0;
253253
const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN);
254254

255-
// (BLK_M, HEAD_DIM)
256-
Tensor gQ = local_tile(
257-
Q, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
258-
Tensor gO = local_tile(
259-
O, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
260-
// (BLK_M, HEAD_DIM) => (M, K)
255+
// (BLK_M, BLK_K)
256+
Tensor gQ =
257+
local_tile(Q, Shape<BLK_M, BLK_K>{}, make_coord(m_block_idx, _0{}));
258+
Tensor gO =
259+
local_tile(O, Shape<BLK_M, BLK_K>{}, make_coord(m_block_idx, _0{}));
260+
// (BLK_M, BLK_K) => (M, K)
261261
Tensor cQ = local_tile(make_identity_tensor(Q.shape()),
262-
Shape<BLK_M, HEAD_DIM>{},
262+
Shape<BLK_M, BLK_K>{},
263263
make_coord(m_block_idx, _0{}));
264264

265-
// (BLK_N, HEAD_DIM, n)
266-
Tensor gK = local_tile(K, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
267-
Tensor gV = local_tile(V, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
268-
// (BLK_N, HEAD_DIM, n) => (N, K)
265+
// (BLK_N, BLK_K, n)
266+
Tensor gK = local_tile(K, Shape<BLK_N, BLK_K>{}, make_coord(_, _0{}));
267+
Tensor gV = local_tile(V, Shape<BLK_N, BLK_K>{}, make_coord(_, _0{}));
268+
// (BLK_N, BLK_K, n) => (N, K)
269269
Tensor cKV = local_tile(make_identity_tensor(K.shape()),
270-
Shape<BLK_N, HEAD_DIM>{},
270+
Shape<BLK_N, BLK_K>{},
271271
make_coord(_, _0{}));
272272

273273
// (BLK_M, BLK_N, n) => (M, N)
@@ -278,7 +278,7 @@ class Sm80KernelMha {
278278

279279
TiledMma tiled_mma;
280280
// accumulator: (MMA,MMA_M,MMA_K)
281-
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<BLK_M, HEAD_DIM>{});
281+
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<BLK_M, BLK_K>{});
282282
clear(tOrAccO);
283283

284284
auto thr_mma = tiled_mma.get_slice(tidx);

0 commit comments

Comments
 (0)