Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ cc_test(
attention_kernel_sm80_varlen_test.cu
attention_kernel_sm80_pagedkv_test.cu
DEPS
:attention.kernel
:attention.template
absl::random_random
GTest::gtest_main
torch
Expand Down
24 changes: 14 additions & 10 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
}

const int head_dim = params.head_dim;
const int sliding_window = LOCAL ? params.sliding_window : kv_len;
const float logits_soft_cap = params.logits_soft_cap;
const float sm_scale = params.sm_scale;
const float sm_scale_log2 = params.sm_scale_log2;
const float sliding_window = LOCAL ? params.sliding_window : kv_len;
const float alibi_slope =
ALIBI ? (params.alibi_slopes_ptr[head_idx] / sm_scale) : 0.0f;

Expand Down Expand Up @@ -156,12 +156,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
auto produce_k = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
// skip zero fill oob for k since mask will mask out oob with -inf
// skip zfill_mn for k since mask will mask out oob with -inf
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_Q,
gmem_tiled_copy_KV,
tKgK,
tKsK,
tKcKV,
Expand All @@ -171,12 +171,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
auto produce_v = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
// TODO: skip zero fill oob for v, may have nan issue
// skipping ZFILL_MN for v may cause nan issue
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_Q,
gmem_tiled_copy_KV,
tVgV,
tVsV,
tKcKV,
Expand Down Expand Up @@ -299,13 +299,21 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
make_coord(q_len - m_block * kBlockM, head_dim));
};

const int diagonal = m_block * kBlockM + kv_len - q_len;
// process kv in range: [kv_idx_min, kv_idx_max)
const int kv_idx_min = std::max(0, diagonal - sliding_window);
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0;
const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN);
// TODO: handle n_block_min >= n_block_max

// ############### Prologue ###############

// produce q: [] => [q]
produce_q();
cp_async_fence();
// produce k: [q] => [q, k]
produce_k(0);
produce_k(n_block_min);
cp_async_fence();

// ############### Mainloop ###############
Expand All @@ -324,10 +332,6 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
q_len, kv_len, sliding_window, alibi_slope);

// TODO: control block min/max precisely
const int n_block_min = 0;
const int n_block_max = cute::ceil_div(kv_len, kBlockN);

clear(tOrAccO);
CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
Expand Down
102 changes: 1 addition & 101 deletions src/kernels/attention/attention_kernel_sm80_varlen_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,112 +4,12 @@

#include "attention_launch_sm80.cuh"
#include "attention_params.h"
#include "attention_ref.h"
#include "cute/layout.hpp"
#include "static_dispatch.h"

namespace llm {
namespace {
// Multi-head attention implementation using pytorch
torch::Tensor attention_ref(
torch::Tensor query, // [q_len, n_heads, head_dim]
torch::Tensor key, // [kv_len, n_kv_heads, head_dim]
torch::Tensor value, // [kv_len, n_kv_heads, head_dim]
torch::optional<torch::Tensor> alibi_slopes, //[n_heads]
float logits_soft_cap,
int32_t sliding_window) {
const auto q_len = query.size(-3);
const auto kv_len = key.size(-3);
const auto n_heads = query.size(-2);
const auto n_kv_heads = key.size(-2);
const auto head_dim = query.size(-1);
assert(kv_len >= q_len);

if (n_heads != n_kv_heads) {
assert(n_heads % n_kv_heads == 0);
const auto group_size = n_heads / n_kv_heads;
key = key.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2);
value = value.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2);
}

const float sm_scale = 1.0 / sqrt(head_dim);
// query * key => [n_heads, q_len, kv_len]
auto scores = torch::einsum("qhd,khd->hqk",
{query.to(torch::kFloat), key.to(torch::kFloat)});
// apply scale
scores *= sm_scale;

// apply softcap if needed
if (logits_soft_cap != 0.0) {
scores = torch::tanh(scores / logits_soft_cap) * logits_soft_cap;
}

// apply alibi bias
if (alibi_slopes) {
const auto& slopes = alibi_slopes.value();
// calculate alibi attention bias
// since it's causal mask, we can just use [0, 1, ...,, kv_len)
auto distance = torch::arange(0, kv_len, query.options());
// [n_heads, 1, kv_len]
auto bias = distance.view({1, 1, kv_len}) * slopes.view({n_heads, 1, 1});
scores += bias;
}

auto mask = torch::ones({q_len, kv_len}, torch::kBool);
if (sliding_window >= 0) {
// sliding window mask
// returns the upper triangular part of a matrix
mask = torch::triu(mask, /*diagonal=*/kv_len - q_len - sliding_window);
}

// apply causal mask
// causal mask: returns the lower triangular part of a matrix
mask = torch::tril(mask, /*diagonal=*/kv_len - q_len).to(query);
scores = scores.masked_fill(mask == 0, -INFINITY);

// safe softmax
scores = torch::softmax(scores, /*dim=*/-1);

// score * value => [q_len, n_heads, head_dim]
return torch::einsum("hqk,khd->qhd", {scores, value.to(torch::kFloat)})
.type_as(query);
}

