Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu")

cc_library(
NAME
attention.kernel
attention.kernels
HDRS
# attention.h
attn_api.h
SRCS
# attention.cpp
attn_api.cpp
${GENERATED_SRC_FILES}
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
Expand Down
15 changes: 11 additions & 4 deletions src/kernels/attention/attention_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ torch::Tensor attention_sm80(
} // namespace

class AttentionKernelTest
: public ::testing::TestWithParam<std::tuple<int64_t /*batch_size*/,
: public ::testing::TestWithParam<std::tuple<torch::ScalarType /*q_dtype*/,
int64_t /*batch_size*/,
int64_t /*q_len*/,
int64_t /*kv_len*/,
int64_t /*n_heads*/,
Expand All @@ -111,7 +112,8 @@ class AttentionKernelTest
};

TEST_P(AttentionKernelTest, MHA) {
const auto [batch_size,
const auto [dtype,
batch_size,
q_len,
kv_len,
n_heads,
Expand All @@ -121,7 +123,7 @@ TEST_P(AttentionKernelTest, MHA) {
alibi,
sliding_window] = GetParam();

const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA);
const auto options = torch::dtype(dtype).device(torch::kCUDA);

// construct non-contiguous query, key and value
const auto data = torch::randn(
Expand All @@ -143,13 +145,18 @@ TEST_P(AttentionKernelTest, MHA) {
auto out = attention_sm80(
query, key, value, alibi_slopes, logits_soft_cap, sliding_window, q_len);

EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
if (dtype == torch::kBFloat16) {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2));
} else {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
}
}

INSTANTIATE_TEST_SUITE_P(
MHA,
AttentionKernelTest,
::testing::Combine(
::testing::Values(torch::kHalf, torch::kBFloat16), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(1, 62, 125), // q_len
::testing::Values(127, 287, 1000), // kv_len
Expand Down
12 changes: 6 additions & 6 deletions src/kernels/attention/attention_launch_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,43 @@ void run_attention_kernel(const Params& params, cudaStream_t stream) {
} // namespace detail

// user-facing function to run the attention kernel
template <typename Element, int HEAD_DIM, typename Params>
template <typename Dtype, int HEAD_DIM, typename Params>
void run_attention_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
// normalize params that for performance optimization
params.normalize();

// TODO: tune block shape MNK based on the head dim and smem size
if constexpr (HEAD_DIM == 64) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::run_attention_kernel<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 96) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/32>;
detail::run_attention_kernel<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 128) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::run_attention_kernel<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 256) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::run_attention_kernel<Traits>(params, stream);
} else {
// use the default block size
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
Expand Down
76 changes: 76 additions & 0 deletions src/kernels/attention/attn_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include "attn_api.h"

#include <ATen/cuda/CUDAContext.h>

#include "attention_params.h"
#include "cute/layout.hpp"
#include "static_dispatch.h"

namespace llm {
using namespace cute;

// forward declaration
template <typename Dtype, int HEAD_DIM, typename Params>
void run_attention_kernel_sm80(Params& params, cudaStream_t stream);

void paged_kv_varlen_mha(
torch::Tensor& out, // [n_tokens, n_heads, head_dim]
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim]
const torch::Tensor& value_cache, // [n_slots, n_kv_heads, head_dim]
const torch::Tensor& q_cu_lens, // [batch + 1]
const torch::Tensor& kv_cu_lens, // [batch + 1]
const torch::Tensor& block_table,
const torch::Tensor& block_cu_lens, // [batch + 1]
const std::optional<torch::Tensor>& alibi_slopes, // [n_heads]
int block_size,
int max_q_len,
int max_kv_len,
float sm_scale,
float logits_soft_cap,
int sliding_window) {
const int batch_size = q_cu_lens.size(0) - 1;
const int n_heads = query.size(-2);
const int n_kv_heads = key_cache.size(-2);
const int head_dim = query.size(-1);

// construct attention params
PagedKVAttentionParams params;
params.q_ptr = query.const_data_ptr();
params.q_stride = make_stride(query.stride(0), query.stride(1));
params.k_ptr = key_cache.const_data_ptr();
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
params.v_ptr = value_cache.const_data_ptr();
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));
params.alibi_slopes_ptr = alibi_slopes.has_value()
? alibi_slopes.value().const_data_ptr<float>()
: nullptr;
params.batch_size = batch_size;
params.block_size = block_size;
params.max_q_len = max_q_len;
(void)max_kv_len; // unused
params.n_heads = n_heads;
params.n_kv_heads = n_kv_heads;
params.head_dim = head_dim;

