Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 104 additions & 52 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
using SmemTiledCopyO = typename Traits::SmemTiledCopyO;

const int m_block = blockIdx.x;
const auto batch_idx = blockIdx.y;
const auto head_idx = blockIdx.z;
const auto tidx = threadIdx.x;
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.z;
const int tidx = threadIdx.x;

AttentionTile<Params> tile(params);

Expand All @@ -75,7 +75,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
const int kv_len = size<0>(K);

if (m_block * kBlockM >= q_len) {
// out of bound, return
// m out of bound, return
return;
}

Expand Down Expand Up @@ -134,46 +134,51 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
// (BLK_M, HEAD_DIM) -> (blk_m, head_dim)
Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{});
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
// (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{});
Tensor tKcKV = gmem_thr_copy_KV.partition_S(cKV);

auto produce_q = [&]() {
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
safe_copy</*EVEN_MN=*/false, EVEN_K>(
gmem_tiled_copy_Q,
tQgQ,
tQsQ,
tQcQ,
make_coord(q_len - m_block * kBlockM, head_dim));
gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord);
};

// TODO: seperate mask iterations
// (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{});
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);

Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
auto produce_k = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
// skip zfill_mn for k since mask will mask out oob with -inf
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZERO_FILL_MN=*/false>(
gmem_tiled_copy_KV,
tKgK,
tKsK,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
/*ZFILL_MN=*/false>(
gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
};

auto produce_k_no_oob = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
safe_copy</*EVEN_MN=*/true, EVEN_K>(
gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
};

Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
auto produce_v = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
// skipping ZFILL_MN for v may cause nan issue
safe_copy</*EVEN_MN=*/false, EVEN_K>(
gmem_tiled_copy_KV,
tVgV,
tVsV,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
};

auto produce_v_no_oob = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
safe_copy</*EVEN_MN=*/true, EVEN_K>(
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
};

TiledMma tiled_mma;
Expand Down Expand Up @@ -281,84 +286,131 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

// wait for smem copy done before gmem copy
__syncthreads();

auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/false>(
gmem_tiled_copy_O,
tOsO,
tOgO,
tOcO,
make_coord(q_len - m_block * kBlockM, head_dim));
/*ZFILL_MN=*/false,
/*ZFILL_K=*/false>(
gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord);
};

// output accumulator, (MMA,MMA_M,MMA_K)
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
auto tOrAccO_rc_view =
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout()));
clear(tOrAccO);

const int diagonal = m_block * kBlockM + kv_len - q_len;
// process kv in range: [kv_idx_min, kv_idx_max)
const int kv_idx_min = std::max(0, diagonal - sliding_window);
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0;
const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN);
// TODO: handle n_block_min >= n_block_max

// ############### Prologue ###############
if (n_block_min >= n_block_max) {
// write output to gmem
epilogue(tOrAccO);
return;
}

// ############### Prologue ###############
int n_block_idx = n_block_max - 1;
// produce q: [] => [q]
produce_q();
cp_async_fence();
// produce k: [q] => [q, k]
produce_k(n_block_min);
produce_k(n_block_idx);
cp_async_fence();

// ############### Mainloop ###############

// output accumulator, (MMA,MMA_M,MMA_K)
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
auto tOrAccO_rc_view =
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout()));
OnlineSoftmax<kRowsPerMMA * size<1>(tOrAccO)> softmax(sm_scale_log2);
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
q_len, kv_len, sliding_window, alibi_slope);

// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrAccS_rc_view =
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
// seperate oob mask iterations for better performance
constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1;

OnlineSoftmax<kRowsPerMMA * size<1>(tOrAccO)> softmax(sm_scale_log2);
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
q_len, kv_len, sliding_window, alibi_slope);

