From 5151644b90d27e44c4aa8794868a12d57a6c7121 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 20 Jan 2025 20:44:11 -0800 Subject: [PATCH 1/2] attn: integrate in-house scale attention and use it by default --- .../attention/attention_kernel_sm80_test.cu | 15 +++++++--- .../attention/attention_launch_sm80.cuh | 12 ++++---- src/kernels/attention/attn_api.cpp | 20 +++++++++++++ src/kernels/attention/attn_api.h | 29 +++++++++++++++++++ src/layers/attention/CMakeLists.txt | 5 ++-- ...fer_handler.cpp => scale_attn_handler.cpp} | 12 ++++---- ...h_infer_handler.h => scale_attn_handler.h} | 10 +++---- 7 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 src/kernels/attention/attn_api.cpp create mode 100644 src/kernels/attention/attn_api.h rename src/layers/attention/{flash_infer_handler.cpp => scale_attn_handler.cpp} (89%) rename src/layers/attention/{flash_infer_handler.h => scale_attn_handler.h} (89%) diff --git a/src/kernels/attention/attention_kernel_sm80_test.cu b/src/kernels/attention/attention_kernel_sm80_test.cu index f20e994a..9329a2d1 100644 --- a/src/kernels/attention/attention_kernel_sm80_test.cu +++ b/src/kernels/attention/attention_kernel_sm80_test.cu @@ -94,7 +94,8 @@ torch::Tensor attention_sm80( } // namespace class AttentionKernelTest - : public ::testing::TestWithParam +template void run_attention_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { // normalize params that for performance optimization params.normalize(); // TODO: tune block shape MNK based on the head dim and smem size if constexpr (HEAD_DIM == 64) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 96) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 128) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 256) { - using Traits = AttentionTraitsSM80(params, stream); } else { // use the default block size - using Traits = AttentionTraitsSM80& alibi_slopes, // [n_heads] + int block_size, + int max_q_len, + int max_kv_len, + float sm_scale, + float logits_soft_cap, + int sliding_window) {} + +} // namespace llm diff --git a/src/kernels/attention/attn_api.h b/src/kernels/attention/attn_api.h new file mode 100644 index 00000000..83b5a717 --- /dev/null +++ b/src/kernels/attention/attn_api.h @@ -0,0 +1,29 @@ +#include +#include + +namespace llm { +// the input tensors are packed into one-dimensional tensors, and the sequence +// lengths are stored in q_cu_lens and k_cu_lens. +// for each sequence, +// the starting offset: q/kv_cu_lens[i] +// the length: q/kv_cu_lens[i+1] - q/kv_cu_lens[i]. +// the maximum sequence length is max_q_len and max_kv_len, which are used +// to decide the kernel dispatch. +void paged_kv_varlen_mha( + torch::Tensor& out, // [n_tokens, n_heads, head_dim] + const torch::Tensor& q, // [n_tokens, n_heads, head_dim] + const torch::Tensor& k_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& v_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& q_cu_lens, // [batch + 1] + const torch::Tensor& kv_cu_lens, // [batch + 1] + const torch::Tensor& block_table, + const torch::Tensor& block_cu_lens, // [batch + 1] + const std::optional& alibi_slopes, // [n_heads] + int block_size, + int max_q_len, + int max_kv_len, + float sm_scale, + float logits_soft_cap, + int sliding_window); + +} // namespace llm \ No newline at end of file diff --git a/src/layers/attention/CMakeLists.txt b/src/layers/attention/CMakeLists.txt index f0e25165..77a9e8f2 100644 --- a/src/layers/attention/CMakeLists.txt +++ b/src/layers/attention/CMakeLists.txt @@ -9,13 +9,13 @@ cc_library( handler.h ref_handler.h flash_attn_handler.h - flash_infer_handler.h + scale_attn_handler.h attention.h SRCS handler.cpp ref_handler.cpp flash_attn_handler.cpp - flash_infer_handler.cpp + scale_attn_handler.cpp attention.cpp DEPS :state_dict @@ -23,7 +23,6 @@ cc_library( :pos_embedding :kernels :flash_attn.kernels - # :flash_infer.kernels glog::glog gflags::gflags torch diff --git a/src/layers/attention/flash_infer_handler.cpp b/src/layers/attention/scale_attn_handler.cpp similarity index 89% rename from src/layers/attention/flash_infer_handler.cpp rename to src/layers/attention/scale_attn_handler.cpp index 84937457..16f2329d 100644 --- a/src/layers/attention/flash_infer_handler.cpp +++ b/src/layers/attention/scale_attn_handler.cpp @@ -1,4 +1,4 @@ -#include "flash_infer_handler.h" +#include "scale_attn_handler.h" #include @@ -7,7 +7,7 @@ namespace llm { -FlashInferHandler::FlashInferHandler(float scale, +ScaleAttnHandler::ScaleAttnHandler(float scale, int64_t rotary_dim, int64_t max_position, float rope_scaling, @@ -17,13 +17,13 @@ FlashInferHandler::FlashInferHandler(float scale, LOG(FATAL) << "Not implemented yet"; } -FlashInferHandler::FlashInferHandler( +ScaleAttnHandler::ScaleAttnHandler( float scale, torch::optional alibi_slopes) : scale_(scale), alibi_slopes_(alibi_slopes) {} // batch prefill for attention, optimized for prefill stage -void FlashInferHandler::batch_prefill( +void ScaleAttnHandler::batch_prefill( const torch::Tensor& query, // [n_tokens, n_heads, head_dim] const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] @@ -36,7 +36,7 @@ void FlashInferHandler::batch_prefill( // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens -void FlashInferHandler::batch_decode( +void ScaleAttnHandler::batch_decode( const torch::Tensor& query, // [n_tokens, n_heads, head_dim] const KVCache& kv_cache, // where to retrieval key and value const InputParameters& input_params, // input paras used for attention @@ -47,7 +47,7 @@ void FlashInferHandler::batch_decode( } // append key and value to kv_cache -void FlashInferHandler::append_kv_cache( +void ScaleAttnHandler::append_kv_cache( KVCache& kv_cache, // where to store key and value const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] diff --git a/src/layers/attention/flash_infer_handler.h b/src/layers/attention/scale_attn_handler.h similarity index 89% rename from src/layers/attention/flash_infer_handler.h rename to src/layers/attention/scale_attn_handler.h index f02771ff..c33b17b8 100644 --- a/src/layers/attention/flash_infer_handler.h +++ b/src/layers/attention/scale_attn_handler.h @@ -9,10 +9,10 @@ namespace llm { // an flash attn implementation for attention operations -class FlashInferHandler : public AttentionHandler { +class ScaleAttnHandler : public AttentionHandler { public: // create a flash attn handler with rope positional embedding - FlashInferHandler(float scale, + ScaleAttnHandler(float scale, int64_t rotary_dim, int64_t max_position, float rope_scaling, @@ -21,9 +21,9 @@ class FlashInferHandler : public AttentionHandler { const torch::TensorOptions& options); // constructor for attention with alibi - FlashInferHandler(float scale, torch::optional alibi_slopes); + ScaleAttnHandler(float scale, std::optional alibi_slopes); - virtual ~FlashInferHandler() = default; + virtual ~ScaleAttnHandler() = default; std::tuple apply_pos_emb( const torch::Tensor& query, @@ -63,7 +63,7 @@ class FlashInferHandler : public AttentionHandler { float scale_ = 0.0; // alibi slops - torch::optional alibi_slopes_; + std::optional alibi_slopes_; }; } // namespace llm From b59c4038f36f7fc471672a3324324e280c2670f2 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 24 Jan 2025 16:11:01 -0800 Subject: [PATCH 2/2] use scale attention by default --- src/kernels/attention/CMakeLists.txt | 6 +- src/kernels/attention/attn_api.cpp | 70 ++++++++++++++++++--- src/kernels/attention/attn_api.h | 12 ++-- src/layers/attention/CMakeLists.txt | 1 + src/layers/attention/attention_test.cpp | 8 +-- src/layers/attention/flash_attn_handler.h | 3 - src/layers/attention/handler.cpp | 6 +- src/layers/attention/scale_attn_handler.cpp | 65 ++++++++++++++----- src/layers/attention/scale_attn_handler.h | 46 ++++++++------ src/memory/kv_cache.h | 7 +++ 10 files changed, 164 insertions(+), 60 deletions(-) diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index d9c72a63..9e19b1fa 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -37,11 +37,11 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu") cc_library( NAME - attention.kernel + attention.kernels HDRS - # attention.h + attn_api.h SRCS - # attention.cpp + attn_api.cpp ${GENERATED_SRC_FILES} INCLUDES ${CMAKE_CURRENT_SOURCE_DIR} diff --git a/src/kernels/attention/attn_api.cpp b/src/kernels/attention/attn_api.cpp index 8e2eafe8..be32b56f 100644 --- a/src/kernels/attention/attn_api.cpp +++ b/src/kernels/attention/attn_api.cpp @@ -1,12 +1,25 @@ #include "attn_api.h" + +#include + +#include "attention_params.h" +#include "cute/layout.hpp" +#include "static_dispatch.h" + namespace llm { +using namespace cute; + +// forward declaration +template +void run_attention_kernel_sm80(Params& params, cudaStream_t stream); + void paged_kv_varlen_mha( - torch::Tensor& out, // [n_tokens, n_heads, head_dim] - const torch::Tensor& q, // [n_tokens, n_heads, head_dim] - const torch::Tensor& k_cache, // [n_slots, n_kv_heads, head_dim] - const torch::Tensor& v_cache, // [n_slots, n_kv_heads, head_dim] - const torch::Tensor& q_cu_lens, // [batch + 1] - const torch::Tensor& kv_cu_lens, // [batch + 1] + torch::Tensor& out, // [n_tokens, n_heads, head_dim] + const torch::Tensor& query, // [n_tokens, n_heads, head_dim] + const torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& value_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& q_cu_lens, // [batch + 1] + const torch::Tensor& kv_cu_lens, // [batch + 1] const torch::Tensor& block_table, const torch::Tensor& block_cu_lens, // [batch + 1] const std::optional& alibi_slopes, // [n_heads] @@ -15,6 +28,49 @@ void paged_kv_varlen_mha( int max_kv_len, float sm_scale, float logits_soft_cap, - int sliding_window) {} + int sliding_window) { + const int batch_size = q_cu_lens.size(0) - 1; + const int n_heads = query.size(-2); + const int n_kv_heads = key_cache.size(-2); + const int head_dim = query.size(-1); + + // construct attention params + PagedKVAttentionParams params; + params.q_ptr = query.const_data_ptr(); + params.q_stride = make_stride(query.stride(0), query.stride(1)); + params.k_ptr = key_cache.const_data_ptr(); + params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1)); + params.v_ptr = value_cache.const_data_ptr(); + params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1)); + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.alibi_slopes_ptr = alibi_slopes.has_value() + ? alibi_slopes.value().const_data_ptr() + : nullptr; + params.batch_size = batch_size; + params.block_size = block_size; + params.max_q_len = max_q_len; + (void)max_kv_len; // unused + params.n_heads = n_heads; + params.n_kv_heads = n_kv_heads; + params.head_dim = head_dim; + + params.sm_scale = sm_scale; + params.logits_soft_cap = logits_soft_cap; + params.sliding_window = sliding_window; + + params.q_cu_lens = q_cu_lens.const_data_ptr(); + params.kv_cu_lens = kv_cu_lens.const_data_ptr(); + + params.block_table = block_table.const_data_ptr(); + params.block_cu_lens = block_cu_lens.const_data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { + DISPATCH_TORCH_DTYPE(query.scalar_type(), DTYPE, [&] { + run_attention_kernel_sm80(params, stream); + }); + }); +} } // namespace llm diff --git a/src/kernels/attention/attn_api.h b/src/kernels/attention/attn_api.h index 83b5a717..d5e1bd22 100644 --- a/src/kernels/attention/attn_api.h +++ b/src/kernels/attention/attn_api.h @@ -10,12 +10,12 @@ namespace llm { // the maximum sequence length is max_q_len and max_kv_len, which are used // to decide the kernel dispatch. void paged_kv_varlen_mha( - torch::Tensor& out, // [n_tokens, n_heads, head_dim] - const torch::Tensor& q, // [n_tokens, n_heads, head_dim] - const torch::Tensor& k_cache, // [n_slots, n_kv_heads, head_dim] - const torch::Tensor& v_cache, // [n_slots, n_kv_heads, head_dim] - const torch::Tensor& q_cu_lens, // [batch + 1] - const torch::Tensor& kv_cu_lens, // [batch + 1] + torch::Tensor& out, // [n_tokens, n_heads, head_dim] + const torch::Tensor& query, // [n_tokens, n_heads, head_dim] + const torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& value_cache, // [n_slots, n_kv_heads, head_dim] + const torch::Tensor& q_cu_lens, // [batch + 1] + const torch::Tensor& kv_cu_lens, // [batch + 1] const torch::Tensor& block_table, const torch::Tensor& block_cu_lens, // [batch + 1] const std::optional& alibi_slopes, // [n_heads] diff --git a/src/layers/attention/CMakeLists.txt b/src/layers/attention/CMakeLists.txt index 77a9e8f2..20baea3c 100644 --- a/src/layers/attention/CMakeLists.txt +++ b/src/layers/attention/CMakeLists.txt @@ -23,6 +23,7 @@ cc_library( :pos_embedding :kernels :flash_attn.kernels + :attention.kernels glog::glog gflags::gflags torch diff --git a/src/layers/attention/attention_test.cpp b/src/layers/attention/attention_test.cpp index a6364dac..f1be201c 100644 --- a/src/layers/attention/attention_test.cpp +++ b/src/layers/attention/attention_test.cpp @@ -11,7 +11,7 @@ #include -#include "flash_attn_handler.h" +#include "scale_attn_handler.h" #include "gtest/gtest.h" #include "models/parameters.h" #include "ref_handler.h" @@ -229,10 +229,10 @@ TEST_P(AttentionDecodeTest, KVCache) { ref_handler.batch_prefill( query, key, value, input_params, sliding_window, ref_output); - // flash attn handler - FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes); + // attn handler + ScaleAttnHandler attn_handler(sm_scale, logits_soft_cap, alibi_slopes); torch::Tensor output = torch::empty_like(query); - flash_attn_handler.batch_decode( + attn_handler.batch_decode( query, kv_cache, input_params, sliding_window, output); const bool success = diff --git a/src/layers/attention/flash_attn_handler.h b/src/layers/attention/flash_attn_handler.h index 0f1391eb..2fa8646d 100644 --- a/src/layers/attention/flash_attn_handler.h +++ b/src/layers/attention/flash_attn_handler.h @@ -66,9 +66,6 @@ class FlashAttnHandler : public AttentionHandler { // alibi slopes torch::optional alibi_slopes_; - - // stream for kv cache - cudaStream_t stream_ = nullptr; }; } // namespace llm diff --git a/src/layers/attention/handler.cpp b/src/layers/attention/handler.cpp index 37a0cff4..ebc15277 100644 --- a/src/layers/attention/handler.cpp +++ b/src/layers/attention/handler.cpp @@ -8,7 +8,7 @@ #include #include "flash_attn_handler.h" -#include "flash_infer_handler.h" +#include "scale_attn_handler.h" #include "layers/pos_embedding.h" #include "ref_handler.h" @@ -73,7 +73,7 @@ std::unique_ptr AttentionHandler::create_handler_with_alibi( // choose the best handler based on device type if (is_cuda) { // use flash_attn for cuda device - return std::make_unique( + return std::make_unique( sm_scale, args.attn_logit_soft_cap(), alibi_slopes); } @@ -125,7 +125,7 @@ std::unique_ptr AttentionHandler::create_handler_with_rope( // choose the best handler based on device type if (is_cuda) { // use flash_attn for cuda device - return std::make_unique(sm_scale, + return std::make_unique(sm_scale, args.attn_logit_soft_cap(), rotary_dim, args.max_position_embeddings(), diff --git a/src/layers/attention/scale_attn_handler.cpp b/src/layers/attention/scale_attn_handler.cpp index 7d7487b1..fb9a9343 100644 --- a/src/layers/attention/scale_attn_handler.cpp +++ b/src/layers/attention/scale_attn_handler.cpp @@ -2,25 +2,42 @@ #include +#include "kernels/attention/attn_api.h" #include "memory/kv_cache.h" #include "models/parameters.h" namespace llm { -ScaleAttnHandler::ScaleAttnHandler(float scale, - int64_t rotary_dim, - int64_t max_position, - float rope_scaling, - float rope_theta, - bool interleaved, - const torch::TensorOptions& options) { - LOG(FATAL) << "Not implemented yet"; +ScaleAttnHandler::ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + int64_t rotary_dim, + int64_t max_position, + torch::Tensor inv_freq, + bool interleaved, + const torch::TensorOptions& options) + : sm_scale_(sm_scale), logits_soft_cap_(logits_soft_cap) { + // register rotary positional embedding + pos_emb_ = + RotaryEmbedding(rotary_dim, max_position, inv_freq, interleaved, options); } -ScaleAttnHandler::ScaleAttnHandler( - float scale, - torch::optional alibi_slopes) - : scale_(scale), alibi_slopes_(alibi_slopes) {} +ScaleAttnHandler::ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + torch::optional alibi_slopes) + : sm_scale_(sm_scale), + logits_soft_cap_(logits_soft_cap), + alibi_slopes_(alibi_slopes) {} + +std::tuple ScaleAttnHandler::apply_pos_emb( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& positions) { + // for alibi scenarios, the pos_emb_ is not defined + if (positions.defined() && pos_emb_) { + return pos_emb_(query, key, positions); + } + return {query, key}; +} // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens @@ -30,8 +47,22 @@ void ScaleAttnHandler::batch_decode( const InputParameters& input_params, // input paras used for attention int32_t sliding_window, // sliding window size torch::Tensor& output) { - // TODO: add implementation - LOG(FATAL) << "Not implemented yet"; + auto [key_cache, value_cache, block_size] = kv_cache.get_kv_cache_slot_view(); + paged_kv_varlen_mha(output, + query, + key_cache, + value_cache, + input_params.q_cu_seq_lens, + input_params.kv_cu_seq_lens, + input_params.block_tables, + input_params.cu_block_lens, + alibi_slopes_, + block_size, + input_params.q_max_seq_len, + input_params.kv_max_seq_len, + sm_scale_, + logits_soft_cap_, + sliding_window); } // append key and value to kv_cache @@ -40,8 +71,10 @@ void ScaleAttnHandler::append_kv_cache( const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] const InputParameters& input_params) { - // TODO: add implementation - LOG(FATAL) << "Not implemented yet"; + // append key and value to kv_cache + if (!kv_cache.empty()) { + kv_cache.set_kv_cache(input_params.new_cache_slots, key, value); + } } } // namespace llm diff --git a/src/layers/attention/scale_attn_handler.h b/src/layers/attention/scale_attn_handler.h index e2c63ce8..cb7e7203 100644 --- a/src/layers/attention/scale_attn_handler.h +++ b/src/layers/attention/scale_attn_handler.h @@ -1,8 +1,10 @@ #pragma once +#include #include #include "handler.h" +#include "layers/pos_embedding.h" #include "memory/kv_cache.h" #include "models/parameters.h" @@ -12,26 +14,28 @@ namespace llm { class ScaleAttnHandler : public AttentionHandler { public: // create a flash attn handler with rope positional embedding - ScaleAttnHandler(float scale, - int64_t rotary_dim, - int64_t max_position, - float rope_scaling, - float rope_theta, - bool interleaved, - const torch::TensorOptions& options); + ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + int64_t rotary_dim, + int64_t max_position, + torch::Tensor inv_freq, + bool interleaved, + const torch::TensorOptions& options); - // constructor for attention with alibi - ScaleAttnHandler(float scale, std::optional alibi_slopes); + // create a flash attn handler with alibi slopes + ScaleAttnHandler(float sm_scale, + float logits_soft_cap, + torch::optional alibi_slopes); - virtual ~ScaleAttnHandler() = default; + ~ScaleAttnHandler() override = default; + + // set workspace for temporary storage before calling any attention operations + void set_workspace(const torch::Tensor& workspace) override {} std::tuple apply_pos_emb( const torch::Tensor& query, const torch::Tensor& key, - const torch::Tensor& /*positions*/) override { - // no positional embedding since we will apply pos emb on the fly - return {query, key}; - } + const torch::Tensor& positions) override; // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens @@ -50,11 +54,17 @@ class ScaleAttnHandler : public AttentionHandler { const InputParameters& input_params) override; private: - // scale factor - float scale_ = 0.0; + // softmax scale factor + float sm_scale_ = 0.0; + + // logits softcap + float logits_soft_cap_ = 0.0; + + // ROPE positional embedding + RotaryEmbedding pos_emb_{nullptr}; - // alibi slops - std::optional alibi_slopes_; + // alibi slopes + torch::optional alibi_slopes_; }; } // namespace llm diff --git a/src/memory/kv_cache.h b/src/memory/kv_cache.h index 6c273d65..fd3d7041 100644 --- a/src/memory/kv_cache.h +++ b/src/memory/kv_cache.h @@ -24,6 +24,13 @@ class KVCache final { return {key_cache_, value_cache_}; } + std::tuple get_kv_cache_slot_view() + const { + return {key_cache_.view({-1, num_kv_heads_, head_size_}), + value_cache_.view({-1, num_kv_heads_, head_size_}), + block_size_}; + } + // set key and value cache for the given slot_ids // the slot_ids are the indices of the key/value cache, [num_slots] IntTensor // keys/values: [num_slots, num_heads, head_dim]