diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index d9c72a63..9e19b1fa 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -37,11 +37,11 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu") cc_library( NAME - attention.kernel + attention.kernels HDRS - # attention.h + attn_api.h SRCS - # attention.cpp + attn_api.cpp ${GENERATED_SRC_FILES} INCLUDES ${CMAKE_CURRENT_SOURCE_DIR} diff --git a/src/kernels/attention/attention_kernel_sm80_test.cu b/src/kernels/attention/attention_kernel_sm80_test.cu index f20e994a..9329a2d1 100644 --- a/src/kernels/attention/attention_kernel_sm80_test.cu +++ b/src/kernels/attention/attention_kernel_sm80_test.cu @@ -94,7 +94,8 @@ torch::Tensor attention_sm80( } // namespace class AttentionKernelTest - : public ::testing::TestWithParam +template void run_attention_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { // normalize params that for performance optimization params.normalize(); // TODO: tune block shape MNK based on the head dim and smem size if constexpr (HEAD_DIM == 64) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 96) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 128) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 256) { - using Traits = AttentionTraitsSM80(params, stream); } else { // use the default block size - using Traits = AttentionTraitsSM80 + +#include "attention_params.h" +#include "cute/layout.hpp" +#include "static_dispatch.h" + +namespace llm { +using namespace cute; + +// forward declaration +template +void run_attention_kernel_sm80(Params& params, cudaStream_t stream); + +void paged_kv_varlen_mha( + torch::Tensor& out, // [n_tokens, n_heads, head_dim] + const torch::Tensor& query, // [n_tokens, n_heads, head_dim] + const torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& value_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& q_cu_lens, // [batch + 1] + const torch::Tensor& kv_cu_lens, // [batch + 1] + const torch::Tensor& block_table, + const torch::Tensor& block_cu_lens, // [batch + 1] + const std::optional& alibi_slopes, // [n_heads] + int block_size, + int max_q_len, + int max_kv_len, + float sm_scale, + float logits_soft_cap, + int sliding_window) { + const int batch_size = q_cu_lens.size(0) - 1; + const int n_heads = query.size(-2); + const int n_kv_heads = key_cache.size(-2); + const int head_dim = query.size(-1); + + // construct attention params + PagedKVAttentionParams params; + params.q_ptr = query.const_data_ptr(); + params.q_stride = make_stride(query.stride(0), query.stride(1)); + params.k_ptr = key_cache.const_data_ptr(); + params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1)); + params.v_ptr = value_cache.const_data_ptr(); + params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1)); + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.alibi_slopes_ptr = alibi_slopes.has_value() + ? alibi_slopes.value().const_data_ptr() + : nullptr; + params.batch_size = batch_size; + params.block_size = block_size; + params.max_q_len = max_q_len; + (void)max_kv_len; // unused + params.n_heads = n_heads; + params.n_kv_heads = n_kv_heads; + params.head_dim = head_dim; + + params.sm_scale = sm_scale; + params.logits_soft_cap = logits_soft_cap; + params.sliding_window = sliding_window; + + params.q_cu_lens = q_cu_lens.const_data_ptr(); + params.kv_cu_lens = kv_cu_lens.const_data_ptr(); + + params.block_table = block_table.const_data_ptr(); + params.block_cu_lens = block_cu_lens.const_data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { + DISPATCH_TORCH_DTYPE(query.scalar_type(), DTYPE, [&] { + run_attention_kernel_sm80(params, stream); + }); + }); +} + +} // namespace llm diff --git a/src/kernels/attention/attn_api.h b/src/kernels/attention/attn_api.h new file mode 100644 index 00000000..d5e1bd22 --- /dev/null +++ b/src/kernels/attention/attn_api.h @@ -0,0 +1,29 @@ +#include +#include + +namespace llm { +// the input tensors are packed into one-dimensional tensors, and the sequence +// lengths are stored in q_cu_lens and k_cu_lens. +// for each sequence, +// the starting offset: q/kv_cu_lens[i] +// the length: q/kv_cu_lens[i+1] - q/kv_cu_lens[i]. +// the maximum sequence length is max_q_len and max_kv_len, which are used +// to decide the kernel dispatch. +void paged_kv_varlen_mha( + torch::Tensor& out, // [n_tokens, n_heads, head_dim] + const torch::Tensor& query, // [n_tokens, n_heads, head_dim] + const torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& value_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& q_cu_lens, // [batch + 1] + const torch::Tensor& kv_cu_lens, // [batch + 1] + const torch::Tensor& block_table, + const torch::Tensor& block_cu_lens, // [batch + 1] + const std::optional& alibi_slopes, // [n_heads] + int block_size, + int max_q_len, + int max_kv_len, + float sm_scale, + float logits_soft_cap, + int sliding_window); + +} // namespace llm \ No newline at end of file diff --git a/src/layers/attention/CMakeLists.txt b/src/layers/attention/CMakeLists.txt index f0e25165..20baea3c 100644 --- a/src/layers/attention/CMakeLists.txt +++ b/src/layers/attention/CMakeLists.txt @@ -9,13 +9,13 @@ cc_library( handler.h ref_handler.h flash_attn_handler.h - flash_infer_handler.h + scale_attn_handler.h attention.h SRCS handler.cpp ref_handler.cpp flash_attn_handler.cpp - flash_infer_handler.cpp + scale_attn_handler.cpp attention.cpp DEPS :state_dict @@ -23,7 +23,7 @@ cc_library( :pos_embedding :kernels :flash_attn.kernels - # :flash_infer.kernels + :attention.kernels glog::glog gflags::gflags torch diff --git a/src/layers/attention/attention_test.cpp b/src/layers/attention/attention_test.cpp index a6364dac..f1be201c 100644 --- a/src/layers/attention/attention_test.cpp +++ b/src/layers/attention/attention_test.cpp @@ -11,7 +11,7 @@ #include -#include "flash_attn_handler.h" +#include "scale_attn_handler.h" #include "gtest/gtest.h" #include "models/parameters.h" #include "ref_handler.h" @@ -229,10 +229,10 @@ TEST_P(AttentionDecodeTest, KVCache) { ref_handler.batch_prefill( query, key, value, input_params, sliding_window, ref_output); - // flash attn handler - FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes); + // attn handler + ScaleAttnHandler attn_handler(sm_scale, logits_soft_cap, alibi_slopes); torch::Tensor output = torch::empty_like(query); - flash_attn_handler.batch_decode( + attn_handler.batch_decode( query, kv_cache, input_params, sliding_window, output); const bool success = diff --git a/src/layers/attention/flash_attn_handler.h b/src/layers/attention/flash_attn_handler.h index 0f1391eb..2fa8646d 100644 --- a/src/layers/attention/flash_attn_handler.h +++ b/src/layers/attention/flash_attn_handler.h @@ -66,9 +66,6 @@ class FlashAttnHandler : public AttentionHandler { // alibi slopes torch::optional alibi_slopes_; - - // stream for kv cache - cudaStream_t stream_ = nullptr; }; } // namespace llm diff --git a/src/layers/attention/flash_infer_handler.cpp b/src/layers/attention/flash_infer_handler.cpp deleted file mode 100644 index 838b6d26..00000000 --- a/src/layers/attention/flash_infer_handler.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include "flash_infer_handler.h" - -#include - -#include "memory/kv_cache.h" -#include "models/parameters.h" - -namespace llm { - -FlashInferHandler::FlashInferHandler(float scale, - int64_t rotary_dim, - int64_t max_position, - float rope_scaling, - float rope_theta, - bool interleaved, - const torch::TensorOptions& options) { - LOG(FATAL) << "Not implemented yet"; -} - -FlashInferHandler::FlashInferHandler( - float scale, - torch::optional alibi_slopes) - : scale_(scale), alibi_slopes_(alibi_slopes) {} - -// batch decode for attention, optimized for decode stage -// support multiple queries: one sequence with multiple query tokens -void FlashInferHandler::batch_decode( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const KVCache& kv_cache, // where to retrieval key and value - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) { - // TODO: add implementation - LOG(FATAL) << "Not implemented yet"; -} - -// append key and value to kv_cache -void FlashInferHandler::append_kv_cache( - KVCache& kv_cache, // where to store key and value - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params) { - // TODO: add implementation - LOG(FATAL) << "Not implemented yet"; -} - -} // namespace llm diff --git a/src/layers/attention/handler.cpp b/src/layers/attention/handler.cpp index 37a0cff4..ebc15277 100644 --- a/src/layers/attention/handler.cpp +++ b/src/layers/attention/handler.cpp @@ -8,7 +8,7 @@ #include #include "flash_attn_handler.h" -#include "flash_infer_handler.h" +#include "scale_attn_handler.h" #include "layers/pos_embedding.h" #include "ref_handler.h" @@ -73,7 +73,7 @@ std::unique_ptr AttentionHandler::create_handler_with_alibi( // choose the best handler based on device type if (is_cuda) { // use flash_attn for cuda device - return std::make_unique( + return std::make_unique( sm_scale, args.attn_logit_soft_cap(), alibi_slopes); } @@ -125,7 +125,7 @@ std::unique_ptr AttentionHandler::create_handler_with_rope( // choose the best handler based on device type if (is_cuda) { // use flash_attn for cuda device - return std::make_unique(sm_scale, + return std::make_unique(sm_scale, args.attn_logit_soft_cap(), rotary_dim, args.max_position_embeddings(), diff --git a/src/layers/attention/scale_attn_handler.cpp b/src/layers/attention/scale_attn_handler.cpp new file mode 100644 index 00000000..fb9a9343 --- /dev/null +++ b/src/layers/attention/scale_attn_handler.cpp @@ -0,0 +1,80 @@ +#include "scale_attn_handler.h" + +#include + +#include "kernels/attention/attn_api.h" +#include "memory/kv_cache.h" +#include "models/parameters.h" + +namespace llm { + +ScaleAttnHandler::ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + int64_t rotary_dim, + int64_t max_position, + torch::Tensor inv_freq, + bool interleaved, + const torch::TensorOptions& options) + : sm_scale_(sm_scale), logits_soft_cap_(logits_soft_cap) { + // register rotary positional embedding + pos_emb_ = + RotaryEmbedding(rotary_dim, max_position, inv_freq, interleaved, options); +} + +ScaleAttnHandler::ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + torch::optional alibi_slopes) + : sm_scale_(sm_scale), + logits_soft_cap_(logits_soft_cap), + alibi_slopes_(alibi_slopes) {} + +std::tuple ScaleAttnHandler::apply_pos_emb( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& positions) { + // for alibi scenarios, the pos_emb_ is not defined + if (positions.defined() && pos_emb_) { + return pos_emb_(query, key, positions); + } + return {query, key}; +} + +// batch decode for attention, optimized for decode stage +// support multiple queries: one sequence with multiple query tokens +void ScaleAttnHandler::batch_decode( + const torch::Tensor& query, // [n_tokens, n_heads, head_dim] + const KVCache& kv_cache, // where to retrieval key and value + const InputParameters& input_params, // input paras used for attention + int32_t sliding_window, // sliding window size + torch::Tensor& output) { + auto [key_cache, value_cache, block_size] = kv_cache.get_kv_cache_slot_view(); + paged_kv_varlen_mha(output, + query, + key_cache, + value_cache, + input_params.q_cu_seq_lens, + input_params.kv_cu_seq_lens, + input_params.block_tables, + input_params.cu_block_lens, + alibi_slopes_, + block_size, + input_params.q_max_seq_len, + input_params.kv_max_seq_len, + sm_scale_, + logits_soft_cap_, + sliding_window); +} + +// append key and value to kv_cache +void ScaleAttnHandler::append_kv_cache( + KVCache& kv_cache, // where to store key and value + const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] + const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] + const InputParameters& input_params) { + // append key and value to kv_cache + if (!kv_cache.empty()) { + kv_cache.set_kv_cache(input_params.new_cache_slots, key, value); + } +} + +} // namespace llm diff --git a/src/layers/attention/flash_infer_handler.h b/src/layers/attention/scale_attn_handler.h similarity index 55% rename from src/layers/attention/flash_infer_handler.h rename to src/layers/attention/scale_attn_handler.h index acd75fdf..cb7e7203 100644 --- a/src/layers/attention/flash_infer_handler.h +++ b/src/layers/attention/scale_attn_handler.h @@ -1,37 +1,41 @@ #pragma once +#include #include #include "handler.h" +#include "layers/pos_embedding.h" #include "memory/kv_cache.h" #include "models/parameters.h" namespace llm { // an flash attn implementation for attention operations -class FlashInferHandler : public AttentionHandler { +class ScaleAttnHandler : public AttentionHandler { public: // create a flash attn handler with rope positional embedding - FlashInferHandler(float scale, - int64_t rotary_dim, - int64_t max_position, - float rope_scaling, - float rope_theta, - bool interleaved, - const torch::TensorOptions& options); + ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + int64_t rotary_dim, + int64_t max_position, + torch::Tensor inv_freq, + bool interleaved, + const torch::TensorOptions& options); - // constructor for attention with alibi - FlashInferHandler(float scale, torch::optional alibi_slopes); + // create a flash attn handler with alibi slopes + ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + torch::optional alibi_slopes); - virtual ~FlashInferHandler() = default; + ~ScaleAttnHandler() override = default; + + // set workspace for temporary storage before calling any attention operations + void set_workspace(const torch::Tensor& workspace) override {} std::tuple apply_pos_emb( const torch::Tensor& query, const torch::Tensor& key, - const torch::Tensor& /*positions*/) override { - // no positional embedding since we will apply pos emb on the fly - return {query, key}; - } + const torch::Tensor& positions) override; // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens @@ -50,10 +54,16 @@ class FlashInferHandler : public AttentionHandler { const InputParameters& input_params) override; private: - // scale factor - float scale_ = 0.0; + // softmax scale factor + float sm_scale_ = 0.0; + + // logits softcap + float logits_soft_cap_ = 0.0; + + // ROPE positional embedding + RotaryEmbedding pos_emb_{nullptr}; - // alibi slops + // alibi slopes torch::optional alibi_slopes_; }; diff --git a/src/memory/kv_cache.h b/src/memory/kv_cache.h index 6c273d65..fd3d7041 100644 --- a/src/memory/kv_cache.h +++ b/src/memory/kv_cache.h @@ -24,6 +24,13 @@ class KVCache final { return {key_cache_, value_cache_}; } + std::tuple get_kv_cache_slot_view() + const { + return {key_cache_.view({-1, num_kv_heads_, head_size_}), + value_cache_.view({-1, num_kv_heads_, head_size_}), + block_size_}; + } + // set key and value cache for the given slot_ids // the slot_ids are the indices of the key/value cache, [num_slots] IntTensor // keys/values: [num_slots, num_heads, head_dim]