params.sm_scale = sm_scale;
params.logits_soft_cap = logits_soft_cap;
params.sliding_window = sliding_window;

params.q_cu_lens = q_cu_lens.const_data_ptr<int32_t>();
params.kv_cu_lens = kv_cu_lens.const_data_ptr<int32_t>();

params.block_table = block_table.const_data_ptr<int32_t>();
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
DISPATCH_TORCH_DTYPE(query.scalar_type(), DTYPE, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params, stream);
});
});
}

} // namespace llm
29 changes: 29 additions & 0 deletions src/kernels/attention/attn_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <torch/torch.h>
#include <torch/types.h>

namespace llm {
// the input tensors are packed into one-dimensional tensors, and the sequence
// lengths are stored in q_cu_lens and k_cu_lens.
// for each sequence,
// the starting offset: q/kv_cu_lens[i]
// the length: q/kv_cu_lens[i+1] - q/kv_cu_lens[i].
// the maximum sequence length is max_q_len and max_kv_len, which are used
// to decide the kernel dispatch.
void paged_kv_varlen_mha(
torch::Tensor& out, // [n_tokens, n_heads, head_dim]
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim]
const torch::Tensor& value_cache, // [n_slots, n_kv_heads, head_dim]
const torch::Tensor& q_cu_lens, // [batch + 1]
const torch::Tensor& kv_cu_lens, // [batch + 1]
const torch::Tensor& block_table,
const torch::Tensor& block_cu_lens, // [batch + 1]
const std::optional<torch::Tensor>& alibi_slopes, // [n_heads]
int block_size,
int max_q_len,
int max_kv_len,
float sm_scale,
float logits_soft_cap,
int sliding_window);

} // namespace llm
6 changes: 3 additions & 3 deletions src/layers/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ cc_library(
handler.h
ref_handler.h
flash_attn_handler.h
flash_infer_handler.h
scale_attn_handler.h
attention.h
SRCS
handler.cpp
ref_handler.cpp
flash_attn_handler.cpp
flash_infer_handler.cpp
scale_attn_handler.cpp
attention.cpp
DEPS
:state_dict
:memory
:pos_embedding
:kernels
:flash_attn.kernels
# :flash_infer.kernels
:attention.kernels
glog::glog
gflags::gflags
torch
Expand Down
8 changes: 4 additions & 4 deletions src/layers/attention/attention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <cstdint>

#include "flash_attn_handler.h"
#include "scale_attn_handler.h"
#include "gtest/gtest.h"
#include "models/parameters.h"
#include "ref_handler.h"
Expand Down Expand Up @@ -229,10 +229,10 @@ TEST_P(AttentionDecodeTest, KVCache) {
ref_handler.batch_prefill(
query, key, value, input_params, sliding_window, ref_output);

// flash attn handler
FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes);
// attn handler
ScaleAttnHandler attn_handler(sm_scale, logits_soft_cap, alibi_slopes);
torch::Tensor output = torch::empty_like(query);
flash_attn_handler.batch_decode(
attn_handler.batch_decode(
query, kv_cache, input_params, sliding_window, output);

const bool success =
Expand Down
3 changes: 0 additions & 3 deletions src/layers/attention/flash_attn_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ class FlashAttnHandler : public AttentionHandler {

// alibi slopes
torch::optional<torch::Tensor> alibi_slopes_;

// stream for kv cache
cudaStream_t stream_ = nullptr;
};

} // namespace llm
47 changes: 0 additions & 47 deletions src/layers/attention/flash_infer_handler.cpp

This file was deleted.

6 changes: 3 additions & 3 deletions src/layers/attention/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <memory>

#include "flash_attn_handler.h"
#include "flash_infer_handler.h"
#include "scale_attn_handler.h"
#include "layers/pos_embedding.h"
#include "ref_handler.h"

Expand Down Expand Up @@ -73,7 +73,7 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_alibi(
// choose the best handler based on device type
if (is_cuda) {
// use flash_attn for cuda device
return std::make_unique<FlashAttnHandler>(
return std::make_unique<ScaleAttnHandler>(
sm_scale, args.attn_logit_soft_cap(), alibi_slopes);
}

Expand Down Expand Up @@ -125,7 +125,7 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
// choose the best handler based on device type
if (is_cuda) {
// use flash_attn for cuda device
return std::make_unique<FlashAttnHandler>(sm_scale,
return std::make_unique<ScaleAttnHandler>(sm_scale,
args.attn_logit_soft_cap(),
rotary_dim,
args.max_position_embeddings(),
Expand Down
Loading
Loading