Skip to content

Commit 790e5ba

Browse files
authored
refactor: skip flash_attn build (#388)
1 parent 46b0401 commit 790e5ba

File tree

6 files changed

+9
-171
lines changed

6 files changed

+9
-171
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,4 @@ cc_binary(
7979
-lineinfo
8080
)
8181

82-
add_subdirectory(flash_attn)
83-
# add_subdirectory(flash_infer)
8482
add_subdirectory(tools)

src/layers/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ cc_library(
5858
:pos_embedding
5959
:attention
6060
:kernels
61-
:flash_attn.kernels
6261
glog::glog
6362
gflags::gflags
6463
torch

src/layers/attention/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,18 @@ cc_library(
88
HDRS
99
handler.h
1010
ref_handler.h
11-
flash_attn_handler.h
1211
scale_attn_handler.h
1312
attention.h
1413
SRCS
1514
handler.cpp
1615
ref_handler.cpp
17-
flash_attn_handler.cpp
1816
scale_attn_handler.cpp
1917
attention.cpp
2018
DEPS
2119
:state_dict
2220
:memory
2321
:pos_embedding
2422
:kernels
25-
:flash_attn.kernels
2623
:attention.kernels
2724
glog::glog
2825
gflags::gflags

src/layers/attention/flash_attn_handler.cpp

Lines changed: 0 additions & 84 deletions
This file was deleted.

src/layers/attention/flash_attn_handler.h

Lines changed: 0 additions & 71 deletions
This file was deleted.

src/layers/attention/handler.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77
#include <boost/algorithm/string.hpp>
88
#include <memory>
99

10-
#include "flash_attn_handler.h"
1110
#include "scale_attn_handler.h"
1211
#include "layers/pos_embedding.h"
1312
#include "ref_handler.h"
1413

1514
// decide which attention implementation to use
1615
DEFINE_string(attention_handler,
1716
"auto",
18-
"attention handler, e.g. auto, pytorch, flash_attn");
17+
"attention handler, e.g. auto, pytorch, scale_attn");
1918

2019
namespace llm {
2120

@@ -64,15 +63,15 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_alibi(
6463
}
6564

6665
const bool is_cuda = options.device().is_cuda();
67-
if (boost::iequals(FLAGS_attention_handler, "flash_attn")) {
68-
CHECK(is_cuda) << "flash_attn only supports cuda device";
69-
return std::make_unique<FlashAttnHandler>(
66+
if (boost::iequals(FLAGS_attention_handler, "scale_attn")) {
67+
CHECK(is_cuda) << "scale_attn only supports cuda device";
68+
return std::make_unique<ScaleAttnHandler>(
7069
sm_scale, args.attn_logit_soft_cap(), alibi_slopes);
7170
}
7271

7372
// choose the best handler based on device type
7473
if (is_cuda) {
75-
// use flash_attn for cuda device
74+
// use scale_attn for cuda device
7675
return std::make_unique<ScaleAttnHandler>(
7776
sm_scale, args.attn_logit_soft_cap(), alibi_slopes);
7877
}
@@ -111,9 +110,9 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
111110
}
112111

113112
const bool is_cuda = options.device().is_cuda();
114-
if (boost::iequals(FLAGS_attention_handler, "flash_attn")) {
115-
CHECK(is_cuda) << "flash_attn only supports cuda device";
116-
return std::make_unique<FlashAttnHandler>(sm_scale,
113+
if (boost::iequals(FLAGS_attention_handler, "scale_attn")) {
114+
CHECK(is_cuda) << "scale_attn only supports cuda device";
115+
return std::make_unique<ScaleAttnHandler>(sm_scale,
117116
args.attn_logit_soft_cap(),
118117
rotary_dim,
119118
args.max_position_embeddings(),
@@ -124,7 +123,7 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
124123

125124
// choose the best handler based on device type
126125
if (is_cuda) {
127-
// use flash_attn for cuda device
126+
// use scale_attn for cuda device
128127
return std::make_unique<ScaleAttnHandler>(sm_scale,
129128
args.attn_logit_soft_cap(),
130129
rotary_dim,

0 commit comments

Comments
 (0)