diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index f1a24cfd..10d97757 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -4,7 +4,7 @@ include(cc_library) include(cc_test) cc_library( - NAME + NAME attention.template HDRS fast_math.h @@ -15,9 +15,8 @@ cc_library( static_dispatch.h mha_params.h mha_tile.h - mha_traits_sm80.h mha_kernel_sm80.cuh - mha_dispatch_sm80.cuh + sm80_mha_dispatch.cuh mla_params.h mla_tile.h mla_traits_sm80.h @@ -39,11 +38,11 @@ execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ COMMAND_ERROR_IS_FATAL ANY ) -# globbing all generated files in sub directory "generated" -file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu") +# globbing all generated files in sub directory "gensrc" +file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu") cc_library( - NAME + NAME attention.kernels HDRS attn_api.h @@ -62,9 +61,8 @@ cc_test( mha_kernel_test SRCS # mha_cpu_test.cpp - mha_traits_test.cpp - mha_kernel_sm80_test.cu - mha_kernel_sm80_pagedkv_test.cu + sm80_mha_test.cu + sm80_mha_pagedkv_test.cu DEPS :attention.template absl::random_random @@ -99,31 +97,31 @@ cc_test( ) nvbench_binary( - NAME - mha_sm80_bench - SRCS - mha_sm80_bench.cu + NAME + sm80_mha_bench + SRCS + sm80_mha_bench.cu DEPS - :attention.template + :attention.template ) nvbench_binary( - NAME - mha_sm80_pagedkv_bench - SRCS - mha_sm80_pagedkv_bench.cu + NAME + sm80_mha_pagedkv_bench + SRCS + sm80_mha_pagedkv_bench.cu DEPS absl::random_random :attention.template ) nvbench_binary( - NAME + NAME mla_sm80_bench - SRCS + SRCS mla_sm80_bench.cu DEPS - :attention.template + :attention.template ) -add_subdirectory(tools) \ No newline at end of file +add_subdirectory(tools) diff --git a/src/kernels/attention/attn_api.cpp b/src/kernels/attention/attn_api.cpp index b5501af6..0853b77d 100644 --- a/src/kernels/attention/attn_api.cpp +++ b/src/kernels/attention/attn_api.cpp @@ -3,8 +3,8 @@ #include #include "cute/layout.hpp" -#include "mha_dispatch_sm80.cuh" #include "mha_params.h" +#include "sm80_mha_dispatch.cuh" #include "static_dispatch.h" namespace llm { @@ -66,7 +66,7 @@ void paged_kv_varlen_mha( cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { DISPATCH_TORCH_DTYPE(query.scalar_type(), DTYPE, [&] { - run_mha_kernel_sm80(params, stream); + sm80_run_mha(params, stream); }); }); } diff --git a/src/kernels/attention/generate_instantiation_cu.py b/src/kernels/attention/generate_instantiation_cu.py index 9a22918e..fb1fb339 100755 --- a/src/kernels/attention/generate_instantiation_cu.py +++ b/src/kernels/attention/generate_instantiation_cu.py @@ -20,21 +20,21 @@ MHA_KERNEL_TEMPLATE = """ -#include "mha_kernel_sm80.cuh" // IWYU pragma: export +#include "sm80_mha_launch.cuh" // IWYU pragma: export #include "mha_params.h" // IWYU pragma: export namespace llm {{ using Params = MHAPagedKVParams; -template void launch_mha_kernel_sm80(const Params& params, - cudaStream_t stream); + cudaStream_t stream); }} // namespace llm """ @@ -79,7 +79,7 @@ def filename(self) -> str: def to_str(val: bool) -> str: return "1" if val else "0" - return f"mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu" + return f"sm80_mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}.cu" @dataclass @@ -164,7 +164,7 @@ def gen_mla_kernels() -> Iterator[MLAKernel]: if __name__ == "__main__": - output_dir = Path.cwd() / "generated" + output_dir = Path.cwd() / "gensrc" shutil.rmtree(output_dir, ignore_errors=True) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/kernels/attention/mha_kernel_sm80.cuh b/src/kernels/attention/mha_kernel_sm80.cuh deleted file mode 100644 index 5b57a281..00000000 --- a/src/kernels/attention/mha_kernel_sm80.cuh +++ /dev/null @@ -1,482 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -#include "cute/config.hpp" -#include "cute/container/array_aligned.hpp" -#include "cute_extensions.cuh" -#include "fast_cast.cuh" -#include "layout_convertor.h" -#include "mask.h" -#include "mha_tile.h" -#include "mha_traits_sm80.h" -#include "online_softmax.cuh" - -namespace llm { - -template -struct MHASharedStorage { - using DType = typename Traits::DType; - using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutK = typename Traits::SmemLayoutK; - using SmemLayoutV = typename Traits::SmemLayoutV; - using SmemLayoutVt = typename Traits::SmemLayoutVt; - using SmemLayoutO = typename Traits::SmemLayoutO; - - union { - union { - cute::array_aligned> q_smem; - struct { - cute::array_aligned> k_smem; - union { - cute::array_aligned> v_smem; - cute::array_aligned> vt_smem; - }; - }; - }; - - cute::array_aligned> o_smem; - }; -}; - -template -__global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80( - __grid_constant__ const Params params) { - using namespace cute; - - constexpr int kBlockM = Traits::kBlockM; - constexpr int kBlockN = Traits::kBlockN; - constexpr int kHeadDim = Traits::kHeadDim; - constexpr int kRowsPerMMA = Traits::kRowsPerMMA; - - using _BLK_M = Int; - using _BLK_N = Int; - using _HEAD_DIM = Int; - - // type alias - using DType = typename Traits::DType; - - using TiledMma = typename Traits::TiledMma; - - using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutK = typename Traits::SmemLayoutK; - using SmemLayoutV = typename Traits::SmemLayoutV; - using SmemLayoutVt = typename Traits::SmemLayoutVt; - using SmemLayoutO = typename Traits::SmemLayoutO; - using SharedStorage = MHASharedStorage; - - using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; - using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; - using GmemTiledCopyO = typename Traits::GmemTiledCopyO; - - using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; - using SmemTiledCopyK = typename Traits::SmemTiledCopyK; - using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; - using SmemTiledCopyO = typename Traits::SmemTiledCopyO; - - const int m_block_idx = blockIdx.x; - const int batch_idx = blockIdx.y; - const int kv_head_idx = blockIdx.z; - const int tidx = threadIdx.x; - - // preprocess input parameters - const int head_dim = params.head_dim; - const float logits_soft_cap = params.logits_soft_cap; - const float sm_scale = params.sm_scale; - const float sm_scale_log2 = params.sm_scale_log2; - - const auto& group_size = params.group_size; - - // ProblemShape - // (q_packed_len, HEAD_DIM) - MHATile tile(params, batch_idx, kv_head_idx); - auto [Q, O] = tile.template get_qo_tile(); - // (kv_len, HEAD_DIM) - auto [K, V] = tile.template get_kv_tile(); - - const int q_packed_len = size<0>(Q); - const int q_len = q_packed_len / group_size; - const int kv_len = size<0>(K); - - if (m_block_idx * kBlockM >= q_packed_len) { - // m out of bound, return - return; - } - - const int sliding_window = LOCAL ? params.sliding_window : kv_len; - - // Gmem - // (BLK_M, HEAD_DIM) - Tensor gQ = - local_tile(Q, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block_idx, _0{})); - Tensor gO = - local_tile(O, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block_idx, _0{})); - // (BLK_N, HEAD_DIM, n) - Tensor gK = local_tile(K, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); - Tensor gV = local_tile(V, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); - - // Smem - extern __shared__ char smem[]; - auto& ss = *reinterpret_cast(smem); - - // (BLK_M, HEAD_DIM), k-major - Tensor sQ = make_tensor(make_smem_ptr(ss.q_smem.data()), SmemLayoutQ{}); - // (BLK_N, HEAD_DIM), k-major - Tensor sK = make_tensor(make_smem_ptr(ss.k_smem.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(ss.v_smem.data()), SmemLayoutV{}); - - // Tensor for V^t; used in GEMM-II. - // (HEAD_DIM, BLK_N), m-major - Tensor sVt = make_tensor(make_smem_ptr(ss.vt_smem.data()), SmemLayoutVt{}); - - // (BLK_M, HEAD_DIM) - Tensor sO = make_tensor(make_smem_ptr(ss.o_smem.data()), SmemLayoutO{}); - - // Tiled Copy - // g2s tiled copy for qkv - GmemTiledCopyQ gmem_tiled_copy_Q; - GmemTiledCopyKV gmem_tiled_copy_KV; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); - auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); - - // coordinate tensor for oob handling - // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) - Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); - - auto produce_query = [&]() { - auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); - auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); - auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim); - safe_copy( - gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord); - }; - - // (BLK_N, HEAD_DIM) -> (blk_n, head_dim) - Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{}); - Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); - - Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); - auto produce_key = [&](int ni) { - auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - // skip ZFILL_MN for key since Mask will mask out oob with -inf - safe_copy( - gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); - }; - - // produce key without oob handling - auto produce_key_no_oob = [&](int ni) { - auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( - gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); - }; - - Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); - auto produce_value = [&](int ni) { - auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - // skipping ZFILL_MN for v may cause nan issue - safe_copy( - gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); - }; - - // produce value without oob handling - auto produce_value_no_oob = [&](int ni) { - auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( - gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); - }; - - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_slice(tidx); - // GEMM-I: S = Q@K.T - auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - - // s2r tiled copy for qkv - // copy query to rmem - SmemTiledCopyQ smem_tiled_copy_Q; - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - auto tSsQ = smem_thr_copy_Q.partition_S(sQ); - auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - - SmemTiledCopyK smem_tiled_copy_K; - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - auto tSsK = smem_thr_copy_K.partition_S(sK); - auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); - - // S = Q@K.T - // tSrAccS: (MMA,MMA_M,MMA_N) - auto compute_qk = [&](auto& tSrAccS) { - // prefetch key - cute::copy(smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{})); - - CUTE_UNROLL - for (int ki = 0; ki < size<2>(tSrQ); ++ki) { - // prefetch next key - if (ki != size<2>(tSrQ) - 1) { - const auto next_ki = ki + 1; - cute::copy(smem_tiled_copy_K, - tSsK(_, _, next_ki), - tSrK_copy_view(_, _, next_ki)); - } - cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); - } - }; - - // GEMM-II: O = softmax(S)@V - auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) - - SmemTiledCopyVt smem_tiled_copy_Vt; - auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); - auto tOsVt = smem_thr_copy_Vt.partition_S(sVt); - auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); - - // O = softmax(S)*V - // tSrAccS: (MMA,MMA_M,MMA_N) - // tOrAccO: (MMA,MMA_M,MMA_K) - auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) { - // cast scores from Accumulator to Element - auto tSrS = make_tensor_like(tSrAccS); - fast_cast(tSrAccS, tSrS); - - // convert layout from gemm-I C to gemm-II A - auto tOrS = - make_tensor(tSrS.data(), LayoutConvertor::to_mma_a(tSrS.layout())); - - // prefetch V^t - cute::copy( - smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); - CUTE_UNROLL - for (int ki = 0; ki < size<2>(tOrS); ++ki) { - // prefetch next V^t - if (ki != size<2>(tOrS) - 1) { - const auto next_ki = ki + 1; - cute::copy(smem_tiled_copy_Vt, - tOsVt(_, _, next_ki), - tOrVt_copy_view(_, _, next_ki)); - } - cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); - } - }; - - // tOrAccO: (MMA,MMA_M,MMA_K) - auto epilogue = [&](const auto& tOrAccO) { - // write output to gmem - // 1> cast output from ElementAccumulator to Element - auto tOrO = make_tensor_like(tOrAccO); - fast_cast(tOrAccO, tOrO); - - // 2. copy output from reg to smem - SmemTiledCopyO smem_tiled_copy_O; - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - auto taccOrO = smem_thr_copy_O.retile_S(tOrO); - auto taccOsO = smem_thr_copy_O.partition_D(sO); - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - - // 3. copy output from smem to gmem - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - - // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) - auto cO = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); - - auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) - auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) - // (CPY,CPY_M,CPY_K) -> (blk_m, head_dim) - auto tOcO = gmem_thr_copy_O.partition_D(cO); - - // wait for smem copy done before gmem copy - __syncthreads(); - - auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim); - safe_copy( - gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord); - }; - - // output accumulator, (MMA,MMA_M,MMA_K) - auto tOrO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); - auto tOrO_mn = - make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout())); - clear(tOrO); - - const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len; - // process kv in range: [kv_idx_min, kv_idx_max) - const int kv_idx_min = std::max(0, diagonal - sliding_window); - const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); - const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0; - const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); - - if (n_block_min >= n_block_max) { - // write output to gmem - epilogue(tOrO); - return; - } - - auto apply_logits_soft_cap = [&](auto& tSrAccS) { - if constexpr (SOFT_CAP) { - CUTE_UNROLL - for (int i = 0; i < size(tSrAccS); ++i) { - tSrAccS(i) = tanh(tSrAccS(i) * logits_soft_cap); - } - } - }; - - // ############### Prologue ############### - // produce query: [] => [q] - produce_query(); - cp_async_fence(); - - // wait g2s copy done for query - cp_async_wait<0>(); - __syncthreads(); - - // copy query from smem to rmem - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - // wait s2r copy done for query - __syncthreads(); - - // produce key: [q] => [q, k] - produce_key(n_block_max - 1); - cp_async_fence(); - - // ############### Mainloop ############### - constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; - const int n_blocks = n_block_max - n_block_min; - - // attention score accumulator, (MMA,MMA_M,MMA_N) - auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); - auto tSrS_mn = - make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout())); - - // identity tensor for score accumulator - auto tScS = - thr_mma.partition_C(make_identity_tensor(Shape<_BLK_M, _BLK_N>{})); - auto tScS_mn = - make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout())); - - constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); - using Softmax = OnlineSoftmax; - using Mask = Mask; - - Softmax softmax(sm_scale_log2); - Mask mask(q_len, kv_len, group_size, sliding_window); - if constexpr (ALIBI) { - mask.init_alibi(tScS_mn, - m_block_idx * kBlockM, - kv_head_idx, - sm_scale, - params.alibi_slopes_ptr); - } - - CUTE_NO_UNROLL - for (int i = 0; i < n_blocks; ++i) { - const int n_block_idx = n_block_max - 1 - i; - clear(tSrS); - - // wait key, queue: [q, k] => [] - cp_async_wait<0>(); - __syncthreads(); - - // produce value, [] => [v] - if (i == 0) { - produce_value(n_block_idx); - } else { - produce_value_no_oob(n_block_idx); - } - cp_async_fence(); - - // 1> S = Q@K.T - compute_qk(tSrS); - - // wait value, [v] => [] - cp_async_wait<0>(); - __syncthreads(); - - if constexpr (SOFT_CAP) { - apply_logits_soft_cap(tSrS); - } - - if (i < n_oob_mask) { - mask.apply( - tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); - } else { - mask.apply( - tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); - } - softmax.rescale(tSrS_mn, tOrO_mn); - - // produce next key: [] => [k] - if (n_block_idx > n_block_min) { - produce_key_no_oob(n_block_idx - 1); - } - cp_async_fence(); - - // 2> O = softmax(S)*V - compute_sv(tSrS, tOrO); - } - - // ############### Epilogue ############### - - // normalize output: o /= rowsum - softmax.finalize(tOrO_mn); - - // write output to gmem - epilogue(tOrO); -} - -template -void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) { - const auto batch_size = params.batch_size; - const auto n_kv_heads = params.n_kv_heads; - const auto max_q_packed_len = params.max_q_len * params.group_size; - - // TODO: tune block shape MNK based on the head dim and smem size - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability - // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| - // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | - // valid dynamic shared memory sizes for different compute capabilities: - // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 - // * 7.5 : 0, 32, 64 - // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 - // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 - // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 - // * 12.0 : 0, 8, 16, 32, 64, 100 - - constexpr int BLK_M = 64; - constexpr int BLK_N = 64; - constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32; - using Traits = MHATraitsSM80; - - const auto smem_size = sizeof(MHASharedStorage); - auto mha_kernel = - mha_kernel_sm80; - cudaFuncSetAttribute( - mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // TODO: support persistent kernels - dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), - batch_size, - n_kv_heads); - dim3 block = Traits::kThreadNum; - mha_kernel<<>>(params); -} - -} // namespace llm diff --git a/src/kernels/attention/mha_traits_sm80.h b/src/kernels/attention/mha_traits_sm80.h deleted file mode 100644 index 37b18684..00000000 --- a/src/kernels/attention/mha_traits_sm80.h +++ /dev/null @@ -1,112 +0,0 @@ -#pragma once -#include -#include - -#include "cute_extensions.cuh" - -namespace llm { -using namespace cute; - -template -struct MHATraitsSM80 { - // helpful aliases - static constexpr int kHeadDim = HEAD_DIM; - static constexpr int kBlockM = BLK_M; - static constexpr int kBlockN = BLK_N; - static constexpr int kBlockK = BLK_K; - static constexpr int kRowsPerMMA = 2; - - static_assert(kHeadDim % kBlockK == 0); - - using DType = DTYPE; - using _BLK_M = Int; - using _BLK_N = Int; - using _BLK_K = Int; - using _HEAD_DIM = Int; - - // ******* Mainloop ******* - // TiledMMA (64x16x16) for gemm-I and gemm-II - // choose MMA_Atom based on Element type - using MMA_Atom_ = - std::conditional_t, - MMA_Atom, - MMA_Atom>; - using TiledMma = TiledMMA>, // warp layout 4x1x1 - Tile<_64, _16, _16>>; // Prom Shape 64x16x16 - - // SMEM layout for QKV - // Atom layout: (8, BLK_K):(BLK_K, 1) k-major - using SmemLayoutAtom = - decltype(composition(Swizzle<3, 3, 3>{}, - Layout, Stride<_BLK_K, _1>>{})); - - // Q smem: (BLK_M, HEAD_DIM) - using SmemLayoutQ = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); - - // KV smem: (BLK_N, HEAD_DIM) - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - - using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - - // V^T smem: (HEAD_DIM, BLK_N) - using SmemLayoutVt = decltype(select<1, 0>(SmemLayoutV{})); - - // Thr layout for gmem copy - using GmemCopyThrLayout = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - - // Tiled copy for QKV - // g2s tiled copy for q - using GmemTiledCopyQ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); - - // g2s tiled copy for kv - using GmemTiledCopyKV = GmemTiledCopyQ; - - // s2r tiled copy for gemm-I - using SmemTiledCopyQ = - decltype(make_tiled_copy_A(Copy_Atom{}, - TiledMma{})); - using SmemTiledCopyK = - decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); - - // s2r tiled copy for gemm-II - using SmemTiledCopyVt = - decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); - - // ******* Epilogue ******* - - // O smem: (BLK_M, K):(K, 1), k-major, same as Q - using SmemLayoutO = SmemLayoutQ; - - // use 128-bit vectorizing copy - using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; - - // s2g tiled copy for O - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom{}, - GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); - - // r2s tiled copy for O - using SmemTiledCopyO = - decltype(make_tiled_copy_C(Copy_Atom{}, - TiledMma{})); - - // constexpr values for kernel launch - static constexpr size_t kThreadNum = size(TiledMma{}); -}; - -} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mha_traits_test.cpp b/src/kernels/attention/mha_traits_test.cpp deleted file mode 100644 index d7738880..00000000 --- a/src/kernels/attention/mha_traits_test.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include - -#include - -#include "cute/layout_composed.hpp" -#include "gather_tensor.hpp" -#include "mha_traits_sm80.h" - -namespace llm { - -using namespace cute; - -template -void test_mha_traits() { - // type alias - using TiledMma = typename Traits::TiledMma; - - using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutK = typename Traits::SmemLayoutK; - using SmemLayoutV = typename Traits::SmemLayoutV; - using SmemLayoutVt = typename Traits::SmemLayoutVt; - using SmemLayoutO = typename Traits::SmemLayoutO; - using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; - using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; - using GmemTiledCopyO = typename Traits::GmemTiledCopyO; - - using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; - using SmemTiledCopyK = typename Traits::SmemTiledCopyK; - using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; - using SmemTiledCopyO = typename Traits::SmemTiledCopyO; - - // test layout conversation - Tensor sQ = make_tensor(counting_iterator(0), SmemLayoutQ{}); - Tensor sK = make_tensor(counting_iterator(0), SmemLayoutK{}); - Tensor sV = make_tensor(counting_iterator(0), SmemLayoutV{}); - Tensor sVt = make_tensor(sV.data(), SmemLayoutVt{}); - - // print("sQ:"); print(sQ);print("\n"); - // print("sK:"); print(sK);print("\n"); - // print("sV:"); print(sV);print("\n"); - // print("sVt:"); print(sVt);print("\n"); - - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_slice(0); - auto tOrVt = thr_mma.partition_fragment_B(sVt); - // TODO: add tests for layout conformance -} - -TEST(MHATraitsTest, TraitsSM80) { - test_mha_traits>(); -} - -} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/sm80_collective_epilogue.cuh b/src/kernels/attention/sm80_collective_epilogue.cuh new file mode 100644 index 00000000..1e0a0182 --- /dev/null +++ b/src/kernels/attention/sm80_collective_epilogue.cuh @@ -0,0 +1,125 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "cute_extensions.cuh" +#include "fast_cast.cuh" + +namespace llm { +using namespace cute; + +template +struct Sm80CollectiveEpilogue { + using TileShape = TileShape_; + using Element = Element_; + + static constexpr int kHeadDim = HeadDim_; + static constexpr bool EVEN_K = EVEN_K_; + + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + + using BLK_M = Int; + using BLK_K = Int; + using HEAD_DIM = Int; + + using SmemLayoutAtom_ = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride>{})); + + // Q smem: (BLK_M, HEAD_DIM) + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + + // use 128-bit vectorizing copy + using VectorizingCopy_ = AutoVectorizingCopyWithAssumedAlignment<128>; + + // r2s copy atom for O + using SmemCopyAtom_ = Copy_Atom; + + // Thr layout for gmem copy + using GmemCopyThrLayout_ = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + + // s2g tiled copy for O + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom{}, + GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_o; + }; + + // Host side kernel arguments + struct Arguments {}; + + // Device side kernel params + using Params = Arguments; + + // Convert host side arguments to device side params + static Params to_underlying_arguments(Arguments const& args) { return args; } + + template + CUTE_DEVICE void operator()(const Params& /*params*/, + const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N) + TiledMma tiled_mma, + TensorO& gO, // (BLK_M, HEAD_DIM) + int tidx, + const BlockCoordMNK& block_coord_mnk, + const ProblemShapeMNK& problem_shape_mnk, + char* smem) { + static constexpr int kBlockM = get<0>(TileShape{}); + + const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; + + // Smem + auto& ss = *reinterpret_cast(smem); + // (BLK_M, HEAD_DIM) + Tensor sO = make_tensor(make_smem_ptr(ss.smem_o.data()), SmemLayoutO{}); + + // 1. cast output from ElementAccumulator to Element + auto tOrO = make_tensor_like(tOrAccO); + fast_cast(tOrAccO, tOrO); + + // 2. copy output from reg to smem + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtom_{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + auto tSrO = smem_thr_copy_O.retile_S(tOrO); + auto tSsO = smem_thr_copy_O.partition_D(sO); + cute::copy(smem_tiled_copy_O, tSrO, tSsO); + + // 3. copy output from smem to gmem + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + + // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) + auto cO = make_identity_tensor(Shape{}); + + auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) + auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) + // (CPY,CPY_M,CPY_K) -> (blk_m, head_dim) + auto tOcO = gmem_thr_copy_O.partition_D(cO); + + // wait for smem copy done before gmem copy + __syncthreads(); + + auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim); + safe_copy( + gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord); + } +}; +} // namespace llm diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/sm80_collective_mha.cuh new file mode 100644 index 00000000..68d61833 --- /dev/null +++ b/src/kernels/attention/sm80_collective_mha.cuh @@ -0,0 +1,434 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "cute_extensions.cuh" +#include "fast_cast.cuh" +#include "layout_convertor.h" +#include "mask.h" + +namespace llm { + +using namespace cute; + +template +struct Sm80CollectiveMha { + // TODO: multiple stages + using TileShape = TileShape_; + using Element = Element_; + using ElementAccum = float; + + static constexpr int kHeadDim = HeadDim_; + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockN = get<1>(TileShape{}); + static constexpr int kBlockK = get<2>(TileShape{}); + + static_assert(kBlockK == 32 || kBlockK == 64); + static_assert(kHeadDim % kBlockK == 0); + + using BLK_M = Int; + using BLK_N = Int; + using BLK_K = Int; + using HEAD_DIM = Int; + + // TiledMMA (64x16x16) for gemm-I and gemm-II + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + using TiledMma = TiledMMA>, // warp layout 4x1x1 + Tile<_64, _16, _16>>; // Tile Shape 64x16x16 + + static constexpr int kRowsPerMMA = 2; + static constexpr int kMmaThreads = size(TiledMma{}); + + // Atom layout: (8, BLK_K):(BLK_K, 1) k-major + using SmemLayoutAtom_ = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride>{})); + + // Q smem: (BLK_M, HEAD_DIM) + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + + // KV smem: (BLK_N, HEAD_DIM) + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtom_{}, Shape{})); + + // V^T smem: (HEAD_DIM, BLK_N) + using SmemLayoutVt = decltype(select<1, 0>(SmemLayoutV{})); + + // Thr layout for gmem copy + using GmemCopyThrLayout_ = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + + // g2s tiled copy for q + using GmemTiledCopyQ = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // g2s tiled copy for kv + using GmemTiledCopyKV = GmemTiledCopyQ; + + // s2r tiled copy for gemm-I + using SmemTiledCopyQ = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma{})); + using SmemTiledCopyK = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // s2r tiled copy for gemm-II + using SmemTiledCopyVt = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + struct SharedStorage : cute::aligned_struct<128> { + union { + cute::array_aligned> smem_q; + struct { + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_vt; + }; + }; + }; + }; + + // Host side arguments + struct Arguments { + // mask + int sliding_window = -1; + + // softcap + float logits_soft_cap = 0.0; + + // softmax scaling + float sm_scale = 1.0; + float sm_scale_log2 = 0.0; + + // alibi: (n_heads) + const float* __restrict__ alibi_slopes_ptr = nullptr; + + FastDivmod group_size; + }; + + // Device side params + using Params = Arguments; + + // Convert host side arguments to device side params + static Params to_underlying_arguments(Arguments const& args) { + // no convertion needed. + return args; + } + + // returns false if the block has been skipped + template + CUTE_DEVICE void operator()(const Params& params, + const TensorQ& gQ, // (BLK_M, HEAD_DIM) + const TensorK& gK, // (BLK_N, HEAD_DIM, n) + const TensorV& gV, // (BLK_N, HEAD_DIM, n) + FrgTensor& tOrO, // (MMA, MMA_M, MMA_N) + Softmax& softmax, + int tidx, + const BlockCoordMNK& block_coord_mnk, + const ProblemShapeMNK& problem_shape_mnk, + char* smem) { + static_assert(is_rmem::value, + "Accum tensor must be rmem resident."); + static_assert(is_gmem::value, "Q tensor must be gmem resident."); + static_assert(is_gmem::value, "K tensor must be gmem resident."); + static_assert(is_gmem::value, "V tensor must be gmem resident."); + + static constexpr int kBlockM = get<0>(TileShape{}); + static constexpr int kBlockN = get<1>(TileShape{}); + + const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk; + const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk; + + const int sliding_window = LOCAL ? params.sliding_window : kv_len; + const float logits_soft_cap = params.logits_soft_cap; + const float sm_scale = params.sm_scale; + const float sm_scale_log2 = params.sm_scale_log2; + const auto& group_size = params.group_size; + const int q_len = q_packed_len / group_size; + + // Construct shared memory tiles + auto& ss = *reinterpret_cast(smem); + + // (BLK_M, HEAD_DIM), k-major + Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{}); + // (BLK_N, HEAD_DIM), k-major + Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{}); + + // Tensor for V^t; used in GEMM-II. + // (HEAD_DIM, BLK_N), k-major + Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{}); + + // g2s tiled copy for qkv + GmemTiledCopyQ gmem_tiled_copy_Q; + GmemTiledCopyKV gmem_tiled_copy_KV; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + // coordinate tensor for oob handling + // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) + Tensor cQ = make_identity_tensor(Shape{}); + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); + // (BLK_N, HEAD_DIM) -> (blk_n, head_dim) + Tensor cKV = make_identity_tensor(Shape{}); + Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); + + Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); + Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(tidx); + // GEMM-I: S = Q@K.T + auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + // s2r tiled copy for qkv + // copy query to rmem + SmemTiledCopyQ smem_tiled_copy_Q; + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + auto tSsQ = smem_thr_copy_Q.partition_S(sQ); + auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + + SmemTiledCopyK smem_tiled_copy_K; + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + auto tSsK = smem_thr_copy_K.partition_S(sK); + auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); + + // S = Q@K.T + // tSrAccS: (MMA,MMA_M,MMA_N) + auto compute_qk = [&](auto& tSrAccS) { + // prefetch key + cute::copy( + smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{})); + + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tSrQ); ++ki) { + // prefetch next key + if (ki != size<2>(tSrQ) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_K, + tSsK(_, _, next_ki), + tSrK_copy_view(_, _, next_ki)); + } + cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); + } + }; + + // GEMM-II: O = softmax(S)@V + auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) + + SmemTiledCopyVt smem_tiled_copy_Vt; + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); + auto tOsVt = smem_thr_copy_Vt.partition_S(sVt); + auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); + + // O = softmax(S)*V + // tSrAccS: (MMA,MMA_M,MMA_N) + // tOrAccO: (MMA,MMA_M,MMA_K) + auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) { + // cast scores from Accumulator to Element + auto tSrS = make_tensor_like(tSrAccS); + fast_cast(tSrAccS, tSrS); + + // convert layout from gemm-I C to gemm-II A + auto tOrS = + make_tensor(tSrS.data(), LayoutConvertor::to_mma_a(tSrS.layout())); + + // prefetch V^t + cute::copy( + smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tOrS); ++ki) { + // prefetch next V^t + if (ki != size<2>(tOrS) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_Vt, + tOsVt(_, _, next_ki), + tOrVt_copy_view(_, _, next_ki)); + } + cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); + } + }; + + auto tOrO_mn = + make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout())); + + const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len; + // process kv in range: [kv_idx_min, kv_idx_max) + const int kv_idx_min = std::max(0, diagonal - sliding_window); + const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); + const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0; + const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); + + if (n_block_min >= n_block_max) { + // no kv blocks to process + return; + } + + auto apply_logits_soft_cap = [&](auto& tSrAccS) { + if constexpr (SOFT_CAP) { + CUTE_UNROLL + for (int i = 0; i < size(tSrAccS); ++i) { + tSrAccS(i) = tanh(tSrAccS(i) * logits_soft_cap); + } + } + }; + + // ############### Prologue ############### + // produce query: [] => [q] + auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); + auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); + auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim); + safe_copy( + gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord); + cp_async_fence(); + + // wait g2s copy done for query + cp_async_wait<0>(); + __syncthreads(); + + // copy query from smem to rmem + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + // wait s2r copy done for query + __syncthreads(); + + // produce key: [q] => [q, k] + { + const int ni = n_block_max - 1; + auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); + auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); + // skip ZFILL_MN for key since Mask will mask out oob with -inf + safe_copy( + gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); + } + + cp_async_fence(); + + // ############### Mainloop ############### + constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; + const int n_blocks = n_block_max - n_block_min; + + // attention score accumulator, (MMA,MMA_M,MMA_N) + auto tSrS = partition_fragment_C(tiled_mma, Shape{}); + auto tSrS_mn = + make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout())); + + // identity tensor for score accumulator + auto tScS = + thr_mma.partition_C(make_identity_tensor(Shape{})); + auto tScS_mn = + make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout())); + + constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); + using Mask = Mask; + Mask mask(q_len, kv_len, group_size, sliding_window); + if constexpr (ALIBI) { + mask.init_alibi(tScS_mn, + m_block_idx * kBlockM, + kv_head_idx, + sm_scale, + params.alibi_slopes_ptr); + } + + CUTE_NO_UNROLL + for (int i = 0; i < n_blocks; ++i) { + const int n_block_idx = n_block_max - 1 - i; + clear(tSrS); + + // wait key, queue: [q, k] => [] + cp_async_wait<0>(); + __syncthreads(); + + // produce value, [] => [v] + auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, n_block_idx)); + auto max_coord = make_coord(kv_len - n_block_idx * kBlockN, head_dim); + if (i == 0) { + safe_copy( + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + + } else { // without oob handling + safe_copy( + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + } + cp_async_fence(); + + // 1> S = Q@K.T + compute_qk(tSrS); + + // wait value, [v] => [] + cp_async_wait<0>(); + __syncthreads(); + + if constexpr (SOFT_CAP) { + apply_logits_soft_cap(tSrS); + } + + if (i < n_oob_mask) { + mask.template apply( + tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); + } else { + mask.template apply( + tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); + } + softmax.rescale(tSrS_mn, tOrO_mn); + + // produce next key: [] => [k] + if (n_block_idx > n_block_min) { + // without oob handling + const int ni = n_block_idx - 1; + auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); + auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); + safe_copy( + gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); + } + cp_async_fence(); + + // 2> O = softmax(S)*V + compute_sv(tSrS, tOrO); + } + + // normalize output: o /= rowsum + softmax.finalize(tOrO_mn); + } +}; + +} // namespace llm diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/sm80_kernel_mha.cuh new file mode 100644 index 00000000..19d4cb64 --- /dev/null +++ b/src/kernels/attention/sm80_kernel_mha.cuh @@ -0,0 +1,123 @@ +#pragma once + +#include +#include + +#include +#include + +#include "mha_tile.h" +#include "online_softmax.cuh" + +namespace llm { + +using namespace cute; + +template +class Sm80KernelMha { + public: + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + + using TiledMma = typename CollectiveMainloop::TiledMma; + + using Element = typename CollectiveMainloop::Element; + using BLK_M = typename CollectiveMainloop::BLK_M; + using BLK_N = typename CollectiveMainloop::BLK_N; + using HEAD_DIM = typename CollectiveMainloop::HEAD_DIM; + + static constexpr int kBlockM = CollectiveMainloop::kBlockM; + + static constexpr int kRowsPerMMA = CollectiveMainloop::kRowsPerMMA; + + static constexpr int kSharedStorageSize = + cute::max(sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage)); + + static constexpr int kMmaThreads = CollectiveMainloop::kMmaThreads; + + // Kernel params + using MainloopParams = typename CollectiveMainloop::Params; + using EpilogueParams = typename CollectiveEpilogue::Params; + + template + CUTE_DEVICE void operator()(const Params& params, char* smem) { + CollectiveMainloop mha; + CollectiveEpilogue epilogue; + + const auto tidx = threadIdx.x; + + // block coord + const int m_block_idx = blockIdx.x; + const int batch_idx = blockIdx.y; + const int kv_head_idx = blockIdx.z; + auto block_coord_mnk = make_coord(m_block_idx, batch_idx, kv_head_idx); + + // (q_packed_len, HEAD_DIM) + MHATile tile(params, batch_idx, kv_head_idx); + auto [Q, O] = tile.template get_qo_tile(); + // (kv_len, HEAD_DIM) + auto [K, V] = tile.template get_kv_tile(); + + // problem shape + const int q_packed_len = size<0>(Q); + const int kv_len = size<0>(K); + const int head_dim = params.head_dim; + auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim); + + if (m_block_idx * kBlockM >= q_packed_len) { + // m out of bound, return + return; + } + + // (BLK_M, HEAD_DIM) + Tensor gQ = + local_tile(Q, Shape{}, make_coord(m_block_idx, _0{})); + Tensor gO = + local_tile(O, Shape{}, make_coord(m_block_idx, _0{})); + // (BLK_N, HEAD_DIM, n) + Tensor gK = local_tile(K, Shape{}, make_coord(_, _0{})); + Tensor gV = local_tile(V, Shape{}, make_coord(_, _0{})); + + // construct params + MainloopParams mainloop_params{params.sliding_window, + params.logits_soft_cap, + params.sm_scale, + params.sm_scale_log2, + params.alibi_slopes_ptr, + params.group_size}; + EpilogueParams epilogue_params; + + TiledMma tiled_mma; + // accumulator: MMA,MMA_M,MMA_K) + auto tOrAccO = partition_fragment_C(tiled_mma, Shape{}); + clear(tOrAccO); + + constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO); + OnlineSoftmax softmax(params.sm_scale_log2); + + // mainloop + mha(mainloop_params, + gQ, + gK, + gV, + tOrAccO, + softmax, + tidx, + block_coord_mnk, + problem_shape_mnk, + smem); + + // epilogue + epilogue(epilogue_params, + tOrAccO, + tiled_mma, + gO, + tidx, + block_coord_mnk, + problem_shape_mnk, + smem); + } +}; + +} // namespace llm diff --git a/src/kernels/attention/mha_sm80_bench.cu b/src/kernels/attention/sm80_mha_bench.cu similarity index 94% rename from src/kernels/attention/mha_sm80_bench.cu rename to src/kernels/attention/sm80_mha_bench.cu index a279453a..1e9faa3f 100644 --- a/src/kernels/attention/mha_sm80_bench.cu +++ b/src/kernels/attention/sm80_mha_bench.cu @@ -4,9 +4,9 @@ #include #include -#include "mha_dispatch_sm80.cuh" -#include "mha_kernel_sm80.cuh" // IWYU pragma: keep #include "mha_params.h" +#include "sm80_mha_dispatch.cuh" +#include "sm80_mha_launch.cuh" // IWYU pragma: keep #include "static_dispatch.h" using namespace llm; @@ -73,7 +73,7 @@ void mha_bench_sm80(nvbench::state& state) { state.exec([&](nvbench::launch& launch) { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { - run_mha_kernel_sm80(params, launch.get_stream()); + sm80_run_mha(params, launch.get_stream()); }); }); } diff --git a/src/kernels/attention/mha_dispatch_sm80.cuh b/src/kernels/attention/sm80_mha_dispatch.cuh similarity index 86% rename from src/kernels/attention/mha_dispatch_sm80.cuh rename to src/kernels/attention/sm80_mha_dispatch.cuh index 6c990329..30136b7a 100644 --- a/src/kernels/attention/mha_dispatch_sm80.cuh +++ b/src/kernels/attention/sm80_mha_dispatch.cuh @@ -14,11 +14,11 @@ template -void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream); +void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream); // user-facing function to run the attention kernel template -void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { +void sm80_run_mha(Params& params, cudaStream_t stream = nullptr) { // normalize params that for performance optimization params.normalize(); @@ -27,7 +27,7 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] { DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] { DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] { - launch_mha_kernel_sm80 +#include + +#include +#include + +#include "sm80_collective_epilogue.cuh" +#include "sm80_collective_mha.cuh" +#include "sm80_kernel_mha.cuh" + +namespace llm { + +namespace detail { +/// Generic kernel template. +template +__global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel( + __grid_constant__ const Params params) { + extern __shared__ char smem[]; + Operator op; + op(params, smem); +} +} // namespace detail + +template +void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) { + const auto batch_size = params.batch_size; + const auto n_kv_heads = params.n_kv_heads; + const auto max_q_packed_len = params.max_q_len * params.group_size; + + // TODO: tune block shape MNK based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| + // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | + // valid dynamic shared memory sizes for different compute capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 + constexpr int BLK_M = 64; + constexpr int BLK_N = 64; + constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32; + + using TileShape = Shape, Int, Int>; + using CollectiveMainloop = Sm80CollectiveMha; + using CollectiveEpilogue = + Sm80CollectiveEpilogue; + + using AttnKernel = Sm80KernelMha; + + auto mha_kernel = detail::device_kernel; + + const auto smem_size = AttnKernel::kSharedStorageSize; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + + // TODO: support persistent kernels + dim3 grid(cute::ceil_div(max_q_packed_len, BLK_M), batch_size, n_kv_heads); + dim3 block = AttnKernel::kMmaThreads; + + mha_kernel<<>>(params); + // TODO: check launch status +} + +} // namespace llm diff --git a/src/kernels/attention/mha_sm80_pagedkv_bench.cu b/src/kernels/attention/sm80_mha_pagedkv_bench.cu similarity index 96% rename from src/kernels/attention/mha_sm80_pagedkv_bench.cu rename to src/kernels/attention/sm80_mha_pagedkv_bench.cu index 08891818..7e8549c3 100644 --- a/src/kernels/attention/mha_sm80_pagedkv_bench.cu +++ b/src/kernels/attention/sm80_mha_pagedkv_bench.cu @@ -5,9 +5,9 @@ #include #include -#include "mha_dispatch_sm80.cuh" -#include "mha_kernel_sm80.cuh" // IWYU pragma: keep #include "mha_params.h" +#include "sm80_mha_dispatch.cuh" +#include "sm80_mha_launch.cuh" // IWYU pragma: keep #include "static_dispatch.h" using namespace llm; @@ -120,7 +120,7 @@ void mha_bench_sm80(nvbench::state& state) { state.exec([&](nvbench::launch& launch) { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { - run_mha_kernel_sm80(params, launch.get_stream()); + sm80_run_mha(params, launch.get_stream()); }); }); } diff --git a/src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/sm80_mha_pagedkv_test.cu similarity index 98% rename from src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu rename to src/kernels/attention/sm80_mha_pagedkv_test.cu index eca9dc9e..afa69d45 100644 --- a/src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/sm80_mha_pagedkv_test.cu @@ -3,10 +3,10 @@ #include #include "cute/layout.hpp" -#include "mha_dispatch_sm80.cuh" -#include "mha_kernel_sm80.cuh" // IWYU pragma: keep #include "mha_params.h" #include "mha_ref.h" +#include "sm80_mha_dispatch.cuh" +#include "sm80_mha_launch.cuh" // IWYU pragma: keep namespace llm { #define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ @@ -76,7 +76,7 @@ torch::Tensor mha_pagedkv_sm80( params.block_size = block_size; DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - run_mha_kernel_sm80(params); + sm80_run_mha(params); }); return out; } diff --git a/src/kernels/attention/mha_kernel_sm80_test.cu b/src/kernels/attention/sm80_mha_test.cu similarity index 97% rename from src/kernels/attention/mha_kernel_sm80_test.cu rename to src/kernels/attention/sm80_mha_test.cu index ac154211..17ed94ef 100644 --- a/src/kernels/attention/mha_kernel_sm80_test.cu +++ b/src/kernels/attention/sm80_mha_test.cu @@ -4,10 +4,10 @@ #include #include "cute/layout.hpp" -#include "mha_dispatch_sm80.cuh" -#include "mha_kernel_sm80.cuh" // IWYU pragma: keep #include "mha_params.h" #include "mha_ref.h" +#include "sm80_mha_dispatch.cuh" +#include "sm80_mha_launch.cuh" // IWYU pragma: keep namespace llm { #define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ @@ -86,9 +86,8 @@ torch::Tensor mha_sm80( params.sliding_window = sliding_window; DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] { - DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - run_mha_kernel_sm80(params); - }); + DISPATCH_HEAD_DIM_( + head_dim, HEAD_DIM, [&] { sm80_run_mha(params); }); }); return out; } diff --git a/src/kernels/attention/traits_sm80/g2s_tiled_copy_kv.svg b/src/kernels/attention/sm80_mha_traits/g2s_tiled_copy_kv.svg similarity index 100% rename from src/kernels/attention/traits_sm80/g2s_tiled_copy_kv.svg rename to src/kernels/attention/sm80_mha_traits/g2s_tiled_copy_kv.svg diff --git a/src/kernels/attention/traits_sm80/g2s_tiled_copy_q.svg b/src/kernels/attention/sm80_mha_traits/g2s_tiled_copy_q.svg similarity index 100% rename from src/kernels/attention/traits_sm80/g2s_tiled_copy_q.svg rename to src/kernels/attention/sm80_mha_traits/g2s_tiled_copy_q.svg diff --git a/src/kernels/attention/traits_sm80/r2s_tiled_copy_o.svg b/src/kernels/attention/sm80_mha_traits/r2s_tiled_copy_o.svg similarity index 100% rename from src/kernels/attention/traits_sm80/r2s_tiled_copy_o.svg rename to src/kernels/attention/sm80_mha_traits/r2s_tiled_copy_o.svg diff --git a/src/kernels/attention/traits_sm80/s2g_tiled_copy_o.svg b/src/kernels/attention/sm80_mha_traits/s2g_tiled_copy_o.svg similarity index 100% rename from src/kernels/attention/traits_sm80/s2g_tiled_copy_o.svg rename to src/kernels/attention/sm80_mha_traits/s2g_tiled_copy_o.svg diff --git a/src/kernels/attention/traits_sm80/s2r_tiled_copy_k.svg b/src/kernels/attention/sm80_mha_traits/s2r_tiled_copy_k.svg similarity index 100% rename from src/kernels/attention/traits_sm80/s2r_tiled_copy_k.svg rename to src/kernels/attention/sm80_mha_traits/s2r_tiled_copy_k.svg diff --git a/src/kernels/attention/traits_sm80/s2r_tiled_copy_q.svg b/src/kernels/attention/sm80_mha_traits/s2r_tiled_copy_q.svg similarity index 100% rename from src/kernels/attention/traits_sm80/s2r_tiled_copy_q.svg rename to src/kernels/attention/sm80_mha_traits/s2r_tiled_copy_q.svg diff --git a/src/kernels/attention/traits_sm80/s2r_tiled_copy_vt.svg b/src/kernels/attention/sm80_mha_traits/s2r_tiled_copy_vt.svg similarity index 100% rename from src/kernels/attention/traits_sm80/s2r_tiled_copy_vt.svg rename to src/kernels/attention/sm80_mha_traits/s2r_tiled_copy_vt.svg diff --git a/src/kernels/attention/traits_sm80/smem_layout_k.svg b/src/kernels/attention/sm80_mha_traits/smem_layout_k.svg similarity index 100% rename from src/kernels/attention/traits_sm80/smem_layout_k.svg rename to src/kernels/attention/sm80_mha_traits/smem_layout_k.svg diff --git a/src/kernels/attention/traits_sm80/smem_layout_o.svg b/src/kernels/attention/sm80_mha_traits/smem_layout_o.svg similarity index 100% rename from src/kernels/attention/traits_sm80/smem_layout_o.svg rename to src/kernels/attention/sm80_mha_traits/smem_layout_o.svg diff --git a/src/kernels/attention/traits_sm80/smem_layout_q.svg b/src/kernels/attention/sm80_mha_traits/smem_layout_q.svg similarity index 100% rename from src/kernels/attention/traits_sm80/smem_layout_q.svg rename to src/kernels/attention/sm80_mha_traits/smem_layout_q.svg diff --git a/src/kernels/attention/traits_sm80/smem_layout_vt.svg b/src/kernels/attention/sm80_mha_traits/smem_layout_vt.svg similarity index 100% rename from src/kernels/attention/traits_sm80/smem_layout_vt.svg rename to src/kernels/attention/sm80_mha_traits/smem_layout_vt.svg diff --git a/src/kernels/attention/traits_sm80/tiled_mma.svg b/src/kernels/attention/sm80_mha_traits/tiled_mma.svg similarity index 100% rename from src/kernels/attention/traits_sm80/tiled_mma.svg rename to src/kernels/attention/sm80_mha_traits/tiled_mma.svg diff --git a/src/kernels/attention/tools/CMakeLists.txt b/src/kernels/attention/tools/CMakeLists.txt index 64032d65..529cd9f8 100644 --- a/src/kernels/attention/tools/CMakeLists.txt +++ b/src/kernels/attention/tools/CMakeLists.txt @@ -1,13 +1,13 @@ include(cc_binary) -cc_binary( - NAME - mha_traits_viewer - SRCS - mha_traits_viewer.cpp - DEPS - :common - cutlass - absl::strings - absl::str_format -) \ No newline at end of file +# cc_binary( +# NAME +# mha_traits_viewer +# SRCS +# mha_traits_viewer.cpp +# DEPS +# :common +# cutlass +# absl::strings +# absl::str_format +# ) diff --git a/src/kernels/attention/tools/mha_traits_viewer.cpp b/src/kernels/attention/tools/mha_traits_viewer.cpp index 8f616321..33c4ba74 100644 --- a/src/kernels/attention/tools/mha_traits_viewer.cpp +++ b/src/kernels/attention/tools/mha_traits_viewer.cpp @@ -1,7 +1,6 @@ #include #include -#include "../mha_traits_sm80.h" #include "common/pretty_print.h" #include "print_svg.hpp" @@ -127,4 +126,4 @@ int main(int argc, char** argv) { print_attn_traits(); return 0; -} \ No newline at end of file +} diff --git a/src/kernels/quantization/marlin/CMakeLists.txt b/src/kernels/quantization/marlin/CMakeLists.txt index 0213b35e..43e5a320 100644 --- a/src/kernels/quantization/marlin/CMakeLists.txt +++ b/src/kernels/quantization/marlin/CMakeLists.txt @@ -12,8 +12,8 @@ execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ ) -# globbing all generated files in sub directory "generated" -file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu") +# globbing all generated files in sub directory "gensrc" +file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu") cc_library( NAME @@ -32,4 +32,3 @@ cc_library( torch glog::glog ) - diff --git a/src/kernels/quantization/marlin/generate_instantiations.py b/src/kernels/quantization/marlin/generate_instantiations.py index 40ee45a1..0a75c6ca 100755 --- a/src/kernels/quantization/marlin/generate_instantiations.py +++ b/src/kernels/quantization/marlin/generate_instantiations.py @@ -104,7 +104,8 @@ def template(self) -> str: @property def filename(self) -> str: - return f"marlin_b{self.num_bits}_t{self.threads}_m{self.m_blocks}_n{self.n_blocks}_k{self.k_blocks}_s{self.stages}_{self.has_act_order}_{self.has_zp}_g{self.group_blocks}_sm80.cu" + return f"sm80_marlin_b{self.num_bits}_t{self.threads}_m{self.m_blocks}_n{self.n_blocks}_k{self.k_blocks}_s{self.stages}_{self.has_act_order}_{self.has_zp}_g{self.group_blocks}.cu" + def all_kernels(): for num_bits in [4, 8]: @@ -140,7 +141,7 @@ def write_kernel(kernel: Kernel, output_dir: Path) -> None: if __name__ == "__main__": - output_dir = Path.cwd() / "generated" + output_dir = Path.cwd() / "gensrc" shutil.rmtree(output_dir, ignore_errors=True) output_dir.mkdir(parents=True, exist_ok=True) for kernel in all_kernels():