clear(tOrAccO);
CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
// oob mask iterations
CUTE_UNROLL
for (int i = 0; i < n_oob_mask; ++i) {
clear(tSrAccS);

// wait k, queue: [q, k] => []
cp_async_wait<0>();
__syncthreads();

// produce v, [] => [v]
produce_v(ni);
if (i == 0) {
produce_v(n_block_idx);
} else {
produce_v_no_oob(n_block_idx);
}
cp_async_fence();

// 1> S = [email protected]
compute_qk(tSrAccS);

// apply soft cap if needed
if constexpr (SOFT_CAP) {
apply_logits_soft_cap(tSrAccS);
}
mask.apply(tSrAccS_rc_view, m_block, n_block_idx, tidx);
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// apply mask for block (m_block, ni)
mask.apply(tSrAccS_rc_view, m_block, ni, tidx);
// wait v, [v] => []
cp_async_wait<0>();
__syncthreads();

// produce next k: [] => [k]
if (n_block_idx > n_block_min) {
produce_k_no_oob(n_block_idx - 1);
}
cp_async_fence();

// 2> O = softmax(S)*V
compute_sv(tSrAccS, tOrAccO);

--n_block_idx;
if (n_block_idx < n_block_min) {
// no more kv blocks to process
break;
}
}

// apply softmax and rescale
// non-oob mask iterations
CUTE_NO_UNROLL
for (; n_block_idx >= n_block_min; --n_block_idx) {
clear(tSrAccS);

// wait k, queue: [q, k] => []
cp_async_wait<0>();
__syncthreads();

// produce v, [] => [v]
produce_v_no_oob(n_block_idx);
cp_async_fence();

// 1> S = [email protected]
compute_qk(tSrAccS);

if constexpr (SOFT_CAP) {
apply_logits_soft_cap(tSrAccS);
}
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, m_block, n_block_idx, tidx);
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// wait v, [v] => []
cp_async_wait<0>();
__syncthreads();

// produce next k: [] => [k]
if (ni != n_block_max - 1) {
produce_k(ni + 1);
if (n_block_idx > n_block_min) {
produce_k_no_oob(n_block_idx - 1);
}
cp_async_fence();

Expand Down
20 changes: 15 additions & 5 deletions src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@
#include "attention_params.h"
#include "attention_ref.h"
#include "cute/layout.hpp"
#include "static_dispatch.h"

namespace llm {
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

namespace {
torch::Tensor attention_pagedkv_sm80(
torch::Tensor query, // [q_seq_len, n_heads, head_dim]
Expand Down Expand Up @@ -61,10 +73,8 @@ torch::Tensor attention_pagedkv_sm80(
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();
params.block_size = block_size;

DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
});
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<cute::half_t, HEAD_DIM>(params);
});
return out;
}
Expand Down
30 changes: 28 additions & 2 deletions src/kernels/attention/attention_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,32 @@
#include "static_dispatch.h"

namespace llm {
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

#define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \
[&] { \
if (TORCH_DTYPE == torch::kHalf) { \
using TYPE_NAME = cute::half_t; \
return __VA_ARGS__(); \
} else if (TORCH_DTYPE == torch::kBFloat16) { \
using TYPE_NAME = cute::bfloat16_t; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

namespace {
torch::Tensor attention_sm80(
torch::Tensor query, // [batch_size, q_len, n_heads, head_dim]
Expand Down Expand Up @@ -57,8 +83,8 @@ torch::Tensor attention_sm80(
params.logits_soft_cap = logits_soft_cap;
params.sliding_window = sliding_window;

DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
});
});
Expand Down
20 changes: 15 additions & 5 deletions src/kernels/attention/attention_kernel_sm80_varlen_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@
#include "attention_params.h"
#include "attention_ref.h"
#include "cute/layout.hpp"
#include "static_dispatch.h"

namespace llm {
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

namespace {
torch::Tensor attention_varlen_sm80(
torch::Tensor query, // [q_len, n_heads, head_dim]
Expand Down Expand Up @@ -54,10 +66,8 @@ torch::Tensor attention_varlen_sm80(
params.q_cu_lens = q_cu_lens.const_data_ptr<int32_t>();
params.kv_cu_lens = kv_cu_lens.const_data_ptr<int32_t>();

DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
});
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<cute::half_t, HEAD_DIM>(params);
});
return out;
}
Expand Down
Loading
Loading