From 5237dc962a908d5291b7ee6e35446e4f3067f0e0 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Sat, 25 Jan 2025 12:04:26 -0800 Subject: [PATCH] kernel: only zfill k once to improve perf for attention --- .../attention/attention_kernel_sm80.cuh | 73 +++++++++---------- src/kernels/attention/cute_extensions.cuh | 34 ++++----- src/kernels/attention/mask.h | 9 +-- 3 files changed, 54 insertions(+), 62 deletions(-) diff --git a/src/kernels/attention/attention_kernel_sm80.cuh b/src/kernels/attention/attention_kernel_sm80.cuh index 046ea981..a9617d09 100644 --- a/src/kernels/attention/attention_kernel_sm80.cuh +++ b/src/kernels/attention/attention_kernel_sm80.cuh @@ -27,13 +27,11 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { constexpr int kBlockM = Traits::kBlockM; constexpr int kBlockN = Traits::kBlockN; - constexpr int kBlockK = Traits::kBlockK; constexpr int kHeadDim = Traits::kHeadDim; constexpr int kRowsPerMMA = Traits::kRowsPerMMA; using _BLK_M = Int; using _BLK_N = Int; - using _BLK_K = Int; using _HEAD_DIM = Int; // type alias @@ -113,14 +111,14 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { DType* k_smem = q_smem + cosize(SmemLayoutQ{}); DType* v_smem = k_smem + cosize(SmemLayoutK{}); - // (BLK_M, BLK_K), k-major + // (BLK_M, HEAD_DIM), k-major Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); - // (BLK_N, BLK_K), k-major + // (BLK_N, HEAD_DIM), k-major Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{}); // Tensor for V^t; used in GEMM-II. - // (BLK_K, BLK_N) + // (HEAD_DIM, BLK_N), m-major Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{}); // Tiled Copy @@ -135,11 +133,11 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); - auto produce_q = [&]() { + auto produce_query = [&]() { auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim); - safe_copy( + safe_copy( gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord); }; @@ -148,36 +146,36 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); - auto produce_k = [&](int ni) { + auto produce_key = [&](int ni) { auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - // skip zfill_mn for k since mask will mask out oob with -inf - safe_copy( + // skip ZFILL_MN for key since Mask will mask out oob with -inf + safe_copy( gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); }; - auto produce_k_no_oob = [&](int ni) { + // produce key without oob handling + auto produce_key_no_oob = [&](int ni) { auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( + safe_copy( gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); }; Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); - auto produce_v = [&](int ni) { + auto produce_value = [&](int ni) { auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); // skipping ZFILL_MN for v may cause nan issue - safe_copy( + safe_copy( gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); }; - auto produce_v_no_oob = [&](int ni) { + // produce value without oob handling + auto produce_value_no_oob = [&](int ni) { auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( + safe_copy( gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); }; @@ -288,10 +286,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { __syncthreads(); auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim); - safe_copy( + safe_copy( gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord); }; @@ -316,11 +311,11 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // ############### Prologue ############### int n_block_idx = n_block_max - 1; - // produce q: [] => [q] - produce_q(); + // produce query: [] => [q] + produce_query(); cp_async_fence(); - // produce k: [q] => [q, k] - produce_k(n_block_idx); + // produce key: [q] => [q, k] + produce_key(n_block_idx); cp_async_fence(); // ############### Mainloop ############### @@ -341,15 +336,15 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { for (int i = 0; i < n_oob_mask; ++i) { clear(tSrAccS); - // wait k, queue: [q, k] => [] + // wait key, queue: [q, k] => [] cp_async_wait<0>(); __syncthreads(); - // produce v, [] => [v] + // produce value, [] => [v] if (i == 0) { - produce_v(n_block_idx); + produce_value(n_block_idx); } else { - produce_v_no_oob(n_block_idx); + produce_value_no_oob(n_block_idx); } cp_async_fence(); @@ -362,13 +357,13 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { mask.apply(tSrAccS_rc_view, m_block, n_block_idx, tidx); softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); - // wait v, [v] => [] + // wait value, [v] => [] cp_async_wait<0>(); __syncthreads(); - // produce next k: [] => [k] + // produce next key: [] => [k] if (n_block_idx > n_block_min) { - produce_k_no_oob(n_block_idx - 1); + produce_key_no_oob(n_block_idx - 1); } cp_async_fence(); @@ -387,12 +382,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { for (; n_block_idx >= n_block_min; --n_block_idx) { clear(tSrAccS); - // wait k, queue: [q, k] => [] + // wait key, queue: [q, k] => [] cp_async_wait<0>(); __syncthreads(); - // produce v, [] => [v] - produce_v_no_oob(n_block_idx); + // produce value, [] => [v] + produce_value_no_oob(n_block_idx); cp_async_fence(); // 1> S = Q@K.T @@ -404,13 +399,13 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { mask.apply(tSrAccS_rc_view, m_block, n_block_idx, tidx); softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); - // wait v, [v] => [] + // wait value, [v] => [] cp_async_wait<0>(); __syncthreads(); - // produce next k: [] => [k] + // produce next key: [] => [k] if (n_block_idx > n_block_min) { - produce_k_no_oob(n_block_idx - 1); + produce_key_no_oob(n_block_idx - 1); } cp_async_fence(); diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/cute_extensions.cuh index 677f4498..a00a64f3 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/cute_extensions.cuh @@ -52,8 +52,8 @@ CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, template (src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { copy(copy_atom, src(_, mi, ki), dst(_, mi, ki)); - } else { - if constexpr (ZFILL_K) { - zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); - } + } else if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } } - } else { - if constexpr (ZFILL_MN) { - zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); + } else if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); + } else if constexpr (ZFILL_K) { + // still need to handle k oob even if m/n is not zfilled + CUTE_UNROLL + for (int ki = 0; ki < size<2>(src); ++ki) { + if (!elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { + zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); + } } } } @@ -98,10 +102,8 @@ CUTE_HOST_DEVICE void safe_copy( for (int mi = 0; mi < size<1>(src); ++mi) { if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { copy(copy_atom, src(_, mi, _), dst(_, mi, _)); - } else { - if constexpr (ZFILL_MN) { - zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); - } + } else if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } else if constexpr (EVEN_MN && !EVEN_K) { @@ -110,10 +112,8 @@ CUTE_HOST_DEVICE void safe_copy( for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { copy(copy_atom, src(_, _, ki), dst(_, _, ki)); - } else { - if constexpr (ZFILL_K) { - zfill(copy_atom, src(_, _, ki), dst(_, _, ki)); - } + } else if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, _, ki), dst(_, _, ki)); } } } else { diff --git a/src/kernels/attention/mask.h b/src/kernels/attention/mask.h index 7c66aa52..13d7085e 100644 --- a/src/kernels/attention/mask.h +++ b/src/kernels/attention/mask.h @@ -65,7 +65,6 @@ struct Mask { } else if constexpr (!OOB_MASK && LOCAL) { // local mask return (q_idx - kv_idx) > sliding_window_; - } else { // !OOB_MASK && !LOCAL return false; @@ -74,12 +73,10 @@ struct Mask { if (out_of_boundary) { rAccS(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } else { + } else if constexpr (ALIBI) { // Apply alibi bias - if constexpr (ALIBI) { - rAccS(make_coord(i, mi), make_coord(j, nj)) += - alibi_slope_ * kv_idx; - } + rAccS(make_coord(i, mi), make_coord(j, nj)) += + alibi_slope_ * kv_idx; } } }