Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto [K, V] =
tile.template get_kv_tile<DType>(batch_idx, head_idx / group_size);

const int q_len = size<0>(Q.shape());
const int kv_len = size<0>(K.shape());
const int q_len = size<0>(Q);
const int kv_len = size<0>(K);

if (m_block * kBlockM >= q_len) {
// out of bound, return
Expand Down Expand Up @@ -141,10 +141,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_q = [&]() {
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
safe_copy</*EVEN_MN=*/false, EVEN_K>(
gmem_tiled_copy_Q,
tQgQ,
tQsQ,
Expand All @@ -157,10 +154,9 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_k = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
// skip zfill_mn for k since mask will mask out oob with -inf
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/true>(
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZERO_FILL_MN=*/false>(
gmem_tiled_copy_KV,
tKgK,
tKsK,
Expand All @@ -172,10 +168,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_v = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
// skipping ZFILL_MN for v may cause nan issue
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
safe_copy</*EVEN_MN=*/false, EVEN_K>(
gmem_tiled_copy_KV,
tVgV,
tVsV,
Expand Down Expand Up @@ -288,8 +281,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

// wait for smem copy done before gmem copy
__syncthreads();
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/false>(
gmem_tiled_copy_O,
Expand Down
2 changes: 1 addition & 1 deletion src/kernels/attention/attention_traits_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct AttentionTraitsSM80 {
// Tiled copy for QKV
// g2s tiled copy for q
using GmemTiledCopyQ = decltype(make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DType>{},
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>, DType>{},
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
));
Expand Down
76 changes: 58 additions & 18 deletions src/kernels/attention/cute_extensions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,62 @@ CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a,
return elem_less(get<I>(a), get<I>(b));
}

template <bool EVEN_K,
bool EVEN_MN,
bool ZERO_FILL_MN,
bool ZERO_FILL_K,
class TiledCopy,
template <class Copy_Atom, class TensorS, class TensorD>
CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom,
const TensorS& src,
TensorD&& dst) {
CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch.");

auto has_with_bool = cute::is_valid(
[](auto t) -> void_t<decltype(declval<typename decltype(t)::Traits>()
.with(true))> {},
copy_atom);
if constexpr (has_with_bool) {
constexpr int R = TensorD::rank;
if constexpr (R == 1) { // Dispatch the copy
copy_atom.with(false).call(src, dst);
} else { // Loop over all but the first mode
Tensor src_v = group_modes<1, R>(src);
Tensor dst_v = group_modes<1, R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
copy_atom.with(false).call(src_v(_, i), dst_v(_, i));
}
}
} else {
// just call clear if no with method
clear(dst);
}
}

template <class Copy_Atom, class TensorS, class TensorD>
CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom,
const TensorS& src,
TensorD& dst) {
zfill(copy_atom, src, dst);
}

template <bool EVEN_MN,
bool EVEN_K,
bool ZFILL_MN = true,
bool ZFILL_K = true,
class CopyAtom,
class TV,
class Tiler,
class TensorS,
class TensorD,
class TensorC,
class Coord>
CUTE_HOST_DEVICE void safe_copy(
const TiledCopy& tiled_copy,
const TiledCopy<CopyAtom, TV, Tiler>& tiled_copy,
const TensorS& src, // (CPY, CPY_M/N, CPY_K)
TensorD& dst, // (CPY, CPY_M/N, CPY_K)
const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k)
const Coord& max_coord // max_coord(blk_m/n, blk_k)
) {
CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch.");
auto copy_atom = static_cast<const CopyAtom&>(tiled_copy);

if constexpr (!EVEN_MN && !EVEN_K) {
// handle both m/n and k oob
CUTE_UNROLL
Expand All @@ -39,16 +79,16 @@ CUTE_HOST_DEVICE void safe_copy(
CUTE_UNROLL
for (int ki = 0; ki < size<2>(src); ++ki) {
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
copy(tiled_copy, src(_, mi, ki), dst(_, mi, ki));
copy(copy_atom, src(_, mi, ki), dst(_, mi, ki));
} else {
if constexpr (ZERO_FILL_K) {
clear(dst(_, mi, ki));
if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki));
}
}
}
} else {
if constexpr (ZERO_FILL_MN) {
clear(dst(_, mi, _));
if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
}
}
}
Expand All @@ -57,10 +97,10 @@ CUTE_HOST_DEVICE void safe_copy(
CUTE_UNROLL
for (int mi = 0; mi < size<1>(src); ++mi) {
if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) {
copy(tiled_copy, src(_, mi, _), dst(_, mi, _));
copy(copy_atom, src(_, mi, _), dst(_, mi, _));
} else {
if constexpr (ZERO_FILL_MN) {
clear(dst(_, mi, _));
if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
}
}
}
Expand All @@ -69,16 +109,16 @@ CUTE_HOST_DEVICE void safe_copy(
CUTE_UNROLL
for (int ki = 0; ki < size<2>(src); ++ki) {
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
copy(tiled_copy, src(_, _, ki), dst(_, _, ki));
copy(copy_atom, src(_, _, ki), dst(_, _, ki));
} else {
if constexpr (ZERO_FILL_K) {
clear(dst(_, _, ki));
if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, _, ki), dst(_, _, ki));
}
}
}
} else {
// no oob, just copy
copy(tiled_copy, src, dst);
copy(copy_atom, src, dst);
}
}

Expand Down
Loading