Skip to content

Commit dad2ba0

Browse files
authored
[Tutorial] Support non-casual in flash attention backward (#8565)
With non-causal, for dkdv computation: ``` start_n = pid * BLOCK_N1 start_m = 0 ``` for dq computation: ``` start_m = pid * BLOCK_M2 start_n = 0 ```
1 parent 11d2077 commit dad2ba0

File tree

1 file changed

+48
-40
lines changed

1 file changed

+48
-40
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,8 @@ def _attn_bwd(Q, K, V, sm_scale, #
373373
BLOCK_M2: tl.constexpr, #
374374
BLOCK_N2: tl.constexpr, #
375375
BLK_SLICE_FACTOR: tl.constexpr, #
376-
HEAD_DIM: tl.constexpr):
376+
HEAD_DIM: tl.constexpr, #
377+
CAUSAL: tl.constexpr):
377378
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
378379

379380
bhid = tl.program_id(2)
@@ -396,7 +397,7 @@ def _attn_bwd(Q, K, V, sm_scale, #
396397
offs_k = tl.arange(0, HEAD_DIM)
397398

398399
start_n = pid * BLOCK_N1
399-
start_m = start_n
400+
start_m = 0
400401

401402
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
402403
offs_n = start_n + tl.arange(0, BLOCK_N1)
@@ -408,23 +409,24 @@ def _attn_bwd(Q, K, V, sm_scale, #
408409
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
409410
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
410411

411-
num_steps = BLOCK_N1 // MASK_BLOCK_M1
412-
413-
dk, dv = _attn_bwd_dkdv(dk, dv, #
414-
Q, k, v, sm_scale, #
415-
DO, #
416-
M, D, #
417-
stride_tok, stride_d, #
418-
H, N_CTX, #
419-
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
420-
start_n, start_m, num_steps, #
421-
MASK=True #
422-
)
423-
424-
start_m += num_steps * MASK_BLOCK_M1
425-
num_steps = (N_CTX - start_m) // BLOCK_M1
412+
if CAUSAL:
413+
start_m = start_n
414+
num_steps = BLOCK_N1 // MASK_BLOCK_M1
415+
dk, dv = _attn_bwd_dkdv(dk, dv, #
416+
Q, k, v, sm_scale, #
417+
DO, #
418+
M, D, #
419+
stride_tok, stride_d, #
420+
H, N_CTX, #
421+
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
422+
start_n, start_m, num_steps, #
423+
MASK=True, #
424+
)
425+
426+
start_m += num_steps * MASK_BLOCK_M1
426427

427428
# Compute dK and dV for non-masked blocks.
429+
num_steps = (N_CTX - start_m) // BLOCK_M1
428430
dk, dv = _attn_bwd_dkdv( #
429431
dk, dv, #
430432
Q, k, v, sm_scale, #
@@ -434,7 +436,7 @@ def _attn_bwd(Q, K, V, sm_scale, #
434436
H, N_CTX, #
435437
BLOCK_M1, BLOCK_N1, HEAD_DIM, #
436438
start_n, start_m, num_steps, #
437-
MASK=False #
439+
MASK=False, #
438440
)
439441

440442
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
@@ -447,7 +449,8 @@ def _attn_bwd(Q, K, V, sm_scale, #
447449

448450
# THIS BLOCK DOES DQ:
449451
start_m = pid * BLOCK_M2
450-
end_n = start_m + BLOCK_M2
452+
start_n = 0
453+
num_steps = N_CTX // BLOCK_N2
451454

452455
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
453456
offs_m = start_m + tl.arange(0, BLOCK_M2)
@@ -459,30 +462,34 @@ def _attn_bwd(Q, K, V, sm_scale, #
459462
m = tl.load(M + offs_m)
460463
m = m[:, None]
461464

462-
# Compute dQ for masked (diagonal) blocks.
463-
# NOTE: This code scans each row of QK^T backward (from right to left,
464-
# but inside each call to _attn_bwd_dq, from left to right), but that's
465-
# not due to anything important. I just wanted to reuse the loop
466-
# structure for dK & dV above as much as possible.
467-
num_steps = BLOCK_M2 // MASK_BLOCK_N2
468-
dq = _attn_bwd_dq(dq, q, K, V, #
469-
do, m, D, #
470-
stride_tok, stride_d, #
471-
H, N_CTX, #
472-
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
473-
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
474-
MASK=True #
475-
)
476-
end_n -= num_steps * MASK_BLOCK_N2
477-
# stage 2
478-
num_steps = end_n // BLOCK_N2
465+
if CAUSAL:
466+
# Compute dQ for masked (diagonal) blocks.
467+
# NOTE: This code scans each row of QK^T backward (from right to left,
468+
# but inside each call to _attn_bwd_dq, from left to right), but that's
469+
# not due to anything important. I just wanted to reuse the loop
470+
# structure for dK & dV above as much as possible.
471+
end_n = start_m + BLOCK_M2
472+
num_steps = BLOCK_M2 // MASK_BLOCK_N2
473+
dq = _attn_bwd_dq(dq, q, K, V, #
474+
do, m, D, #
475+
stride_tok, stride_d, #
476+
H, N_CTX, #
477+
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
478+
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
479+
MASK=True, #
480+
)
481+
end_n -= num_steps * MASK_BLOCK_N2
482+
# stage 2
483+
num_steps = end_n // BLOCK_N2
484+
start_n = end_n - num_steps * BLOCK_N2
485+
479486
dq = _attn_bwd_dq(dq, q, K, V, #
480487
do, m, D, #
481488
stride_tok, stride_d, #
482489
H, N_CTX, #
483490
BLOCK_M2, BLOCK_N2, HEAD_DIM, #
484-
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
485-
MASK=False #
491+
start_m, start_n, num_steps, #
492+
MASK=False, #
486493
)
487494
# Write back dQ.
488495
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
@@ -599,7 +606,8 @@ def backward(ctx, do):
599606
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
600607
HEAD_DIM=ctx.HEAD_DIM, #
601608
num_warps=NUM_WARPS, #
602-
num_stages=NUM_STAGES #
609+
num_stages=NUM_STAGES, #
610+
CAUSAL=ctx.causal, #
603611
)
604612

605613
return dq, dk, dv, None, None, None, None
@@ -614,7 +622,7 @@ def backward(ctx, do):
614622
@pytest.mark.parametrize("H", [2, 48])
615623
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024])
616624
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
617-
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment.
625+
@pytest.mark.parametrize("causal", [False, True])
618626
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False])
619627
@pytest.mark.parametrize("mode", ["fwd", "bwd"])
620628
@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []))

0 commit comments

Comments
 (0)