Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,4 @@ cc_binary(
-lineinfo
)

add_subdirectory(flash_attn)
# add_subdirectory(flash_infer)
add_subdirectory(tools)
1 change: 0 additions & 1 deletion src/layers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ cc_library(
:pos_embedding
:attention
:kernels
:flash_attn.kernels
glog::glog
gflags::gflags
torch
Expand Down
3 changes: 0 additions & 3 deletions src/layers/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,18 @@ cc_library(
HDRS
handler.h
ref_handler.h
flash_attn_handler.h
scale_attn_handler.h
attention.h
SRCS
handler.cpp
ref_handler.cpp
flash_attn_handler.cpp
scale_attn_handler.cpp
attention.cpp
DEPS
:state_dict
:memory
:pos_embedding
:kernels
:flash_attn.kernels
:attention.kernels
glog::glog
gflags::gflags
Expand Down
84 changes: 0 additions & 84 deletions src/layers/attention/flash_attn_handler.cpp

This file was deleted.

71 changes: 0 additions & 71 deletions src/layers/attention/flash_attn_handler.h

This file was deleted.

19 changes: 9 additions & 10 deletions src/layers/attention/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
#include <boost/algorithm/string.hpp>
#include <memory>

#include "flash_attn_handler.h"
#include "scale_attn_handler.h"
#include "layers/pos_embedding.h"
#include "ref_handler.h"

// decide which attention implementation to use
DEFINE_string(attention_handler,
"auto",
"attention handler, e.g. auto, pytorch, flash_attn");
"attention handler, e.g. auto, pytorch, scale_attn");

namespace llm {

Expand Down Expand Up @@ -64,15 +63,15 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_alibi(
}

const bool is_cuda = options.device().is_cuda();
if (boost::iequals(FLAGS_attention_handler, "flash_attn")) {
CHECK(is_cuda) << "flash_attn only supports cuda device";
return std::make_unique<FlashAttnHandler>(
if (boost::iequals(FLAGS_attention_handler, "scale_attn")) {
CHECK(is_cuda) << "scale_attn only supports cuda device";
return std::make_unique<ScaleAttnHandler>(
sm_scale, args.attn_logit_soft_cap(), alibi_slopes);
}

// choose the best handler based on device type
if (is_cuda) {
// use flash_attn for cuda device
// use scale_attn for cuda device
return std::make_unique<ScaleAttnHandler>(
sm_scale, args.attn_logit_soft_cap(), alibi_slopes);
}
Expand Down Expand Up @@ -111,9 +110,9 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
}

const bool is_cuda = options.device().is_cuda();
if (boost::iequals(FLAGS_attention_handler, "flash_attn")) {
CHECK(is_cuda) << "flash_attn only supports cuda device";
return std::make_unique<FlashAttnHandler>(sm_scale,
if (boost::iequals(FLAGS_attention_handler, "scale_attn")) {
CHECK(is_cuda) << "scale_attn only supports cuda device";
return std::make_unique<ScaleAttnHandler>(sm_scale,
args.attn_logit_soft_cap(),
rotary_dim,
args.max_position_embeddings(),
Expand All @@ -124,7 +123,7 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(

// choose the best handler based on device type
if (is_cuda) {
// use flash_attn for cuda device
// use scale_attn for cuda device
return std::make_unique<ScaleAttnHandler>(sm_scale,
args.attn_logit_soft_cap(),
rotary_dim,
Expand Down
Loading