diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 80de7486..41976976 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -61,7 +61,7 @@ cc_test( INCLUDES ${CMAKE_CURRENT_SOURCE_DIR} SRCS - # sm80_mha_test.cu + # tests/sm80_mha_test.cu tests/sm80_mha_pagedkv_test.cu DEPS :attention.kernels diff --git a/src/kernels/attention/collective/sm120_collective_epilogue.cuh b/src/kernels/attention/collective/sm120_collective_epilogue.cuh index 86a83a6d..8ce427ec 100644 --- a/src/kernels/attention/collective/sm120_collective_epilogue.cuh +++ b/src/kernels/attention/collective/sm120_collective_epilogue.cuh @@ -9,32 +9,27 @@ #include "common/fast_cast.cuh" #include "common/safe_copy.h" +#include "common/selector.h" namespace llm { using namespace cute; -template +template struct Sm120CollectiveEpilogue { - using TileShape = TileShape_; - using Element = Element_; - - static constexpr int kHeadDim = HeadDim_; - static constexpr bool EVEN_K = EVEN_K_; - + static constexpr int kThreads = 128; static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockK = get<2>(TileShape{}); using BLK_M = Int; using BLK_K = Int; - using HEAD_DIM = Int; using SmemLayoutAtom_ = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride>{})); + decltype(smem_layout_atom_selector()); + static constexpr int kSmemBlockK = size<1>(SmemLayoutAtom_{}); - // Q smem: (BLK_M, HEAD_DIM) + // Q smem: (BLK_M, BLK_K) using SmemLayoutO = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); // use 128-bit vectorizing copy using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>; @@ -42,18 +37,10 @@ struct Sm120CollectiveEpilogue { // r2s copy atom for O using SmemCopyAtom_ = Copy_Atom; - // Thr layout for gmem copy - using GmemCopyThrLayout_ = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - // s2g tiled copy for O - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom{}, - GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); + using GmemTiledCopyO = + decltype(gmem_tiled_copy_selector( + Copy_Atom{})); struct TensorStorage { cute::array_aligned> smem_o; @@ -80,11 +67,11 @@ struct Sm120CollectiveEpilogue { return; } - // (BLK_M, HEAD_DIM) => (M, K) + // (BLK_M, BLK_K) => (M, K) auto [gO, cO] = block.get_o_tile(); auto residue_mnk = block.get_residue_mnk(); - // (BLK_M, HEAD_DIM) + // (BLK_M, BLK_K) Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{}); // 1. cast output from ElementAccumulator to Element diff --git a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh index 8512d48d..ba5e8f39 100644 --- a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh @@ -14,6 +14,7 @@ #include "common/mask.h" #include "common/online_softmax.cuh" #include "common/safe_copy.h" +#include "common/selector.h" #include "sm120_collective_load_cpasync_ws.cuh" #include "sm120_collective_load_tma_ws.cuh" @@ -23,7 +24,6 @@ using namespace cute; template struct Sm120CollectiveFMhaWs { + // exposed template parameters using TileShape = TileShape_; using Element = Element_; using ElementAccum = float; using ClusterShape = Shape<_1, _1, _1>; - static constexpr int kHeadDim = HeadDim_; static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); static constexpr int kBlockK = get<2>(TileShape{}); @@ -46,13 +46,9 @@ struct Sm120CollectiveFMhaWs { static constexpr bool kLocal = LOCAL; static constexpr bool kKVUseTma = KV_USE_TMA; - static_assert(kBlockK == 32 || kBlockK == 64); - static_assert(kHeadDim % kBlockK == 0); - using BLK_M = Int; using BLK_N = Int; using BLK_K = Int; - using HEAD_DIM = Int; // TiledMMA (64x16x16) for gemm-I and gemm-II using MMA_Atom_ = @@ -70,27 +66,25 @@ struct Sm120CollectiveFMhaWs { static constexpr int StageCountQ = 1; static constexpr int StageCountKV = 3; - // Atom layout: (8, BLK_K):(BLK_K, 1) k-major using SmemLayoutAtom_ = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride>{})); + decltype(smem_layout_atom_selector()); - // Q smem: (BLK_M, HEAD_DIM) + // Q smem: (BLK_M, BLK_K) using SmemLayoutQ = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); - // KV smem: (BLK_N, HEAD_DIM, KVStages) + // KV smem: (BLK_N, BLK_K, KVStages) using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtom_{}, - Shape>{})); + Shape>{})); using SmemLayoutV = SmemLayoutK; - // V^T smem: (HEAD_DIM, BLK_N, KVStages) + // V^T smem: (BLK_K, BLK_N, KVStages) using SmemLayoutVt = decltype(select<1, 0, 2>(SmemLayoutV{})); - // tma transaction bytes for (BLK_N, HEAD_DIM) - static constexpr uint32_t kTmaTransactionBytes = - size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8; + // tma transaction bytes for (BLK_N, BLK_K) + static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes( + cosize(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v); struct TensorStorage { cute::array_aligned> smem_q; @@ -201,12 +195,12 @@ struct Sm120CollectiveFMhaWs { const auto kv_len = block.get_kv_len(); // Construct smem tensors - // (BLK_M, HEAD_DIM), k-major + // (BLK_M, BLK_K), k-major Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM, KVStages), k-major + // (BLK_N, BLK_K, KVStages), k-major Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); // Tensor for V^t; used in GEMM-II. - // (HEAD_DIM, BLK_N, KVStages), k-major + // (BLK_K, BLK_N, KVStages), k-major Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{}); TiledMma tiled_mma; diff --git a/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh b/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh index 85efc143..6008e4d8 100644 --- a/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh @@ -10,6 +10,8 @@ #include #include "common/safe_copy.h" +#include "common/selector.h" +#include "common/static_dispatch.h" namespace llm { @@ -25,6 +27,14 @@ template struct Sm120CollectiveLoadCpAsyncWs { + static constexpr int kThreads = 128; + static constexpr int kBlockK = get<2>(TileShape{}); + // g2s tiled copy for Q/K/V + using GmemTiledCopy = + decltype(gmem_tiled_copy_selector( + Copy_Atom, + Element>{})); + // load Q/K/V tiles from gmem to smem using cp_async template CUTE_DEVICE void operator()(const Block& block, @@ -34,7 +44,7 @@ struct Sm120CollectiveLoadCpAsyncWs { PipelineKV& kv_pipeline, typename PipelineKV::PipelineState& kv_state, TensorStorage& ss) { - static constexpr int kBlockK = get<2>(TileShape{}); + static constexpr int kStages = size<2>(SmemLayoutK{}); if (!block.is_valid()) { // skip invalid block @@ -49,30 +59,20 @@ struct Sm120CollectiveLoadCpAsyncWs { // (M, N, K) const auto residue_mnk = block.get_residue_mnk(); - // (BLK_M, HEAD_DIM) => (M, K) + // (BLK_M, BLK_K) => (M, K) auto [gQ, cQ] = block.get_q_tile(); - // (BLK_N, HEAD_DIM, n) => (N, K) + // (BLK_N, BLK_K, n) => (N, K) auto [gK, gV, cKV] = block.get_kv_tile(); // Construct smem tensors - // (BLK_M, HEAD_DIM), k-major + // (BLK_M, BLK_K), k-major Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM, KVStages), k-major + // (BLK_N, BLK_K, KVStages), k-major Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{}); - // Thr thread layout for gmem copy (4 warps = 128 threads) - using GmemCopyThrLayout_ = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - // g2s tiled copy for q/kv - auto gmem_tiled_copy = make_tiled_copy( - Copy_Atom, Element>{}, - GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - ); + GmemTiledCopy gmem_tiled_copy; auto gmem_thr_copy = gmem_tiled_copy.get_slice(tidx); // (CPY, CPY_N, CPY_K) => (M, K) @@ -94,84 +94,115 @@ struct Sm120CollectiveLoadCpAsyncWs { const auto residue_mk = select<0, 2>(residue_mnk); const auto residue_nk = select<1, 2>(residue_mnk); - auto load_query = [&](auto& state) { - q_pipeline.producer_acquire(state); - safe_copy( + auto load_query = [&]() { + q_pipeline.producer_acquire(q_state); + safe_copy( gmem_tiled_copy, tGgQ, tGsQ, tGcQ, residue_mk); - q_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); - ++state; + q_pipeline.producer_commit(q_state, + cutlass::arch::cpasync_barrier_arrive); + ++q_state; }; - auto load_key = [&](int ni, auto& state) { - kv_pipeline.producer_acquire(state); + auto load_key = [&](int ni) { + kv_pipeline.producer_acquire(kv_state); // skip ZFILL_MN for key since Mask will mask out oob with -inf safe_copy( gmem_tiled_copy, tGgK(_, _, _, ni), - tGsK(_, _, _, state.index()), + tGsK(_, _, _, kv_state.index()), tGcKV(_, _, _, ni), residue_nk); - kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); - ++state; + kv_pipeline.producer_commit(kv_state, + cutlass::arch::cpasync_barrier_arrive); + ++kv_state; }; // load key without oob handling - auto load_key_no_oob = [&](int ni, auto& state) { - kv_pipeline.producer_acquire(state); - safe_copy( - gmem_tiled_copy, - tGgK(_, _, _, ni), - tGsK(_, _, _, state.index()), - tGcKV(_, _, _, ni), - residue_nk); - kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); - ++state; + auto load_key_no_oob = [&](int ni) { + kv_pipeline.producer_acquire(kv_state); + if constexpr (EVEN_K) { + safe_copy(gmem_tiled_copy, + tGgK(_, _, _, ni), + tGsK(_, _, _, kv_state.index()), + tGcKV(_, _, _, ni), + residue_nk); + } else { + DISPATCH_BOOL(kv_state.count() < kStages, ZFILL_K, [&] { + safe_copy( + gmem_tiled_copy, + tGgK(_, _, _, ni), + tGsK(_, _, _, kv_state.index()), + tGcKV(_, _, _, ni), + residue_nk); + }); + } + kv_pipeline.producer_commit(kv_state, + cutlass::arch::cpasync_barrier_arrive); + ++kv_state; }; - auto load_value = [&](int ni, auto& state) { - kv_pipeline.producer_acquire(state); + auto load_value = [&](int ni) { + kv_pipeline.producer_acquire(kv_state); // skipping ZFILL_MN for v may cause nan issue safe_copy( gmem_tiled_copy, tGgV(_, _, _, ni), - tGsV(_, _, _, state.index()), + tGsV(_, _, _, kv_state.index()), tGcKV(_, _, _, ni), residue_nk); - kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); - ++state; + kv_pipeline.producer_commit(kv_state, + cutlass::arch::cpasync_barrier_arrive); + ++kv_state; }; // load value without oob handling - auto load_value_no_oob = [&](int ni, auto& state) { - kv_pipeline.producer_acquire(state); - safe_copy( - gmem_tiled_copy, - tGgV(_, _, _, ni), - tGsV(_, _, _, state.index()), - tGcKV(_, _, _, ni), - residue_nk); - kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); - ++state; + auto load_value_no_oob = [&](int ni) { + kv_pipeline.producer_acquire(kv_state); + if constexpr (EVEN_K) { + safe_copy(gmem_tiled_copy, + tGgV(_, _, _, ni), + tGsV(_, _, _, kv_state.index()), + tGcKV(_, _, _, ni), + residue_nk); + } else { + DISPATCH_BOOL(kv_state.count() < kStages, ZFILL_K, [&] { + safe_copy( + gmem_tiled_copy, + tGgV(_, _, _, ni), + tGsV(_, _, _, kv_state.index()), + tGcKV(_, _, _, ni), + residue_nk); + }); + } + kv_pipeline.producer_commit(kv_state, + cutlass::arch::cpasync_barrier_arrive); + ++kv_state; }; // async copy gmem to smem in following order: // Q0, Kn-1, Vn-1, ..., K1, V1, K0, V0 // load Q1 - load_query(q_state); + load_query(); // load Kn-1, Vn-1 with oob handling int ni = n_block_max - 1; - load_key(ni, kv_state); - load_value(ni, kv_state); + load_key(ni); + load_value(ni); --ni; CUTE_NO_UNROLL while (ni >= n_block_min) { // load Ki - load_key_no_oob(ni, kv_state); + load_key_no_oob(ni); // load Vi - load_value_no_oob(ni, kv_state); + load_value_no_oob(ni); // advance to next kv block --ni; } diff --git a/src/kernels/attention/collective/sm120_collective_load_tma_ws.cuh b/src/kernels/attention/collective/sm120_collective_load_tma_ws.cuh index 68c300d8..421e0e2c 100644 --- a/src/kernels/attention/collective/sm120_collective_load_tma_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_load_tma_ws.cuh @@ -10,12 +10,14 @@ #include #include "common/safe_copy.h" +#include "common/selector.h" namespace llm { using namespace cute; template struct Sm120CollectiveLoadTmaWs { + static constexpr int kThreads = 128; + static constexpr int kBlockK = get<2>(TileShape{}); + // g2s tiled copy for Q + using GmemTiledCopyQ = + decltype(gmem_tiled_copy_selector( + Copy_Atom, + Element>{})); + + // using StrideK = ...; + + // using TMA_K = decltype(make_tma_copy( + // GmemTiledCopy{}, // TMA_COPY + // make_tensor(static_cast(nullptr), + // repeat_like(StrideK{}, int32_t(0)), StrideK{}), + // SmemLayoutK{}(_,_,_0{}))); + + // Tensor tensor_k = make_tensor(ptr_k, make_layout(make_shape(M,K,L), + // args.stride_k)); auto tma_load_k = make_tma_copy(SM90_TMA_LOAD{}, + // gtensor_k, SmemLayoutK{}(_,_,_0{})); + // load Q using cp_async and K/V using tma template CUTE_DEVICE void operator()(const Block& block, @@ -34,8 +56,6 @@ struct Sm120CollectiveLoadTmaWs { PipelineKV& kv_pipeline, typename PipelineKV::PipelineState& kv_state, TensorStorage& ss) { - static constexpr int kBlockK = get<2>(TileShape{}); - if (!block.is_valid()) { // skip invalid block return; @@ -49,76 +69,74 @@ struct Sm120CollectiveLoadTmaWs { // (M, N, K) const auto residue_mnk = block.get_residue_mnk(); - // (BLK_M, HEAD_DIM) => (M, K) + // (BLK_M, BLK_K) => (M, K) auto [gQ, cQ] = block.get_q_tile(); // Construct smem tensors - // (BLK_M, HEAD_DIM), k-major + // (BLK_M, BLK_K), k-major Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM, KVStages), k-major + // (BLK_N, BLK_K, KVStages), k-major Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{}); - // Thr thread layout for gmem copy (4 warps = 128 threads) - using GmemCopyThrLayout_ = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - - // g2s tiled copy for q/kv - auto gmem_tiled_copy = make_tiled_copy( - Copy_Atom, Element>{}, - GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - ); - auto gmem_thr_copy = gmem_tiled_copy.get_slice(tidx); + // g2s tiled copy for q + GmemTiledCopyQ gmem_tiled_copy_q; + auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(tidx); // (CPY, CPY_N, CPY_K) => (M, K) - Tensor tGcQ = gmem_thr_copy.partition_S(cQ); + Tensor tGcQ = gmem_thr_copy_q.partition_S(cQ); // (CPY, CPY_N, CPY_K) - Tensor tGgQ = gmem_thr_copy.partition_S(gQ); - Tensor tGsQ = gmem_thr_copy.partition_D(sQ); + Tensor tGgQ = gmem_thr_copy_q.partition_S(gQ); + Tensor tGsQ = gmem_thr_copy_q.partition_D(sQ); // TODO: copy k/v using TMA + // ??? where to define TMA copy? + // 1> block, need smem layout (pass) + // 2> tma_load, need gmem tensor (pass) + // 3> mainloop, has smem layout and gtensor (x) + + // where to keep tma_load_kv? + // as args in load_tma_ws? or pass in as parameters? const auto residue_mk = select<0, 2>(residue_mnk); - auto load_query = [&](auto& state) { - q_pipeline.producer_acquire(state); + auto load_query = [&]() { + q_pipeline.producer_acquire(q_state); safe_copy( - gmem_tiled_copy, tGgQ, tGsQ, tGcQ, residue_mk); - q_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); - ++state; + gmem_tiled_copy_q, tGgQ, tGsQ, tGcQ, residue_mk); + q_pipeline.producer_commit(q_state, + cutlass::arch::cpasync_barrier_arrive); + ++q_state; }; - auto load_key = [&](int ni, auto& state) { - kv_pipeline.producer_acquire(state); + auto load_key = [&](int ni) { + kv_pipeline.producer_acquire(kv_state); // TMA copy - kv_pipeline.producer_commit(state); // no op for tma - ++state; + // kv_pipeline.producer_commit(state); + ++kv_state; }; - auto load_value = [&](int ni, auto& state) { - kv_pipeline.producer_acquire(state); + auto load_value = [&](int ni) { + kv_pipeline.producer_acquire(kv_state); // TMA copy - kv_pipeline.producer_commit(state); // no op for tma - ++state; + // kv_pipeline.producer_commit(state); + ++kv_state; }; // async copy gmem to smem in following order: // Q0, Kn-1, Vn-1, ..., K1, V1, K0, V0 // load Q1 - load_query(q_state); + load_query(); // load Kn-1, Vn-1 CUTE_NO_UNROLL for (int ni = n_block_max - 1; ni >= n_block_min; --ni) { // load Ki - load_key(ni, kv_state); + load_key(ni); // load Vi - load_value(ni, kv_state); + load_value(ni); } } }; diff --git a/src/kernels/attention/collective/sm80_collective_epilogue.cuh b/src/kernels/attention/collective/sm80_collective_epilogue.cuh index 05329a3f..0504e11a 100644 --- a/src/kernels/attention/collective/sm80_collective_epilogue.cuh +++ b/src/kernels/attention/collective/sm80_collective_epilogue.cuh @@ -9,16 +9,17 @@ #include "common/fast_cast.cuh" #include "common/safe_copy.h" +#include "common/selector.h" namespace llm { using namespace cute; -template +template struct Sm80CollectiveEpilogue { using TileShape = TileShape_; using Element = Element_; - static constexpr int kHeadDim = HeadDim_; + static constexpr int kThreads = 128; static constexpr bool EVEN_K = EVEN_K_; static constexpr int kBlockM = get<0>(TileShape{}); @@ -26,15 +27,13 @@ struct Sm80CollectiveEpilogue { using BLK_M = Int; using BLK_K = Int; - using HEAD_DIM = Int; using SmemLayoutAtom_ = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride>{})); + decltype(smem_layout_atom_selector()); - // Q smem: (BLK_M, HEAD_DIM) + // Q smem: (BLK_M, BLK_K) using SmemLayoutO = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); // use 128-bit vectorizing copy using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>; @@ -42,18 +41,10 @@ struct Sm80CollectiveEpilogue { // r2s copy atom for O using SmemCopyAtom_ = Copy_Atom; - // Thr layout for gmem copy - using GmemCopyThrLayout_ = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - // s2g tiled copy for O - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom{}, - GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); + using GmemTiledCopyO = + decltype(gmem_tiled_copy_selector( + Copy_Atom{})); struct SharedStorage : cute::aligned_struct<128> { cute::array_aligned> smem_o; @@ -73,20 +64,19 @@ struct Sm80CollectiveEpilogue { class TensorO, class TensorCO, class ResidueMNK> - CUTE_DEVICE void operator()( - const Params& /*params*/, - const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N) - TiledMma tiled_mma, - TensorO& gO, // (BLK_M, HEAD_DIM) - const TensorCO& cO, // (BLK_M, HEAD_DIM) => (M, K) - int tidx, - const ResidueMNK& residue_mnk, - char* smem) { + CUTE_DEVICE void operator()(const Params& /*params*/, + const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N) + TiledMma tiled_mma, + TensorO& gO, // (BLK_M, BLK_K) + const TensorCO& cO, // (BLK_M, BLK_K) => (M, K) + int tidx, + const ResidueMNK& residue_mnk, + char* smem) { static constexpr int kBlockM = get<0>(TileShape{}); // Smem auto& ss = *reinterpret_cast(smem); - // (BLK_M, HEAD_DIM) + // (BLK_M, BLK_K) Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{}); // 1. cast output from ElementAccumulator to Element diff --git a/src/kernels/attention/collective/sm80_collective_mha.cuh b/src/kernels/attention/collective/sm80_collective_mha.cuh index 19e6a6be..e326c0bb 100644 --- a/src/kernels/attention/collective/sm80_collective_mha.cuh +++ b/src/kernels/attention/collective/sm80_collective_mha.cuh @@ -11,6 +11,7 @@ #include "common/fast_cast.cuh" #include "common/layout_convertor.h" #include "common/safe_copy.h" +#include "common/selector.h" namespace llm { @@ -18,7 +19,6 @@ using namespace cute; template (TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); static constexpr int kBlockK = get<2>(TileShape{}); @@ -37,13 +36,9 @@ struct Sm80CollectiveMha { static constexpr bool kAlibi = ALIBI; static constexpr bool kLocal = LOCAL; - static_assert(kBlockK == 32 || kBlockK == 64); - static_assert(kHeadDim % kBlockK == 0); - using BLK_M = Int; using BLK_N = Int; using BLK_K = Int; - using HEAD_DIM = Int; // TiledMMA (64x16x16) for gemm-I and gemm-II using MMA_Atom_ = @@ -57,36 +52,26 @@ struct Sm80CollectiveMha { static constexpr int kRowsPerMMA = 2; static constexpr int kMmaThreads = size(TiledMma{}); - // Atom layout: (8, BLK_K):(BLK_K, 1) k-major + // Atom layout for shared memory using SmemLayoutAtom_ = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride>{})); + decltype(smem_layout_atom_selector()); - // Q smem: (BLK_M, HEAD_DIM) + // Q smem: (BLK_M, BLK_K) using SmemLayoutQ = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); - // KV smem: (BLK_N, HEAD_DIM) + // KV smem: (BLK_N, BLK_K) using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); - // V^T smem: (HEAD_DIM, BLK_N) + // V^T smem: (BLK_K, BLK_N) using SmemLayoutVt = decltype(select<1, 0>(SmemLayoutV{})); - // Thr layout for gmem copy - using GmemCopyThrLayout_ = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - - // g2s tiled copy for q - using GmemTiledCopyQ = decltype(make_tiled_copy( - Copy_Atom, Element>{}, - GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); + using GmemTiledCopyQ = + decltype(gmem_tiled_copy_selector( + Copy_Atom, Element>{})); // g2s tiled copy for kv using GmemTiledCopyKV = GmemTiledCopyQ; @@ -145,11 +130,11 @@ struct Sm80CollectiveMha { class ResidueMNK> CUTE_DEVICE void operator()( const Params& params, - const TensorQ& gQ, // (BLK_M, HEAD_DIM) - const TensorCQ& cQ, // (BLK_M, HEAD_DIM) => (M, K) - const TensorK& gK, // (BLK_N, HEAD_DIM, n) - const TensorV& gV, // (BLK_N, HEAD_DIM, n) - const TensorCKV& cKV, // (BLK_N, HEAD_DIM, n) => (N, K) + const TensorQ& gQ, // (BLK_M, BLK_K) + const TensorCQ& cQ, // (BLK_M, BLK_K) => (M, K) + const TensorK& gK, // (BLK_N, BLK_K, n) + const TensorV& gV, // (BLK_N, BLK_K, n) + const TensorCKV& cKV, // (BLK_N, BLK_K, n) => (N, K) const TensorCMN& tScMN_mn, // ((2, MMA_M), (2, MMA_N), n) => (M, N) FrgTensor& tOrO, // (MMA, MMA_M, MMA_N) Softmax& softmax, @@ -173,14 +158,14 @@ struct Sm80CollectiveMha { // Construct shared memory tiles auto& ss = *reinterpret_cast(smem); - // (BLK_M, HEAD_DIM), k-major + // (BLK_M, BLK_K), k-major Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM), k-major + // (BLK_N, BLK_K), k-major Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{}); // Tensor for V^t; used in GEMM-II. - // (HEAD_DIM, BLK_N), k-major + // (BLK_K, BLK_N), k-major Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{}); // g2s tiled copy for qkv diff --git a/src/kernels/attention/common/fmha_block.h b/src/kernels/attention/common/fmha_block.h index 622d0dd9..8d11f818 100644 --- a/src/kernels/attention/common/fmha_block.h +++ b/src/kernels/attention/common/fmha_block.h @@ -14,15 +14,15 @@ using namespace cute; template struct FmhaBlock { static constexpr int kBlockM = get<0>(TileShape{}); static constexpr int kBlockN = get<1>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); using BLK_M = Int; using BLK_N = Int; - using HEAD_DIM = Int; + using BLK_K = Int; // hold a reference to the parameters and block coordination const Params& params_; @@ -89,70 +89,77 @@ struct FmhaBlock { } } - // return the query tile: (BLK_M, HEAD_DIM) => (M, K) + // return the query tile: (BLK_M, BLK_K) => (M, K) CUTE_HOST_DEVICE auto get_q_tile() const { const auto& [batch_idx, m_block_idx, kv_head_idx] = blk_coord_; // packing all q in the same kv head group together - const auto head_base = kv_head_idx * params_.group_size; - auto packed_idx_to_coord = [this, head_base](int packed_idx) { + auto packed_idx_to_coord = [this](int packed_idx) { // packed_idx => (seq, kv_heads):(group_size, 1) int idx, offset; params_.group_size.divmod(packed_idx, idx, offset); - return make_coord(idx, head_base + offset); + return make_coord(idx, offset); }; - // (batch, seq, head, dim) => ((seq, kv_head), dim) - const auto offset = batch_idx * get<0>(params_.q_stride); - // (q_packed_len, head_dim) gmem tensor + // (batch, seq, head, dim) + // => (batch, seq, (kv_heads, group), dim) + // => (seq, group, dim) + const auto offset = + batch_idx * get<0>(params_.q_stride) + + kv_head_idx * params_.group_size * get<2>(params_.q_stride); + // gmem tensor: (packed_len, dim) => ((seq, group), dim) auto Q = make_gather_tensor( make_gmem_ptr((const Element*)params_.q_ptr + offset), make_shape(packed_len_, params_.head_dim), make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)), packed_idx_to_coord); - // (BLK_M, HEAD_DIM) + // (BLK_M, BLK_K) Tensor gQ = - local_tile(Q, Shape{}, make_coord(m_block_idx, _0{})); - // (BLK_M, HEAD_DIM) => (M, K) + local_tile(Q, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_M, BLK_K) => (M, K) Tensor cQ = local_tile(make_identity_tensor(shape(Q)), - Shape{}, + Shape{}, make_coord(m_block_idx, _0{})); return make_tuple(gQ, cQ); } - // return the output tile: (BLK_M, HEAD_DIM) => (M, K) + // return the output tile: (BLK_M, BLK_K) => (M, K) CUTE_HOST_DEVICE auto get_o_tile() const { const auto& [batch_idx, m_block_idx, kv_head_idx] = blk_coord_; // packing all q in the same kv head group together - const auto head_base = kv_head_idx * params_.group_size; - auto packed_idx_to_coord = [this, head_base](int packed_idx) { + auto packed_idx_to_coord = [this](int packed_idx) { // packed_idx => (seq, kv_heads):(group_size, 1) int idx, offset; params_.group_size.divmod(packed_idx, idx, offset); - return make_coord(idx, head_base + offset); + return make_coord(idx, offset); }; - // (batch, seq, head, dim) => ((seq, head), dim) - const auto offset = batch_idx * get<0>(params_.o_stride); + // (batch, seq, head, dim) + // => (batch, seq, (kv_heads, group), dim) + // => (seq, group, dim) + const auto offset = + batch_idx * get<0>(params_.o_stride) + + kv_head_idx * params_.group_size * get<2>(params_.o_stride); + // gmem tensor: (packed_len, dim) => ((seq, group), dim) auto O = make_gather_tensor( make_gmem_ptr((Element*)params_.o_ptr + offset), make_shape(packed_len_, params_.head_dim), make_stride(select<1, 2>(params_.o_stride), get<3>(params_.o_stride)), packed_idx_to_coord); - // (BLK_M, HEAD_DIM) + // (BLK_M, BLK_K) Tensor gO = - local_tile(O, Shape{}, make_coord(m_block_idx, _0{})); - // (BLK_M, HEAD_DIM) => (M, K) + local_tile(O, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_M, BLK_K) => (M, K) Tensor cQ = local_tile(make_identity_tensor(shape(O)), - Shape{}, + Shape{}, make_coord(m_block_idx, _0{})); return make_tuple(gO, cQ); } - // return the key/value tile: (BLK_N, HEAD_DIM, n) => (N, K) + // return the key/value tile: (BLK_N, BLK_K, n) => (N, K) CUTE_HOST_DEVICE auto get_kv_tile() const { const auto& [batch_idx, m_block_idx, kv_head_idx] = blk_coord_; @@ -171,15 +178,42 @@ struct FmhaBlock { make_shape(params_.kv_len, params_.head_dim), select<1, 3>(params_.v_stride)); - // (BLK_N, HEAD_DIM, n) - Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); - Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); - // (BLK_N, HEAD_DIM, n) => (N, K) + // (BLK_N, BLK_K, n) + Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); + Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + // (BLK_N, BLK_K, n) => (N, K) Tensor cKV = local_tile(make_identity_tensor(shape(K)), - Shape{}, + Shape{}, make_coord(_, _0{})); return make_tuple(gK, gV, cKV); } + + // functions for tma load + // returns kv tma tile: (BLK_N, BLK_K, n) => (1@0, 1@1, 1@2) + template + CUTE_HOST_DEVICE auto get_kv_tma_tile(TMA_K tma_k, TMA_V tma_v) const { + // 1: make_gather_tma_tensor() + // tma_tensor = (seq, dim, kv_head) => (1@0, 1@1, 1@2) + // 2: partition into tiles + // tma_tile = (BLK_N, BLK_K, n) => (1@0, 1@1, 1@2) + + // (Q, D, (B, H)) + // Tensor mQ_qdl_p = tma_k.get_tma_tensor(select<0,2,3>(problem_shape)); + // Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, + // q_offs_2_1)), mQ_qdl_p); (BLK_N, BLK_K, m, k, (b)) Tensor gQ_qdl = + // local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, + // _1>{}); + + // outside in caller part + // (BLK_N, BLK_K, n) => (TMA,TMA_M,TMA_N, n) + // auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + // (TMA,TMA_M,TMA_N,REST_M,REST_N) + // Tensor tAgA_x = cta_tma.partition_S(gA); + // (TMA,TMA_M,TMA_N) + // Tensor tAsA_x = cta_tma.partition_D(sA); + + return; + } }; } // namespace llm diff --git a/src/kernels/attention/common/selector.h b/src/kernels/attention/common/selector.h new file mode 100644 index 00000000..ce7fd7ce --- /dev/null +++ b/src/kernels/attention/common/selector.h @@ -0,0 +1,75 @@ +#pragma once +#include // cute::CUTE_HOST_DEVICE +#include // cute::smem_ptr_flag +#include // cute::Swizzle + +namespace llm { +using namespace cute; +// clang-format off +namespace detail { +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// +// K-major GMMA layouts in units of bits +using Layout_K_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _128,_1>>>; +using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; +using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; +using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; + +// K-major layouts in units of Type +template +using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); +template +using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); +template +using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); +template +using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); + +} // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto smem_layout_atom_selector() { + if constexpr (kBlockK % size<1>(detail::Layout_K_SW128_Atom{}) == 0) { + return detail::Layout_K_SW128_Atom{}; + } + else if constexpr (kBlockK % size<1>(detail::Layout_K_SW64_Atom{}) == 0) { + return detail::Layout_K_SW64_Atom{}; + } + else if constexpr (kBlockK % size<1>(detail::Layout_K_SW32_Atom{}) == 0) { + return detail::Layout_K_SW32_Atom{}; + } + else if constexpr (kBlockK % size<1>(detail::Layout_K_INTER_Atom{}) == 0) { + return detail::Layout_K_INTER_Atom{}; + } + else { + static_assert(kBlockK % size<1>(detail::Layout_K_INTER_Atom{}) == 0, + "kBlockK must be a multiple of size<1>(detail::Layout_K_INTER_Atom{})"); + } +} +// clang-format on + +template +CUTE_HOST_DEVICE constexpr auto gmem_tiled_copy_selector(Copy_Atom cp_atom) { + // maxmize vectorized load (128-bits or 16 bytes per thread) + constexpr int kElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + constexpr int kSmemBlockK = + size<1>(smem_layout_atom_selector()); + static_assert(kSmemBlockK % kElemsPerLoad == 0, + "kBlockK must be a multiple of kGmemElemsPerLoad"); + + constexpr int kThreadsPerRow = kSmemBlockK / kElemsPerLoad; + static_assert(kThreads % kThreadsPerRow == 0, + "kThreads must be a multiple of kThreadsPerRow"); + constexpr int kRows = kThreads / kThreadsPerRow; + static_assert(kRows <= 64, "kRows must be less than or equal to 64"); + + constexpr auto thr_layout = Layout, Int>, + Stride, _1>>{}; + constexpr auto val_layout = Layout>>{}; + + // g2s tiled copy + return make_tiled_copy(cp_atom, thr_layout, val_layout); +} + +} // namespace llm diff --git a/src/kernels/attention/device/sm120_fmha_launch.cuh b/src/kernels/attention/device/sm120_fmha_launch.cuh index 4975f8f4..a6cefe04 100644 --- a/src/kernels/attention/device/sm120_fmha_launch.cuh +++ b/src/kernels/attention/device/sm120_fmha_launch.cuh @@ -38,7 +38,7 @@ void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream) { const auto n_kv_heads = params.n_kv_heads; const auto max_q_packed_len = params.max_q_len * params.group_size; - // TODO: tune block shape MNK based on the head dim and smem size + // TODO: tune tile shape M/N based on the head dim and smem size // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | @@ -51,18 +51,19 @@ void sm120_launch_mha_kernel(const Params& params, cudaStream_t stream) { // * 12.0 : 0, 8, 16, 32, 64, 100 constexpr int BLK_M = 64; constexpr int BLK_N = 64; - constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32; - using TileShape = Shape, Int, Int>; + // TMA is used for K/V loading + constexpr bool KV_USE_TMA = false; + + using TileShape = Shape, Int, Int>; using CollectiveMainloop = Sm120CollectiveFMhaWs; - using CollectiveEpilogue = - Sm120CollectiveEpilogue; + LOCAL, + KV_USE_TMA>; + using CollectiveEpilogue = Sm120CollectiveEpilogue; // TODO: support persistent kernels using TileScheduler = SingleTileScheduler; diff --git a/src/kernels/attention/device/sm80_mha_launch.cuh b/src/kernels/attention/device/sm80_mha_launch.cuh index 89bffb63..1c44ac90 100644 --- a/src/kernels/attention/device/sm80_mha_launch.cuh +++ b/src/kernels/attention/device/sm80_mha_launch.cuh @@ -51,18 +51,11 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { // * 12.0 : 0, 8, 16, 32, 64, 100 constexpr int BLK_M = 64; constexpr int BLK_N = 64; - constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32; - - using TileShape = Shape, Int, Int>; - using CollectiveMainloop = Sm80CollectiveMha; - using CollectiveEpilogue = - Sm80CollectiveEpilogue; + + using TileShape = Shape, Int, Int>; + using CollectiveMainloop = + Sm80CollectiveMha; + using CollectiveEpilogue = Sm80CollectiveEpilogue; // TODO: support persistent kernels using TileScheduler = SingleTileScheduler; diff --git a/src/kernels/attention/fmha_params.h b/src/kernels/attention/fmha_params.h new file mode 100644 index 00000000..f36bfbe0 --- /dev/null +++ b/src/kernels/attention/fmha_params.h @@ -0,0 +1,93 @@ +#pragma once + +#include + +namespace llm { + +// Params for fused multi-head attention (FMHA) kernels +struct FmhaParams { + //////////////////////////////////////////////// + // Parameters for input/output tensors + //////////////////////////////////////////////// + const void* __restrict__ q_ptr = nullptr; + const void* __restrict__ k_ptr = nullptr; + const void* __restrict__ v_ptr = nullptr; + void* __restrict__ o_ptr = nullptr; + + // Parameters for input shapes + int batch_size = 0; + int n_heads = 0; + int n_kv_heads = 0; + int head_dim = 0; + + // strides for query, key, value, and output tensors + // Tensor shape: (batch, seq, head, dim): last dimension is contiguous + // N.B. for variable length sequence, the q/k/v/o_batch_stride is not used. + int64_t q_batch_stride; + int64_t q_seq_stride; + int64_t q_head_stride; + + int64_t k_batch_stride; + int64_t k_seq_stride; + int64_t k_head_stride; + + int64_t v_batch_stride; + int64_t v_seq_stride; + int64_t v_head_stride; + + int64_t o_batch_stride; + int64_t o_seq_stride; + int64_t o_head_stride; + + //////////////////////////////////////////////// + // Parameters for sequence length + //////////////////////////////////////////////// + // Only used for fix length sequence + int q_len = 0; + int kv_len = 0; + + // Only used for variable length sequence + // array of length batch_size + 1 holding starting offset of each sequence. + const int* __restrict__ q_cu_lens = nullptr; + const int* __restrict__ kv_cu_lens = nullptr; + + //////////////////////////////////////////////// + // Parameters for paged KV cache + //////////////////////////////////////////////// + // size for each cache block + int block_size = 1; + // the first slot id of each block + const int* __restrict__ block_table = nullptr; + // array of length batch_size + 1 holding starting offset of each sequence. + const int* __restrict__ block_cu_lens = nullptr; + + //////////////////////////////////////////////// + // Parameters for local attention + //////////////////////////////////////////////// + // left sliding window size + int sliding_window = -1; + + //////////////////////////////////////////////// + // Parameters for logits soft cap + //////////////////////////////////////////////// + float logits_soft_cap = 0.0; + + //////////////////////////////////////////////// + // Parameters for softmax + //////////////////////////////////////////////// + // softmax scaling + float sm_scale = 1.0; + + //////////////////////////////////////////////// + // Parameters for alibi positional encoding + //////////////////////////////////////////////// + const float* __restrict__ alibi_slopes_ptr = nullptr; // [n_heads] + + //////////////////////////////////////////////// + // Parameters for scheduling + //////////////////////////////////////////////// + // TODO: remove it after persistent kernel + int max_q_len = 0; +}; + +} // namespace llm diff --git a/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh index d3048e23..be36b22f 100644 --- a/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh +++ b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh @@ -118,9 +118,8 @@ class Sm120KernelFmhaWs { PipelineQ& q_pipeline, PipelineKV& kv_pipeline, SharedStorage& ss) { - static constexpr int kHeadDim = CollectiveMainloop::kHeadDim; static constexpr bool kLocal = CollectiveMainloop::kLocal; - using Block = FmhaBlock; + using Block = FmhaBlock; auto q_state = cutlass::make_producer_start_state(); auto kv_state = cutlass::make_producer_start_state(); @@ -150,13 +149,12 @@ class Sm120KernelFmhaWs { PipelineQ& q_pipeline, PipelineKV& kv_pipeline, SharedStorage& ss) { - static constexpr int kHeadDim = CollectiveMainloop::kHeadDim; static constexpr bool kLocal = CollectiveMainloop::kLocal; using TiledMma = typename CollectiveMainloop::TiledMma; using BLK_M = typename CollectiveMainloop::BLK_M; - using HEAD_DIM = typename CollectiveMainloop::HEAD_DIM; + using BLK_K = typename CollectiveMainloop::BLK_K; - using Block = FmhaBlock; + using Block = FmhaBlock; PipelineStateQ q_state; PipelineStateKV kv_state; @@ -187,7 +185,7 @@ class Sm120KernelFmhaWs { TiledMma tiled_mma; // accumulator: (MMA,MMA_M,MMA_K) - auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); + auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); clear(tOrAccO); mainloop.fmha(mainloop_params, diff --git a/src/kernels/attention/kernel/sm80_kernel_mha.cuh b/src/kernels/attention/kernel/sm80_kernel_mha.cuh index f3bc5850..6b67f279 100644 --- a/src/kernels/attention/kernel/sm80_kernel_mha.cuh +++ b/src/kernels/attention/kernel/sm80_kernel_mha.cuh @@ -41,22 +41,25 @@ struct MHATile { // (batch, seq, head, dim) // packed all q/o in the same kv head group together - const auto head_base = kv_head_idx_ * params_.group_size; - auto packed_idx_to_coord = [this, head_base](int packed_idx) { + auto packed_idx_to_coord = [this](int packed_idx) { int idx, offset; params_.group_size.divmod(packed_idx, idx, offset); - return make_coord(idx, head_base + offset); + return make_coord(idx, offset); }; const auto packed_len = params_.q_len * params_.group_size; - const auto q_offset = batch_idx_ * get<0>(params_.q_stride); + const auto q_offset = + (batch_idx_ * get<0>(params_.q_stride)) + + (kv_head_idx_ * params_.group_size * get<2>(params_.q_stride)); auto q = make_gather_tensor( make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(packed_len, params_.head_dim), make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)), packed_idx_to_coord); - const auto o_offset = batch_idx_ * get<0>(params_.o_stride); + const auto o_offset = + (batch_idx_ * get<0>(params_.o_stride)) + + (kv_head_idx_ * params_.group_size * get<2>(params_.o_stride)); auto o = make_gather_tensor( make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(packed_len, params_.head_dim), @@ -69,10 +72,10 @@ struct MHATile { template CUTE_HOST_DEVICE auto get_kv_tile() const { // (batch, seq, kv_head, dim) - const auto k_offset = batch_idx_ * get<0>(params_.k_stride) + - kv_head_idx_ * get<2>(params_.k_stride); - const auto v_offset = batch_idx_ * get<0>(params_.v_stride) + - kv_head_idx_ * get<2>(params_.v_stride); + const auto k_offset = (batch_idx_ * get<0>(params_.k_stride)) + + (kv_head_idx_ * get<2>(params_.k_stride)); + const auto v_offset = (batch_idx_ * get<0>(params_.v_stride)) + + (kv_head_idx_ * get<2>(params_.v_stride)); // k[batch_idx, :, kv_head_idx, :] auto k = make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset), @@ -105,22 +108,26 @@ struct MHATile { CUTE_HOST_DEVICE auto get_qo_tile() const { const auto begin = params_.q_cu_lens[batch_idx_]; const auto qo_len = params_.q_cu_lens[batch_idx_ + 1] - begin; - const auto head_base = kv_head_idx_ * params_.group_size; - auto packed_idx_to_coord = [this, head_base](int packed_idx) { + + auto packed_idx_to_coord = [this](int packed_idx) { int idx, offset; params_.group_size.divmod(packed_idx, idx, offset); - return make_coord(idx, head_base + offset); + return make_coord(idx, offset); }; const auto packed_len = qo_len * params_.group_size; - const auto q_offset = begin * get<0>(params_.q_stride); + const auto q_offset = + (begin * get<0>(params_.q_stride)) + + (kv_head_idx_ * params_.group_size * get<1>(params_.q_stride)); auto q = make_gather_tensor( make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(packed_len, params_.head_dim), make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)), packed_idx_to_coord); - const auto o_offset = begin * get<0>(params_.o_stride); + const auto o_offset = + (begin * get<0>(params_.o_stride)) + + (kv_head_idx_ * params_.group_size * get<1>(params_.o_stride)); auto o = make_gather_tensor( make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(packed_len, params_.head_dim), @@ -178,7 +185,7 @@ class Sm80KernelMha { using Element = typename CollectiveMainloop::Element; using BLK_M = typename CollectiveMainloop::BLK_M; using BLK_N = typename CollectiveMainloop::BLK_N; - using HEAD_DIM = typename CollectiveMainloop::HEAD_DIM; + using BLK_K = typename CollectiveMainloop::BLK_K; static constexpr int kSharedStorageSize = cute::max(sizeof(typename CollectiveMainloop::SharedStorage), @@ -224,10 +231,10 @@ class Sm80KernelMha { const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord; const auto tidx = threadIdx.x; - // (q_packed_len, HEAD_DIM) + // (q_packed_len, BLK_K) detail::MHATile tile(params, batch_idx, kv_head_idx); auto [Q, O] = tile.template get_qo_tile(); - // (kv_len, HEAD_DIM) + // (kv_len, BLK_K) auto [K, V] = tile.template get_kv_tile(); // problem shape @@ -252,22 +259,22 @@ class Sm80KernelMha { const int n_block_min = kLocal ? kv_idx_min / kBlockN : 0; const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); - // (BLK_M, HEAD_DIM) - Tensor gQ = local_tile( - Q, Shape{}, make_coord(m_block_idx, _0{})); - Tensor gO = local_tile( - O, Shape{}, make_coord(m_block_idx, _0{})); - // (BLK_M, HEAD_DIM) => (M, K) + // (BLK_M, BLK_K) + Tensor gQ = + local_tile(Q, Shape{}, make_coord(m_block_idx, _0{})); + Tensor gO = + local_tile(O, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_M, BLK_K) => (M, K) Tensor cQ = local_tile(make_identity_tensor(Q.shape()), - Shape{}, + Shape{}, make_coord(m_block_idx, _0{})); - // (BLK_N, HEAD_DIM, n) - Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); - Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); - // (BLK_N, HEAD_DIM, n) => (N, K) + // (BLK_N, BLK_K, n) + Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); + Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + // (BLK_N, BLK_K, n) => (N, K) Tensor cKV = local_tile(make_identity_tensor(K.shape()), - Shape{}, + Shape{}, make_coord(_, _0{})); // (BLK_M, BLK_N, n) => (M, N) @@ -278,7 +285,7 @@ class Sm80KernelMha { TiledMma tiled_mma; // accumulator: (MMA,MMA_M,MMA_K) - auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); + auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); clear(tOrAccO); auto thr_mma = tiled_mma.get_slice(tidx); diff --git a/src/kernels/attention/tests/sm120_fmha_test.cu b/src/kernels/attention/tests/sm120_fmha_test.cu index b6879882..bd481276 100644 --- a/src/kernels/attention/tests/sm120_fmha_test.cu +++ b/src/kernels/attention/tests/sm120_fmha_test.cu @@ -133,18 +133,21 @@ TEST_P(MHAKernelTest, FMHA) { const auto options = torch::dtype(dtype).device(torch::kCUDA); // construct non-contiguous query, key and value - const auto data = torch::randn( - {batch_size, q_len, n_heads + 2 * n_kv_heads, head_dim}, options); - const auto qkv = - data.split(/*split_size=*/{n_heads, n_kv_heads, n_kv_heads}, /*dim=*/2); - const auto& query = qkv[0]; - const auto& key = qkv[1]; - const auto& value = qkv[2]; + const auto& query = + torch::randn({batch_size, q_len, n_heads, head_dim}, options); + + const auto kv = + torch::randn({batch_size, kv_len, 2 * n_kv_heads, head_dim}, options) + .chunk(/*chunks=*/2, /*dim=*/2); + const auto& key = kv[0]; + const auto& value = kv[1]; torch::optional alibi_slopes; if (alibi) { - alibi_slopes = torch::randn( - {n_heads}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + alibi_slopes = + torch::randn({n_heads}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)) / + kv_len; } auto ref_out = mha_batch_ref( diff --git a/src/kernels/attention/tests/sm80_mha_test.cu b/src/kernels/attention/tests/sm80_mha_test.cu index 8f554117..99565980 100644 --- a/src/kernels/attention/tests/sm80_mha_test.cu +++ b/src/kernels/attention/tests/sm80_mha_test.cu @@ -104,18 +104,22 @@ TEST_P(MHAKernelTest, MHA) { const auto options = torch::dtype(dtype).device(torch::kCUDA); // construct non-contiguous query, key and value - const auto data = torch::randn( - {batch_size, q_len, n_heads + 2 * n_kv_heads, head_dim}, options); - const auto qkv = - data.split(/*split_size=*/{n_heads, n_kv_heads, n_kv_heads}, /*dim=*/2); - const auto& query = qkv[0]; - const auto& key = qkv[1]; - const auto& value = qkv[2]; + const auto& query = + torch::randn({batch_size, q_len, n_heads, head_dim}, options); + + const auto data = + torch::randn({batch_size, kv_len, 2 * n_kv_heads, head_dim}, options); + const auto kv = + data.split(/*split_size=*/{n_kv_heads, n_kv_heads}, /*dim=*/2); + const auto& key = kv[0]; + const auto& value = kv[1]; torch::optional alibi_slopes; if (alibi) { - alibi_slopes = torch::randn( - {n_heads}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + alibi_slopes = + torch::randn({n_heads}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)) / + kv_len; } auto ref_out = mha_batch_ref(