diff --git a/src/kernels/attention/attention_kernel_sm80.cuh b/src/kernels/attention/attention_kernel_sm80.cuh index d701bc75..83e36675 100644 --- a/src/kernels/attention/attention_kernel_sm80.cuh +++ b/src/kernels/attention/attention_kernel_sm80.cuh @@ -71,8 +71,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto [K, V] = tile.template get_kv_tile(batch_idx, head_idx / group_size); - const int q_len = size<0>(Q.shape()); - const int kv_len = size<0>(K.shape()); + const int q_len = size<0>(Q); + const int kv_len = size<0>(K); if (m_block * kBlockM >= q_len) { // out of bound, return @@ -141,10 +141,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_q = [&]() { auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); - safe_copy( + safe_copy( gmem_tiled_copy_Q, tQgQ, tQsQ, @@ -157,10 +154,9 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_k = [&](int ni) { auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); // skip zfill_mn for k since mask will mask out oob with -inf - safe_copy( + safe_copy( gmem_tiled_copy_KV, tKgK, tKsK, @@ -172,10 +168,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_v = [&](int ni) { auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); // skipping ZFILL_MN for v may cause nan issue - safe_copy( + safe_copy( gmem_tiled_copy_KV, tVgV, tVsV, @@ -288,8 +281,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // wait for smem copy done before gmem copy __syncthreads(); - safe_copy( gmem_tiled_copy_O, diff --git a/src/kernels/attention/attention_traits_sm80.h b/src/kernels/attention/attention_traits_sm80.h index 934dd8d5..863711f2 100644 --- a/src/kernels/attention/attention_traits_sm80.h +++ b/src/kernels/attention/attention_traits_sm80.h @@ -91,7 +91,7 @@ struct AttentionTraitsSM80 { // Tiled copy for QKV // g2s tiled copy for q using GmemTiledCopyQ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, + Copy_Atom, DType>{}, GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) Layout>{} // Val layout: 8 vals per read )); diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/cute_extensions.cuh index bba33669..677f4498 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/cute_extensions.cuh @@ -15,22 +15,62 @@ CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a, return elem_less(get(a), get(b)); } -template +CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, + const TensorS& src, + TensorD&& dst) { + CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch."); + + auto has_with_bool = cute::is_valid( + [](auto t) -> void_t() + .with(true))> {}, + copy_atom); + if constexpr (has_with_bool) { + constexpr int R = TensorD::rank; + if constexpr (R == 1) { // Dispatch the copy + copy_atom.with(false).call(src, dst); + } else { // Loop over all but the first mode + Tensor src_v = group_modes<1, R>(src); + Tensor dst_v = group_modes<1, R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.with(false).call(src_v(_, i), dst_v(_, i)); + } + } + } else { + // just call clear if no with method + clear(dst); + } +} + +template +CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, + const TensorS& src, + TensorD& dst) { + zfill(copy_atom, src, dst); +} + +template CUTE_HOST_DEVICE void safe_copy( - const TiledCopy& tiled_copy, + const TiledCopy& tiled_copy, const TensorS& src, // (CPY, CPY_M/N, CPY_K) TensorD& dst, // (CPY, CPY_M/N, CPY_K) const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k) const Coord& max_coord // max_coord(blk_m/n, blk_k) ) { + CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch."); + auto copy_atom = static_cast(tiled_copy); + if constexpr (!EVEN_MN && !EVEN_K) { // handle both m/n and k oob CUTE_UNROLL @@ -39,16 +79,16 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { - copy(tiled_copy, src(_, mi, ki), dst(_, mi, ki)); + copy(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } else { - if constexpr (ZERO_FILL_K) { - clear(dst(_, mi, ki)); + if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } } } } else { - if constexpr (ZERO_FILL_MN) { - clear(dst(_, mi, _)); + if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } @@ -57,10 +97,10 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int mi = 0; mi < size<1>(src); ++mi) { if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { - copy(tiled_copy, src(_, mi, _), dst(_, mi, _)); + copy(copy_atom, src(_, mi, _), dst(_, mi, _)); } else { - if constexpr (ZERO_FILL_MN) { - clear(dst(_, mi, _)); + if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } @@ -69,16 +109,16 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { - copy(tiled_copy, src(_, _, ki), dst(_, _, ki)); + copy(copy_atom, src(_, _, ki), dst(_, _, ki)); } else { - if constexpr (ZERO_FILL_K) { - clear(dst(_, _, ki)); + if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, _, ki), dst(_, _, ki)); } } } } else { // no oob, just copy - copy(tiled_copy, src, dst); + copy(copy_atom, src, dst); } }