Skip to content

Commit db09d4a

Browse files
zhuohan123WoosukKwonOliver-ss
authored
[FIX] Fix Alibi implementation in PagedAttention kernel (#945)
* [FIX] Fix Alibi implementation in PagedAttention kernel * Fix test_attention * Fix --------- Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: Oliver-ss <[email protected]>
1 parent c957c74 commit db09d4a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

csrc/attention/attention_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ __global__ void single_query_cached_kv_attention_kernel(
178178
// This includes a reduction across the threads in the same thread group.
179179
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
180180
// Add the ALiBi bias if slopes are given.
181-
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
181+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
182182

183183
if (thread_group_offset == 0) {
184184
// Store the partial reductions to shared memory.

tests/kernels/test_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
1818
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
1919
BLOCK_SIZES = [8, 16, 32]
20-
USE_ALIBI = [False] # TODO(woosuk): Add USE_ALIBI=True
20+
USE_ALIBI = [False, True]
2121
SEEDS = [0]
2222

2323

@@ -83,7 +83,7 @@ def ref_single_query_cached_kv_attention(
8383
if alibi_slopes is not None:
8484
# Create the ALiBi bias used in the paged attention kernel.
8585
position_ids = torch.arange(context_len, device="cuda").int()
86-
alibi_bias = (context_len - position_ids).float()
86+
alibi_bias = (position_ids - context_len + 1).float()
8787
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
8888
1, 1, -1)
8989

@@ -224,6 +224,7 @@ def ref_multi_query_kv_attention(
224224
return ref_output
225225

226226

227+
# TODO(woosuk): Add tests for USE_ALIBI=True.
227228
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
228229
@pytest.mark.parametrize("num_heads", NUM_HEADS)
229230
@pytest.mark.parametrize("head_size", HEAD_SIZES)

0 commit comments

Comments
 (0)