Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ cc_test(
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
SRCS
# sm80_mha_test.cu
# tests/sm80_mha_test.cu
tests/sm80_mha_pagedkv_test.cu
DEPS
:attention.kernels
Expand Down
37 changes: 12 additions & 25 deletions src/kernels/attention/collective/sm120_collective_epilogue.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,38 @@

#include "common/fast_cast.cuh"
#include "common/safe_copy.h"
#include "common/selector.h"

namespace llm {
using namespace cute;

template <class TileShape_, class Element_, int HeadDim_, bool EVEN_K_>
template <class TileShape, class Element, bool EVEN_K>
struct Sm120CollectiveEpilogue {
using TileShape = TileShape_;
using Element = Element_;

static constexpr int kHeadDim = HeadDim_;
static constexpr bool EVEN_K = EVEN_K_;

static constexpr int kThreads = 128;
static constexpr int kBlockM = get<0>(TileShape{});
static constexpr int kBlockK = get<2>(TileShape{});

using BLK_M = Int<kBlockM>;
using BLK_K = Int<kBlockK>;
using HEAD_DIM = Int<kHeadDim>;

using SmemLayoutAtom_ =
decltype(composition(Swizzle<3, 3, 3>{},
Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
decltype(smem_layout_atom_selector<Element, kBlockK>());
static constexpr int kSmemBlockK = size<1>(SmemLayoutAtom_{});

// Q smem: (BLK_M, HEAD_DIM)
// Q smem: (BLK_M, BLK_K)
using SmemLayoutO =
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K>{}));

// use 128-bit vectorizing copy
using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>;

// r2s copy atom for O
using SmemCopyAtom_ = Copy_Atom<VectorizingCopy_, Element>;

// Thr layout for gmem copy
using GmemCopyThrLayout_ =
std::conditional_t<kBlockK == 32,
Layout<Shape<_32, _4>, Stride<_4, _1>>,
Layout<Shape<_16, _8>, Stride<_8, _1>>>;

// s2g tiled copy for O
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<VectorizingCopy_, Element>{},
GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4)
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
));
using GmemTiledCopyO =
decltype(gmem_tiled_copy_selector<Element, kThreads, kBlockK>(
Copy_Atom<VectorizingCopy_, Element>{}));

struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
Expand All @@ -80,11 +67,11 @@ struct Sm120CollectiveEpilogue {
return;
}

// (BLK_M, HEAD_DIM) => (M, K)
// (BLK_M, BLK_K) => (M, K)
auto [gO, cO] = block.get_o_tile();
auto residue_mnk = block.get_residue_mnk();

// (BLK_M, HEAD_DIM)
// (BLK_M, BLK_K)
Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{});

// 1. cast output from ElementAccumulator to Element
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "common/mask.h"
#include "common/online_softmax.cuh"
#include "common/safe_copy.h"
#include "common/selector.h"
#include "sm120_collective_load_cpasync_ws.cuh"
#include "sm120_collective_load_tma_ws.cuh"

Expand All @@ -23,21 +24,20 @@ using namespace cute;

template <class TileShape_,
class Element_,
int HeadDim_,
bool EVEN_K,
bool ALIBI,
bool SOFT_CAP,
bool LOCAL,
bool KV_USE_TMA = false // whether to use TMA for K/V loading
>
struct Sm120CollectiveFMhaWs {
// exposed template parameters
using TileShape = TileShape_;
using Element = Element_;
using ElementAccum = float;

using ClusterShape = Shape<_1, _1, _1>;

static constexpr int kHeadDim = HeadDim_;
static constexpr int kBlockM = get<0>(TileShape{});
static constexpr int kBlockN = get<1>(TileShape{});
static constexpr int kBlockK = get<2>(TileShape{});
Expand All @@ -46,13 +46,9 @@ struct Sm120CollectiveFMhaWs {
static constexpr bool kLocal = LOCAL;
static constexpr bool kKVUseTma = KV_USE_TMA;

static_assert(kBlockK == 32 || kBlockK == 64);
static_assert(kHeadDim % kBlockK == 0);

using BLK_M = Int<kBlockM>;
using BLK_N = Int<kBlockN>;
using BLK_K = Int<kBlockK>;
using HEAD_DIM = Int<kHeadDim>;

// TiledMMA (64x16x16) for gemm-I and gemm-II
using MMA_Atom_ =
Expand All @@ -70,27 +66,25 @@ struct Sm120CollectiveFMhaWs {
static constexpr int StageCountQ = 1;
static constexpr int StageCountKV = 3;

// Atom layout: (8, BLK_K):(BLK_K, 1) k-major
using SmemLayoutAtom_ =
decltype(composition(Swizzle<3, 3, 3>{},
Layout<Shape<_8, BLK_K>, Stride<BLK_K, _1>>{}));
decltype(smem_layout_atom_selector<Element, kBlockK>());

// Q smem: (BLK_M, HEAD_DIM)
// Q smem: (BLK_M, BLK_K)
using SmemLayoutQ =
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, BLK_K>{}));

// KV smem: (BLK_N, HEAD_DIM, KVStages)
// KV smem: (BLK_N, BLK_K, KVStages)
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtom_{},
Shape<BLK_N, HEAD_DIM, Int<StageCountKV>>{}));
Shape<BLK_N, BLK_K, Int<StageCountKV>>{}));
using SmemLayoutV = SmemLayoutK;

// V^T smem: (HEAD_DIM, BLK_N, KVStages)
// V^T smem: (BLK_K, BLK_N, KVStages)
using SmemLayoutVt = decltype(select<1, 0, 2>(SmemLayoutV{}));

// tma transaction bytes for (BLK_N, HEAD_DIM)
static constexpr uint32_t kTmaTransactionBytes =
size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8;
// tma transaction bytes for (BLK_N, BLK_K)
static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes(
cosize(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element>);

struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
Expand Down Expand Up @@ -201,12 +195,12 @@ struct Sm120CollectiveFMhaWs {
const auto kv_len = block.get_kv_len();

// Construct smem tensors
// (BLK_M, HEAD_DIM), k-major
// (BLK_M, BLK_K), k-major
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
// (BLK_N, HEAD_DIM, KVStages), k-major
// (BLK_N, BLK_K, KVStages), k-major
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
// Tensor for V^t; used in GEMM-II.
// (HEAD_DIM, BLK_N, KVStages), k-major
// (BLK_K, BLK_N, KVStages), k-major
Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{});

TiledMma tiled_mma;
Expand Down
Loading