torch::Tensor attention_varlen_ref(
torch::Tensor query, // [q_len, n_heads, head_dim]
torch::Tensor key, // [kv_len, n_kv_heads, head_dim]
torch::Tensor value, // [kv_len, n_kv_heads, head_dim]
torch::Tensor q_cu_lens, // [batch_size + 1]
torch::Tensor kv_cu_lens, // [batch_size + 1]
torch::optional<torch::Tensor> alibi_slopes, //[n_heads]
float logits_soft_cap,
int32_t sliding_window) {
torch::Tensor q_cu_lens_cpu = q_cu_lens.cpu();
torch::Tensor kv_cu_seq_lens_cpu = kv_cu_lens.cpu();
const size_t n_seqs = q_cu_lens_cpu.numel() - 1;
const int32_t* q_cu_lens_ptr = q_cu_lens_cpu.data_ptr<int32_t>();
const int32_t* kv_cu_lens_ptr = kv_cu_seq_lens_cpu.data_ptr<int32_t>();

std::vector<torch::Tensor> out_list;
// process sequence one by one
for (int64_t i = 0; i < n_seqs; ++i) {
// calaculate attention for each sequence
const int32_t q_start = q_cu_lens_ptr[i];
const int32_t q_end = q_cu_lens_ptr[i + 1];
const int32_t kv_start = kv_cu_lens_ptr[i];
const int32_t kv_end = kv_cu_lens_ptr[i + 1];

torch::Tensor q = query.slice(/*dim=*/0, /*start=*/q_start, /*end=*/q_end);
torch::Tensor k = key.slice(/*dim=*/0, /*start=*/kv_start, /*end=*/kv_end);
torch::Tensor v =
value.slice(/*dim=*/0, /*start=*/kv_start, /*end=*/kv_end);

auto output =
attention_ref(q, k, v, alibi_slopes, logits_soft_cap, sliding_window);
out_list.push_back(output);
}
return torch::cat(out_list, /*dim=*/0);
}

torch::Tensor attention_varlen_sm80(
torch::Tensor query, // [q_len, n_heads, head_dim]
torch::Tensor key, // [kv_len, n_kv_heads, head_dim]
Expand Down
8 changes: 6 additions & 2 deletions src/kernels/attention/attention_traits_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,20 @@ struct AttentionTraitsSM80 {
// O smem: (BLK_M, K):(K, 1), k-major, same as Q
using SmemLayoutO = SmemLayoutQ;

// use 128-bit vectorizing copy
using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;

// s2g tiled copy for O
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<DefaultCopy, DType>{},
Copy_Atom<VectorizingCopy, DType>{},
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
));

// r2s tiled copy for O
using SmemTiledCopyO =
decltype(make_tiled_copy_C(Copy_Atom<DefaultCopy, DType>{}, TiledMma{}));
decltype(make_tiled_copy_C(Copy_Atom<VectorizingCopy, DType>{},
TiledMma{}));

// constexpr values for kernel launch
static constexpr size_t kSmemSize =
Expand Down
22 changes: 13 additions & 9 deletions src/kernels/attention/flash_attn/src/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ struct Flash_kernel_traits {
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
#endif

// use 128-bit vectorizing copy
using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtom = Copy_Atom<VectorizingCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<VectorizingCopy, elem_type>;
#endif
};

Expand All @@ -49,6 +52,7 @@ struct Flash_fwd_kernel_traits : public Base {
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using VectorizingCopy = typename Base::VectorizingCopy;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;

Expand Down Expand Up @@ -97,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using SmemCopyAtomO = Copy_Atom<VectorizingCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<VectorizingCopy, ElementAccum>;

static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
Expand All @@ -121,7 +125,7 @@ struct Flash_fwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
VectorizingCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
Expand All @@ -140,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));

using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<VectorizingCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store

Expand All @@ -152,7 +156,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<VectorizingCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
Expand All @@ -161,15 +165,15 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<VectorizingCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
using GmemTiledCopyRotcossinPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinContPaged = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<VectorizingCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
};
Expand Down
Loading