diff --git a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh index 1e1027cc..1de82b22 100644 --- a/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh @@ -59,6 +59,10 @@ struct Sm120CollectiveFMhaWs { static constexpr int kRowsPerMMA = 2; static constexpr int kMmaThreads = size(TiledMma{}); + // TODO: tune the number of stages based on smem size + static constexpr int StageCountQ = 1; + static constexpr int StageCountKV = 3; + // Atom layout: (8, BLK_K):(BLK_K, 1) k-major using SmemLayoutAtom_ = decltype(composition(Swizzle<3, 3, 3>{}, @@ -68,14 +72,16 @@ struct Sm120CollectiveFMhaWs { using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); - // KV smem: (BLK_N, HEAD_DIM) + // KV smem: (BLK_N, HEAD_DIM, KVStages) using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, + Shape>{})); using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + decltype(tile_to_shape(SmemLayoutAtom_{}, + Shape>{})); - // V^T smem: (HEAD_DIM, BLK_N) - using SmemLayoutVt = decltype(select<1, 0>(SmemLayoutV{})); + // V^T smem: (HEAD_DIM, BLK_N, KVStages) + using SmemLayoutVt = decltype(select<1, 0, 2>(SmemLayoutV{})); // s2r tiled copy for gemm-I using SmemTiledCopyQ = @@ -92,17 +98,13 @@ struct Sm120CollectiveFMhaWs { struct TensorStorage { cute::array_aligned> smem_q; - cute::array_aligned> smem_k; union { + cute::array_aligned> smem_k; cute::array_aligned> smem_v; cute::array_aligned> smem_vt; }; }; - // TODO: tune the number of stages based on smem size - static constexpr int StageCountQ = 1; - static constexpr int StageCountKV = 2; - using PipelineQ = cutlass::PipelineAsync; using PipelineKV = cutlass::PipelineAsync; @@ -201,36 +203,43 @@ struct Sm120CollectiveFMhaWs { // Construct smem tensors // (BLK_M, HEAD_DIM), k-major Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM), k-major + // (BLK_N, HEAD_DIM, KVStages), k-major Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); // Tensor for V^t; used in GEMM-II. - // (HEAD_DIM, BLK_N), k-major + // (HEAD_DIM, BLK_N, KVStages), k-major Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{}); TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(tidx); // GEMM-I: S = Q@K.T - auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + // (MMA,MMA_M,MMA_K) + auto tSrQ = thr_mma.partition_fragment_A(sQ); + // (MMA,MMA_N,MMA_K) + auto tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{})); // s2r tiled copy for qkv // copy query to rmem SmemTiledCopyQ smem_tiled_copy_Q; - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx); + // (CPY,CPY_M,CPY_K) auto tSsQ = smem_thr_copy_Q.partition_S(sQ); + // (CPY,CPY_M,CPY_K) auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); SmemTiledCopyK smem_tiled_copy_K; - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx); + // (CPY,CPY_N,CPY_K, KVStages) auto tSsK = smem_thr_copy_K.partition_S(sK); + // (CPY,CPY_N,CPY_K) auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); // S = Q@K.T // tSrAccS: (MMA,MMA_M,MMA_N) - auto compute_qk = [&](auto& tSrAccS) { + auto compute_qk = [&](auto& tSrAccS, int stage) { + auto tSsK_s = tSsK(_, _, _, stage); // prefetch key cute::copy( - smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{})); + smem_tiled_copy_K, tSsK_s(_, _, _0{}), tSrK_copy_view(_, _, _0{})); CUTE_UNROLL for (int ki = 0; ki < size<2>(tSrQ); ++ki) { @@ -238,7 +247,7 @@ struct Sm120CollectiveFMhaWs { if (ki != size<2>(tSrQ) - 1) { const auto next_ki = ki + 1; cute::copy(smem_tiled_copy_K, - tSsK(_, _, next_ki), + tSsK_s(_, _, next_ki), tSrK_copy_view(_, _, next_ki)); } cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); @@ -246,17 +255,20 @@ struct Sm120CollectiveFMhaWs { }; // GEMM-II: O = softmax(S)@V - auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) + // (MMA,MMA_K,MMA_N) + auto tOrVt = thr_mma.partition_fragment_B(sVt(_, _, _0{})); SmemTiledCopyVt smem_tiled_copy_Vt; - auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx); + // (CPY,CPY_K,CPY_N, KVStages) auto tOsVt = smem_thr_copy_Vt.partition_S(sVt); + // (CPY,CPY_K,CPY_N) auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); // O = softmax(S)*V // tSrAccS: (MMA,MMA_M,MMA_N) // tOrAccO: (MMA,MMA_M,MMA_K) - auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) { + auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO, int stage) { // cast scores from Accumulator to Element auto tSrS = make_tensor_like(tSrAccS); fast_cast(tSrAccS, tSrS); @@ -264,17 +276,18 @@ struct Sm120CollectiveFMhaWs { // convert layout from gemm-I C to gemm-II A auto tOrS = make_tensor(tSrS.data(), LayoutConvertor::to_mma_a(tSrS.layout())); - + // (CPY,CPY_M,CPY_K) + auto tOsVt_s = tOsVt(_, _, _, stage); // prefetch V^t cute::copy( - smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); + smem_tiled_copy_Vt, tOsVt_s(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); CUTE_UNROLL for (int ki = 0; ki < size<2>(tOrS); ++ki) { // prefetch next V^t if (ki != size<2>(tOrS) - 1) { const auto next_ki = ki + 1; cute::copy(smem_tiled_copy_Vt, - tOsVt(_, _, next_ki), + tOsVt_s(_, _, next_ki), tOrVt_copy_view(_, _, next_ki)); } cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); @@ -321,7 +334,7 @@ struct Sm120CollectiveFMhaWs { kv_pipeline.consumer_wait(kv_state); // 1> S = Q@K.T - compute_qk(tSrS); + compute_qk(tSrS, kv_state.index()); // release key smem kv_pipeline.consumer_release(kv_state); @@ -345,7 +358,7 @@ struct Sm120CollectiveFMhaWs { kv_pipeline.consumer_wait(kv_state); // 2> O = softmax(S)*V - compute_sv(tSrS, tOrO); + compute_sv(tSrS, tOrO, kv_state.index()); // release value smem kv_pipeline.consumer_release(kv_state); diff --git a/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh b/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh index d4180ed8..1c8e1bbe 100644 --- a/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh @@ -58,7 +58,7 @@ struct Sm120CollectiveLoadCpAsyncWs { // Construct smem tensors // (BLK_M, HEAD_DIM), k-major Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM), k-major + // (BLK_N, HEAD_DIM, KVStages), k-major Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{}); @@ -74,7 +74,7 @@ struct Sm120CollectiveLoadCpAsyncWs { GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) Layout>{} // Val layout: 8 vals per read ); - auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); + auto gmem_thr_copy = gmem_tiled_copy.get_slice(tidx); // (CPY, CPY_N, CPY_K, n) => (N, K) Tensor tGcKV = gmem_thr_copy.partition_S(cKV); @@ -82,7 +82,7 @@ struct Sm120CollectiveLoadCpAsyncWs { Tensor tGgK = gmem_thr_copy.partition_S(gK); Tensor tGgV = gmem_thr_copy.partition_S(gV); - // (CPY, CPY_N, CPY_K) + // (CPY, CPY_N, CPY_K, KVStages) Tensor tGsK = gmem_thr_copy.partition_D(sK); Tensor tGsV = gmem_thr_copy.partition_D(sV); @@ -108,7 +108,7 @@ struct Sm120CollectiveLoadCpAsyncWs { safe_copy( gmem_tiled_copy, tGgK(_, _, _, ni), - tGsK, + tGsK(_, _, _, state.index()), tGcKV(_, _, _, ni), residue_nk); kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); @@ -121,7 +121,7 @@ struct Sm120CollectiveLoadCpAsyncWs { safe_copy( gmem_tiled_copy, tGgK(_, _, _, ni), - tGsK, + tGsK(_, _, _, state.index()), tGcKV(_, _, _, ni), residue_nk); kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); @@ -134,7 +134,7 @@ struct Sm120CollectiveLoadCpAsyncWs { safe_copy( gmem_tiled_copy, tGgV(_, _, _, ni), - tGsV, + tGsV(_, _, _, state.index()), tGcKV(_, _, _, ni), residue_nk); kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); @@ -147,7 +147,7 @@ struct Sm120CollectiveLoadCpAsyncWs { safe_copy( gmem_tiled_copy, tGgV(_, _, _, ni), - tGsV, + tGsV(_, _, _, state.index()), tGcKV(_, _, _, ni), residue_nk); kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); @@ -155,7 +155,7 @@ struct Sm120CollectiveLoadCpAsyncWs { }; // async copy gmem to smem in following order: - // Q1, Kn-1, Vn-1, ..., K2, V2, K1, V1 + // Q0, Kn-1, Vn-1, ..., K1, V1, K0, V0 // produce Q1 produce_query(q_state); diff --git a/src/kernels/attention/tests/sm120_fmha_test.cu b/src/kernels/attention/tests/sm120_fmha_test.cu index f76a07bc..b6879882 100644 --- a/src/kernels/attention/tests/sm120_fmha_test.cu +++ b/src/kernels/attention/tests/sm120_fmha_test.cu @@ -4,6 +4,7 @@ #include #include +#include "common/static_dispatch.h" #include "device/sm120_fmha_launch.cuh" #include "mha_params.h" #include "tests/mha_ref.h" @@ -81,15 +82,17 @@ torch::Tensor sm120_fmha( // normalize params that for performance optimization params.normalize(); - DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] { + DISPATCH_TORCH_DTYPE_(query.dtype(), Dtype, [&] { DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - sm120_launch_mha_kernel(params, nullptr); + DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] { + sm120_launch_mha_kernel(params, nullptr); + }); }); }); return out; @@ -140,7 +143,7 @@ TEST_P(MHAKernelTest, FMHA) { torch::optional alibi_slopes; if (alibi) { - alibi_slopes = torch::rand( + alibi_slopes = torch::randn( {n_heads}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); } @@ -159,16 +162,17 @@ TEST_P(MHAKernelTest, FMHA) { INSTANTIATE_TEST_SUITE_P( SM120, MHAKernelTest, - ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype - ::testing::Values(1), // batch_size - ::testing::Values(62), // q_len - ::testing::Values(127), // kv_len - ::testing::Values(6), // n_heads - ::testing::Values(6), // n_kv_heads - ::testing::Values(64), // head_dim - ::testing::Values(0.0), // logits_soft_cap - ::testing::Values(false), // alibi slope - ::testing::Values(-1) // sliding window - )); + ::testing::Combine( + ::testing::Values(torch::kHalf), // q_dtype + ::testing::Values(1, 2, 4), // batch_size + ::testing::Values(1, 62, 125), // q_len + ::testing::Values(127, 287, 1000), // kv_len + ::testing::Values(6), // n_heads + ::testing::Values(6 /*mha*/, 3 /*gqa*/, 1 /*mqa*/), // n_kv_heads + ::testing::Values(32, 64), // head_dim + ::testing::Values(0.0), // logits_soft_cap + ::testing::Values(false), // alibi slope + ::testing::Values(-1) // sliding window + )); } // namespace llm