Skip to content

Commit b12a63c

Browse files
authored
[Bugfix] Ensure correct handling for cases where seq_q<seq_kv in flash attention examples (#864)
* fix flash attention examples for `seqlen_q<seqlen_kv` cases * lint
1 parent 3b21a67 commit b12a63c

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

examples/flash_attention/example_mha_fwd_bhsd.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def flashattn(batch,
3434
dtype = "float16"
3535
accum_dtype = "float"
3636

37+
past_len = seq_kv - seq_q
38+
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
39+
3740
@T.macro
3841
def MMA0(
3942
K: T.Tensor(kv_shape, dtype),
@@ -45,7 +48,6 @@ def MMA0(
4548
by: T.int32,
4649
bz: T.int32,
4750
):
48-
past_len = seq_kv - seq_q
4951
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
5052
if is_causal:
5153
for i, j in T.Parallel(block_M, block_N):
@@ -135,8 +137,10 @@ def main(
135137
T.fill(scores_max, -T.infinity(accum_dtype))
136138

137139
loop_range = (
138-
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
139-
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
140+
T.min(
141+
T.ceildiv(seq_kv, block_N), T.ceildiv(
142+
(bx + 1) * block_M +
143+
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
140144

141145
for k in T.Pipelined(loop_range, num_stages=num_stages):
142146
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
@@ -159,7 +163,7 @@ def ref_program(Q, K, V, is_causal):
159163
if is_causal:
160164
seq_q = Q.size(2)
161165
seq_kv = K.size(2)
162-
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
166+
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
163167
mask = mask.unsqueeze(0).unsqueeze(0)
164168
scores = scores.masked_fill(mask == 0, float('-inf'))
165169
attention_weights = F.softmax(scores, dim=-1)

examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def flashattn(batch,
3434
dtype = "float16"
3535
accum_dtype = "float"
3636

37+
past_len = seq_kv - seq_q
38+
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
39+
3740
@T.macro
3841
def MMA0(
3942
K: T.Tensor(kv_shape, dtype),
@@ -45,7 +48,6 @@ def MMA0(
4548
by: T.int32,
4649
bz: T.int32,
4750
):
48-
past_len = seq_kv - seq_q
4951
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
5052
if is_causal:
5153
for i, j in T.Parallel(block_M, block_N):
@@ -135,8 +137,10 @@ def main(
135137
T.fill(scores_max, -T.infinity(accum_dtype))
136138

137139
loop_range = (
138-
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
139-
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
140+
T.min(
141+
T.ceildiv(seq_kv, block_N), T.ceildiv(
142+
(bx + 1) * block_M +
143+
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
140144

141145
for k in T.Pipelined(
142146
loop_range,
@@ -164,7 +168,7 @@ def ref_program(Q, K, V, is_causal):
164168
if is_causal:
165169
seq_q = Q.size(2)
166170
seq_kv = K.size(2)
167-
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
171+
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
168172
mask = mask.unsqueeze(0).unsqueeze(0)
169173
scores = scores.masked_fill(mask == 0, float('-inf'))
170174
attention_weights = F.softmax(scores, dim=-1)

0 commit comments

Comments
 (0)