Skip to content

Commit 5151644

Browse files
committed
attn: integrate in-house scale attention and use it by default
1 parent a6af766 commit 5151644

File tree

7 files changed

+79
-24
lines changed

7 files changed

+79
-24
lines changed

src/kernels/attention/attention_kernel_sm80_test.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ torch::Tensor attention_sm80(
9494
} // namespace
9595

9696
class AttentionKernelTest
97-
: public ::testing::TestWithParam<std::tuple<int64_t /*batch_size*/,
97+
: public ::testing::TestWithParam<std::tuple<torch::ScalarType /*q_dtype*/,
98+
int64_t /*batch_size*/,
9899
int64_t /*q_len*/,
99100
int64_t /*kv_len*/,
100101
int64_t /*n_heads*/,
@@ -111,7 +112,8 @@ class AttentionKernelTest
111112
};
112113

113114
TEST_P(AttentionKernelTest, MHA) {
114-
const auto [batch_size,
115+
const auto [dtype,
116+
batch_size,
115117
q_len,
116118
kv_len,
117119
n_heads,
@@ -121,7 +123,7 @@ TEST_P(AttentionKernelTest, MHA) {
121123
alibi,
122124
sliding_window] = GetParam();
123125

124-
const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA);
126+
const auto options = torch::dtype(dtype).device(torch::kCUDA);
125127

126128
// construct non-contiguous query, key and value
127129
const auto data = torch::randn(
@@ -143,13 +145,18 @@ TEST_P(AttentionKernelTest, MHA) {
143145
auto out = attention_sm80(
144146
query, key, value, alibi_slopes, logits_soft_cap, sliding_window, q_len);
145147

146-
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
148+
if (dtype == torch::kBFloat16) {
149+
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2));
150+
} else {
151+
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
152+
}
147153
}
148154

149155
INSTANTIATE_TEST_SUITE_P(
150156
MHA,
151157
AttentionKernelTest,
152158
::testing::Combine(
159+
::testing::Values(torch::kHalf, torch::kBFloat16), // q_dtype
153160
::testing::Values(1, 2, 4), // batch_size
154161
::testing::Values(1, 62, 125), // q_len
155162
::testing::Values(127, 287, 1000), // kv_len

src/kernels/attention/attention_launch_sm80.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,43 +51,43 @@ void run_attention_kernel(const Params& params, cudaStream_t stream) {
5151
} // namespace detail
5252

5353
// user-facing function to run the attention kernel
54-
template <typename Element, int HEAD_DIM, typename Params>
54+
template <typename Dtype, int HEAD_DIM, typename Params>
5555
void run_attention_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
5656
// normalize params that for performance optimization
5757
params.normalize();
5858

5959
// TODO: tune block shape MNK based on the head dim and smem size
6060
if constexpr (HEAD_DIM == 64) {
61-
using Traits = AttentionTraitsSM80<Element,
61+
using Traits = AttentionTraitsSM80<Dtype,
6262
HEAD_DIM,
6363
/*BLK_M=*/64,
6464
/*BLK_N=*/64,
6565
/*BLK_K=*/64>;
6666
detail::run_attention_kernel<Traits>(params, stream);
6767
} else if constexpr (HEAD_DIM == 96) {
68-
using Traits = AttentionTraitsSM80<Element,
68+
using Traits = AttentionTraitsSM80<Dtype,
6969
HEAD_DIM,
7070
/*BLK_M=*/64,
7171
/*BLK_N=*/64,
7272
/*BLK_K=*/32>;
7373
detail::run_attention_kernel<Traits>(params, stream);
7474
} else if constexpr (HEAD_DIM == 128) {
75-
using Traits = AttentionTraitsSM80<Element,
75+
using Traits = AttentionTraitsSM80<Dtype,
7676
HEAD_DIM,
7777
/*BLK_M=*/64,
7878
/*BLK_N=*/64,
7979
/*BLK_K=*/64>;
8080
detail::run_attention_kernel<Traits>(params, stream);
8181
} else if constexpr (HEAD_DIM == 256) {
82-
using Traits = AttentionTraitsSM80<Element,
82+
using Traits = AttentionTraitsSM80<Dtype,
8383
HEAD_DIM,
8484
/*BLK_M=*/64,
8585
/*BLK_N=*/64,
8686
/*BLK_K=*/64>;
8787
detail::run_attention_kernel<Traits>(params, stream);
8888
} else {
8989
// use the default block size
90-
using Traits = AttentionTraitsSM80<Element,
90+
using Traits = AttentionTraitsSM80<Dtype,
9191
HEAD_DIM,
9292
/*BLK_M=*/64,
9393
/*BLK_N=*/64,

src/kernels/attention/attn_api.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "attn_api.h"
2+
namespace llm {
3+
void paged_kv_varlen_mha(
4+
torch::Tensor& out, // [n_tokens, n_heads, head_dim]
5+
const torch::Tensor& q, // [n_tokens, n_heads, head_dim]
6+
const torch::Tensor& k_cache, // [n_slots, n_kv_heads, head_dim]
7+
const torch::Tensor& v_cache, // [n_slots, n_kv_heads, head_dim]
8+
const torch::Tensor& q_cu_lens, // [batch + 1]
9+
const torch::Tensor& kv_cu_lens, // [batch + 1]
10+
const torch::Tensor& block_table,
11+
const torch::Tensor& block_cu_lens, // [batch + 1]
12+
const std::optional<torch::Tensor>& alibi_slopes, // [n_heads]
13+
int block_size,
14+
int max_q_len,
15+
int max_kv_len,
16+
float sm_scale,
17+
float logits_soft_cap,
18+
int sliding_window) {}
19+
20+
} // namespace llm

src/kernels/attention/attn_api.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <torch/torch.h>
2+
#include <torch/types.h>
3+
4+
namespace llm {
5+
// the input tensors are packed into one-dimensional tensors, and the sequence
6+
// lengths are stored in q_cu_lens and k_cu_lens.
7+
// for each sequence,
8+
// the starting offset: q/kv_cu_lens[i]
9+
// the length: q/kv_cu_lens[i+1] - q/kv_cu_lens[i].
10+
// the maximum sequence length is max_q_len and max_kv_len, which are used
11+
// to decide the kernel dispatch.
12+
void paged_kv_varlen_mha(
13+
torch::Tensor& out, // [n_tokens, n_heads, head_dim]
14+
const torch::Tensor& q, // [n_tokens, n_heads, head_dim]
15+
const torch::Tensor& k_cache, // [n_slots, n_kv_heads, head_dim]
16+
const torch::Tensor& v_cache, // [n_slots, n_kv_heads, head_dim]
17+
const torch::Tensor& q_cu_lens, // [batch + 1]
18+
const torch::Tensor& kv_cu_lens, // [batch + 1]
19+
const torch::Tensor& block_table,
20+
const torch::Tensor& block_cu_lens, // [batch + 1]
21+
const std::optional<torch::Tensor>& alibi_slopes, // [n_heads]
22+
int block_size,
23+
int max_q_len,
24+
int max_kv_len,
25+
float sm_scale,
26+
float logits_soft_cap,
27+
int sliding_window);
28+
29+
} // namespace llm

src/layers/attention/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,20 @@ cc_library(
99
handler.h
1010
ref_handler.h
1111
flash_attn_handler.h
12-
flash_infer_handler.h
12+
scale_attn_handler.h
1313
attention.h
1414
SRCS
1515
handler.cpp
1616
ref_handler.cpp
1717
flash_attn_handler.cpp
18-
flash_infer_handler.cpp
18+
scale_attn_handler.cpp
1919
attention.cpp
2020
DEPS
2121
:state_dict
2222
:memory
2323
:pos_embedding
2424
:kernels
2525
:flash_attn.kernels
26-
# :flash_infer.kernels
2726
glog::glog
2827
gflags::gflags
2928
torch

src/layers/attention/flash_infer_handler.cpp renamed to src/layers/attention/scale_attn_handler.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "flash_infer_handler.h"
1+
#include "scale_attn_handler.h"
22

33
#include <torch/torch.h>
44

@@ -7,7 +7,7 @@
77

88
namespace llm {
99

10-
FlashInferHandler::FlashInferHandler(float scale,
10+
ScaleAttnHandler::ScaleAttnHandler(float scale,
1111
int64_t rotary_dim,
1212
int64_t max_position,
1313
float rope_scaling,
@@ -17,13 +17,13 @@ FlashInferHandler::FlashInferHandler(float scale,
1717
LOG(FATAL) << "Not implemented yet";
1818
}
1919

20-
FlashInferHandler::FlashInferHandler(
20+
ScaleAttnHandler::ScaleAttnHandler(
2121
float scale,
2222
torch::optional<torch::Tensor> alibi_slopes)
2323
: scale_(scale), alibi_slopes_(alibi_slopes) {}
2424

2525
// batch prefill for attention, optimized for prefill stage
26-
void FlashInferHandler::batch_prefill(
26+
void ScaleAttnHandler::batch_prefill(
2727
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
2828
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
2929
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
@@ -36,7 +36,7 @@ void FlashInferHandler::batch_prefill(
3636

3737
// batch decode for attention, optimized for decode stage
3838
// support multiple queries: one sequence with multiple query tokens
39-
void FlashInferHandler::batch_decode(
39+
void ScaleAttnHandler::batch_decode(
4040
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
4141
const KVCache& kv_cache, // where to retrieval key and value
4242
const InputParameters& input_params, // input paras used for attention
@@ -47,7 +47,7 @@ void FlashInferHandler::batch_decode(
4747
}
4848

4949
// append key and value to kv_cache
50-
void FlashInferHandler::append_kv_cache(
50+
void ScaleAttnHandler::append_kv_cache(
5151
KVCache& kv_cache, // where to store key and value
5252
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
5353
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]

src/layers/attention/flash_infer_handler.h renamed to src/layers/attention/scale_attn_handler.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
namespace llm {
1010

1111
// an flash attn implementation for attention operations
12-
class FlashInferHandler : public AttentionHandler {
12+
class ScaleAttnHandler : public AttentionHandler {
1313
public:
1414
// create a flash attn handler with rope positional embedding
15-
FlashInferHandler(float scale,
15+
ScaleAttnHandler(float scale,
1616
int64_t rotary_dim,
1717
int64_t max_position,
1818
float rope_scaling,
@@ -21,9 +21,9 @@ class FlashInferHandler : public AttentionHandler {
2121
const torch::TensorOptions& options);
2222

2323
// constructor for attention with alibi
24-
FlashInferHandler(float scale, torch::optional<torch::Tensor> alibi_slopes);
24+
ScaleAttnHandler(float scale, std::optional<torch::Tensor> alibi_slopes);
2525

26-
virtual ~FlashInferHandler() = default;
26+
virtual ~ScaleAttnHandler() = default;
2727

2828
std::tuple<torch::Tensor, torch::Tensor> apply_pos_emb(
2929
const torch::Tensor& query,
@@ -63,7 +63,7 @@ class FlashInferHandler : public AttentionHandler {
6363
float scale_ = 0.0;
6464

6565
// alibi slops
66-
torch::optional<torch::Tensor> alibi_slopes_;
66+
std::optional<torch::Tensor> alibi_slopes_;
6767
};
6868

6969
} // namespace llm

0 commit comments

Comments
 (0)