77#include < boost/algorithm/string.hpp>
88#include < memory>
99
10- #include " flash_attn_handler.h"
1110#include " scale_attn_handler.h"
1211#include " layers/pos_embedding.h"
1312#include " ref_handler.h"
1413
1514// decide which attention implementation to use
1615DEFINE_string (attention_handler,
1716 " auto" ,
18- " attention handler, e.g. auto, pytorch, flash_attn " );
17+ " attention handler, e.g. auto, pytorch, scale_attn " );
1918
2019namespace llm {
2120
@@ -64,15 +63,15 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_alibi(
6463 }
6564
6665 const bool is_cuda = options.device ().is_cuda ();
67- if (boost::iequals (FLAGS_attention_handler, " flash_attn " )) {
68- CHECK (is_cuda) << " flash_attn only supports cuda device" ;
69- return std::make_unique<FlashAttnHandler >(
66+ if (boost::iequals (FLAGS_attention_handler, " scale_attn " )) {
67+ CHECK (is_cuda) << " scale_attn only supports cuda device" ;
68+ return std::make_unique<ScaleAttnHandler >(
7069 sm_scale, args.attn_logit_soft_cap (), alibi_slopes);
7170 }
7271
7372 // choose the best handler based on device type
7473 if (is_cuda) {
75- // use flash_attn for cuda device
74+ // use scale_attn for cuda device
7675 return std::make_unique<ScaleAttnHandler>(
7776 sm_scale, args.attn_logit_soft_cap (), alibi_slopes);
7877 }
@@ -111,9 +110,9 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
111110 }
112111
113112 const bool is_cuda = options.device ().is_cuda ();
114- if (boost::iequals (FLAGS_attention_handler, " flash_attn " )) {
115- CHECK (is_cuda) << " flash_attn only supports cuda device" ;
116- return std::make_unique<FlashAttnHandler >(sm_scale,
113+ if (boost::iequals (FLAGS_attention_handler, " scale_attn " )) {
114+ CHECK (is_cuda) << " scale_attn only supports cuda device" ;
115+ return std::make_unique<ScaleAttnHandler >(sm_scale,
117116 args.attn_logit_soft_cap (),
118117 rotary_dim,
119118 args.max_position_embeddings (),
@@ -124,7 +123,7 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
124123
125124 // choose the best handler based on device type
126125 if (is_cuda) {
127- // use flash_attn for cuda device
126+ // use scale_attn for cuda device
128127 return std::make_unique<ScaleAttnHandler>(sm_scale,
129128 args.attn_logit_soft_cap (),
130129 rotary_dim,
0 commit comments