diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 9e19b1fa..e99160ce 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -79,6 +79,4 @@ cc_binary( -lineinfo ) -add_subdirectory(flash_attn) -# add_subdirectory(flash_infer) add_subdirectory(tools) \ No newline at end of file diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index 8f5bc505..df6dafb6 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -58,7 +58,6 @@ cc_library( :pos_embedding :attention :kernels - :flash_attn.kernels glog::glog gflags::gflags torch diff --git a/src/layers/attention/CMakeLists.txt b/src/layers/attention/CMakeLists.txt index 20baea3c..2c8d25ad 100644 --- a/src/layers/attention/CMakeLists.txt +++ b/src/layers/attention/CMakeLists.txt @@ -8,13 +8,11 @@ cc_library( HDRS handler.h ref_handler.h - flash_attn_handler.h scale_attn_handler.h attention.h SRCS handler.cpp ref_handler.cpp - flash_attn_handler.cpp scale_attn_handler.cpp attention.cpp DEPS @@ -22,7 +20,6 @@ cc_library( :memory :pos_embedding :kernels - :flash_attn.kernels :attention.kernels glog::glog gflags::gflags diff --git a/src/layers/attention/flash_attn_handler.cpp b/src/layers/attention/flash_attn_handler.cpp deleted file mode 100644 index 3bd1107c..00000000 --- a/src/layers/attention/flash_attn_handler.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include "flash_attn_handler.h" - -#include -#include - -#include "kernels/attention/flash_attn/flash_api.h" -#include "memory/kv_cache.h" -#include "models/parameters.h" - -namespace llm { - -FlashAttnHandler::FlashAttnHandler(float sm_scale, - float logits_soft_cap, - int64_t rotary_dim, - int64_t max_position, - torch::Tensor inv_freq, - bool interleaved, - const torch::TensorOptions& options) - : sm_scale_(sm_scale), logits_soft_cap_(logits_soft_cap) { - // register rotary positional embedding - pos_emb_ = - RotaryEmbedding(rotary_dim, max_position, inv_freq, interleaved, options); -} - -FlashAttnHandler::FlashAttnHandler(float sm_scale, - float logits_soft_cap, - torch::optional alibi_slopes) - : sm_scale_(sm_scale), - logits_soft_cap_(logits_soft_cap), - alibi_slopes_(alibi_slopes) {} - -FlashAttnHandler::~FlashAttnHandler() {} - -std::tuple FlashAttnHandler::apply_pos_emb( - const torch::Tensor& query, - const torch::Tensor& key, - const torch::Tensor& positions) { - // for alibi scenarios, the pos_emb_ is not defined - if (positions.defined() && pos_emb_) { - return pos_emb_(query, key, positions); - } - return {query, key}; -} - -// batch decode for attention, optimized for decode stage -// support multiple queries: one sequence with multiple query tokens -void FlashAttnHandler::batch_decode( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const KVCache& kv_cache, // where to retrieval key and value - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) { - auto [key_cache, value_cache] = kv_cache.get_kv_cache(); - mha_varlen_fwd(output, - query, - key_cache, - value_cache, - input_params.q_cu_seq_lens, - input_params.kv_cu_seq_lens, - input_params.block_tables, - input_params.cu_block_lens, - alibi_slopes_, - input_params.q_max_seq_len, - input_params.kv_max_seq_len, - sm_scale_, - logits_soft_cap_, - /*window_size_left=*/sliding_window, - /*window_size_right=*/0, - /*num_splits=*/0); -} - -// append key and value to kv_cache -void FlashAttnHandler::append_kv_cache( - KVCache& kv_cache, // where to store key and value - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params) { - // append key and value to kv_cache - if (!kv_cache.empty()) { - kv_cache.set_kv_cache(input_params.new_cache_slots, key, value); - } -} - -} // namespace llm diff --git a/src/layers/attention/flash_attn_handler.h b/src/layers/attention/flash_attn_handler.h deleted file mode 100644 index 2fa8646d..00000000 --- a/src/layers/attention/flash_attn_handler.h +++ /dev/null @@ -1,71 +0,0 @@ -#pragma once - -#include -#include - -#include "handler.h" -#include "layers/pos_embedding.h" -#include "memory/kv_cache.h" -#include "models/parameters.h" - -namespace llm { - -// an flash attn implementation for attention operations -class FlashAttnHandler : public AttentionHandler { - public: - // create a flash attn handler with rope positional embedding - FlashAttnHandler(float sm_scale, - float logits_soft_cap, - int64_t rotary_dim, - int64_t max_position, - torch::Tensor inv_freq, - bool interleaved, - const torch::TensorOptions& options); - - // create a flash attn handler with alibi slopes - FlashAttnHandler(float sm_scale, - float logits_soft_cap, - torch::optional alibi_slopes); - - ~FlashAttnHandler() override; - - // set workspace for temporary storage before calling any attention operations - void set_workspace(const torch::Tensor& workspace) override {} - - // apply positional embedding to query and key if needed - std::tuple apply_pos_emb( - const torch::Tensor& query, - const torch::Tensor& key, - const torch::Tensor& positions) override; - - // batch decode for attention, optimized for decode stage - // support multiple queries: one sequence with multiple query tokens - void batch_decode( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const KVCache& kv_cache, // where to store and retrieval key and value - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) override; - - // append key and value to kv_cache - void append_kv_cache( - KVCache& kv_cache, // where to store and retrieval key and value - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params) override; - - private: - // softmax scale factor - float sm_scale_ = 0.0; - - // logits softcap - float logits_soft_cap_ = 0.0; - - // ROPE positional embedding - RotaryEmbedding pos_emb_{nullptr}; - - // alibi slopes - torch::optional alibi_slopes_; -}; - -} // namespace llm diff --git a/src/layers/attention/handler.cpp b/src/layers/attention/handler.cpp index ebc15277..d26443ca 100644 --- a/src/layers/attention/handler.cpp +++ b/src/layers/attention/handler.cpp @@ -7,7 +7,6 @@ #include #include -#include "flash_attn_handler.h" #include "scale_attn_handler.h" #include "layers/pos_embedding.h" #include "ref_handler.h" @@ -15,7 +14,7 @@ // decide which attention implementation to use DEFINE_string(attention_handler, "auto", - "attention handler, e.g. auto, pytorch, flash_attn"); + "attention handler, e.g. auto, pytorch, scale_attn"); namespace llm { @@ -64,15 +63,15 @@ std::unique_ptr AttentionHandler::create_handler_with_alibi( } const bool is_cuda = options.device().is_cuda(); - if (boost::iequals(FLAGS_attention_handler, "flash_attn")) { - CHECK(is_cuda) << "flash_attn only supports cuda device"; - return std::make_unique( + if (boost::iequals(FLAGS_attention_handler, "scale_attn")) { + CHECK(is_cuda) << "scale_attn only supports cuda device"; + return std::make_unique( sm_scale, args.attn_logit_soft_cap(), alibi_slopes); } // choose the best handler based on device type if (is_cuda) { - // use flash_attn for cuda device + // use scale_attn for cuda device return std::make_unique( sm_scale, args.attn_logit_soft_cap(), alibi_slopes); } @@ -111,9 +110,9 @@ std::unique_ptr AttentionHandler::create_handler_with_rope( } const bool is_cuda = options.device().is_cuda(); - if (boost::iequals(FLAGS_attention_handler, "flash_attn")) { - CHECK(is_cuda) << "flash_attn only supports cuda device"; - return std::make_unique(sm_scale, + if (boost::iequals(FLAGS_attention_handler, "scale_attn")) { + CHECK(is_cuda) << "scale_attn only supports cuda device"; + return std::make_unique(sm_scale, args.attn_logit_soft_cap(), rotary_dim, args.max_position_embeddings(), @@ -124,7 +123,7 @@ std::unique_ptr AttentionHandler::create_handler_with_rope( // choose the best handler based on device type if (is_cuda) { - // use flash_attn for cuda device + // use scale_attn for cuda device return std::make_unique(sm_scale, args.attn_logit_soft_cap(), rotary_dim,