Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 34 additions & 39 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

constexpr int kBlockM = Traits::kBlockM;
constexpr int kBlockN = Traits::kBlockN;
constexpr int kBlockK = Traits::kBlockK;
constexpr int kHeadDim = Traits::kHeadDim;
constexpr int kRowsPerMMA = Traits::kRowsPerMMA;

using _BLK_M = Int<kBlockM>;
using _BLK_N = Int<kBlockN>;
using _BLK_K = Int<kBlockK>;
using _HEAD_DIM = Int<kHeadDim>;

// type alias
Expand Down Expand Up @@ -113,14 +111,14 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
DType* k_smem = q_smem + cosize(SmemLayoutQ{});
DType* v_smem = k_smem + cosize(SmemLayoutK{});

// (BLK_M, BLK_K), k-major
// (BLK_M, HEAD_DIM), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
// (BLK_N, BLK_K), k-major
// (BLK_N, HEAD_DIM), k-major
Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{});

// Tensor for V^t; used in GEMM-II.
// (BLK_K, BLK_N)
// (HEAD_DIM, BLK_N), m-major
Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{});

// Tiled Copy
Expand All @@ -135,11 +133,11 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{});
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);

auto produce_q = [&]() {
auto produce_query = [&]() {
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
safe_copy</*EVEN_MN=*/false, EVEN_K>(
safe_copy</*EVEN_MN=*/false, EVEN_K, /*ZFILL_MN=*/true, /*ZFILL_K=*/true>(
gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord);
};

Expand All @@ -148,36 +146,36 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);

Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
auto produce_k = [&](int ni) {
auto produce_key = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
// skip zfill_mn for k since mask will mask out oob with -inf
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/false>(
// skip ZFILL_MN for key since Mask will mask out oob with -inf
safe_copy</*EVEN_MN=*/false, EVEN_K, /*ZFILL_MN=*/false, /*ZFILL_K=*/true>(
gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
};

auto produce_k_no_oob = [&](int ni) {
// produce key without oob handling
auto produce_key_no_oob = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
safe_copy</*EVEN_MN=*/true, EVEN_K>(
safe_copy</*EVEN_MN=*/true, EVEN_K, /*ZFILL_MN=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
};

Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
auto produce_v = [&](int ni) {
auto produce_value = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
// skipping ZFILL_MN for v may cause nan issue
safe_copy</*EVEN_MN=*/false, EVEN_K>(
safe_copy</*EVEN_MN=*/false, EVEN_K, /*ZFILL_MN=*/true, /*ZFILL_K=*/true>(
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
};

auto produce_v_no_oob = [&](int ni) {
// produce value without oob handling
auto produce_value_no_oob = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
safe_copy</*EVEN_MN=*/true, EVEN_K>(
safe_copy</*EVEN_MN=*/true, EVEN_K, /*ZFILL_MN=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
};

Expand Down Expand Up @@ -288,10 +286,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
__syncthreads();

auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/false,
/*ZFILL_K=*/false>(
safe_copy</*EVEN_MN=*/false, EVEN_K, /*ZFILL_MN=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord);
};

Expand All @@ -316,11 +311,11 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

// ############### Prologue ###############
int n_block_idx = n_block_max - 1;
// produce q: [] => [q]
produce_q();
// produce query: [] => [q]
produce_query();
cp_async_fence();
// produce k: [q] => [q, k]
produce_k(n_block_idx);
// produce key: [q] => [q, k]
produce_key(n_block_idx);
cp_async_fence();

// ############### Mainloop ###############
Expand All @@ -341,15 +336,15 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
for (int i = 0; i < n_oob_mask; ++i) {
clear(tSrAccS);

// wait k, queue: [q, k] => []
// wait key, queue: [q, k] => []
cp_async_wait<0>();
__syncthreads();

// produce v, [] => [v]
// produce value, [] => [v]
if (i == 0) {
produce_v(n_block_idx);
produce_value(n_block_idx);
} else {
produce_v_no_oob(n_block_idx);
produce_value_no_oob(n_block_idx);
}
cp_async_fence();

Expand All @@ -362,13 +357,13 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
mask.apply(tSrAccS_rc_view, m_block, n_block_idx, tidx);
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// wait v, [v] => []
// wait value, [v] => []
cp_async_wait<0>();
__syncthreads();

// produce next k: [] => [k]
// produce next key: [] => [k]
if (n_block_idx > n_block_min) {
produce_k_no_oob(n_block_idx - 1);
produce_key_no_oob(n_block_idx - 1);
}
cp_async_fence();

Expand All @@ -387,12 +382,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
for (; n_block_idx >= n_block_min; --n_block_idx) {
clear(tSrAccS);

// wait k, queue: [q, k] => []
// wait key, queue: [q, k] => []
cp_async_wait<0>();
__syncthreads();

// produce v, [] => [v]
produce_v_no_oob(n_block_idx);
// produce value, [] => [v]
produce_value_no_oob(n_block_idx);
cp_async_fence();

// 1> S = [email protected]
Expand All @@ -404,13 +399,13 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, m_block, n_block_idx, tidx);
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// wait v, [v] => []
// wait value, [v] => []
cp_async_wait<0>();
__syncthreads();

// produce next k: [] => [k]
// produce next key: [] => [k]
if (n_block_idx > n_block_min) {
produce_k_no_oob(n_block_idx - 1);
produce_key_no_oob(n_block_idx - 1);
}
cp_async_fence();

Expand Down
34 changes: 17 additions & 17 deletions src/kernels/attention/cute_extensions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom,

template <bool EVEN_MN,
bool EVEN_K,
bool ZFILL_MN = true,
bool ZFILL_K = true,
bool ZFILL_MN,
bool ZFILL_K,
class CopyAtom,
class TV,
class Tiler,
Expand All @@ -80,15 +80,19 @@ CUTE_HOST_DEVICE void safe_copy(
for (int ki = 0; ki < size<2>(src); ++ki) {
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
copy(copy_atom, src(_, mi, ki), dst(_, mi, ki));
} else {
if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki));
}
} else if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki));
}
}
} else {
if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
} else if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
} else if constexpr (ZFILL_K) {
// still need to handle k oob even if m/n is not zfilled
CUTE_UNROLL
for (int ki = 0; ki < size<2>(src); ++ki) {
if (!elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki));
}
}
}
}
Expand All @@ -98,10 +102,8 @@ CUTE_HOST_DEVICE void safe_copy(
for (int mi = 0; mi < size<1>(src); ++mi) {
if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) {
copy(copy_atom, src(_, mi, _), dst(_, mi, _));
} else {
if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
}
} else if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
}
}
} else if constexpr (EVEN_MN && !EVEN_K) {
Expand All @@ -110,10 +112,8 @@ CUTE_HOST_DEVICE void safe_copy(
for (int ki = 0; ki < size<2>(src); ++ki) {
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
copy(copy_atom, src(_, _, ki), dst(_, _, ki));
} else {
if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, _, ki), dst(_, _, ki));
}
} else if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, _, ki), dst(_, _, ki));
}
}
} else {
Expand Down
9 changes: 3 additions & 6 deletions src/kernels/attention/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ struct Mask {
} else if constexpr (!OOB_MASK && LOCAL) {
// local mask
return (q_idx - kv_idx) > sliding_window_;

} else {
// !OOB_MASK && !LOCAL
return false;
Expand All @@ -74,12 +73,10 @@ struct Mask {

if (out_of_boundary) {
rAccS(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} else {
} else if constexpr (ALIBI) {
// Apply alibi bias
if constexpr (ALIBI) {
rAccS(make_coord(i, mi), make_coord(j, nj)) +=
alibi_slope_ * kv_idx;
}
rAccS(make_coord(i, mi), make_coord(j, nj)) +=
alibi_slope_ * kv_idx;
}
}
}
Expand Down
Loading