@@ -373,7 +373,8 @@ def _attn_bwd(Q, K, V, sm_scale, #
373373 BLOCK_M2 : tl .constexpr , #
374374 BLOCK_N2 : tl .constexpr , #
375375 BLK_SLICE_FACTOR : tl .constexpr , #
376- HEAD_DIM : tl .constexpr ):
376+ HEAD_DIM : tl .constexpr , #
377+ CAUSAL : tl .constexpr ):
377378 LN2 : tl .constexpr = 0.6931471824645996 # = ln(2)
378379
379380 bhid = tl .program_id (2 )
@@ -396,7 +397,7 @@ def _attn_bwd(Q, K, V, sm_scale, #
396397 offs_k = tl .arange (0 , HEAD_DIM )
397398
398399 start_n = pid * BLOCK_N1
399- start_m = start_n
400+ start_m = 0
400401
401402 MASK_BLOCK_M1 : tl .constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
402403 offs_n = start_n + tl .arange (0 , BLOCK_N1 )
@@ -408,23 +409,24 @@ def _attn_bwd(Q, K, V, sm_scale, #
408409 k = tl .load (K + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
409410 v = tl .load (V + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
410411
411- num_steps = BLOCK_N1 // MASK_BLOCK_M1
412-
413- dk , dv = _attn_bwd_dkdv ( dk , dv , #
414- Q , k , v , sm_scale , #
415- DO , #
416- M , D , #
417- stride_tok , stride_d , #
418- H , N_CTX , #
419- MASK_BLOCK_M1 , BLOCK_N1 , HEAD_DIM , #
420- start_n , start_m , num_steps , #
421- MASK = True #
422- )
423-
424- start_m += num_steps * MASK_BLOCK_M1
425- num_steps = ( N_CTX - start_m ) // BLOCK_M1
412+ if CAUSAL :
413+ start_m = start_n
414+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
415+ dk , dv = _attn_bwd_dkdv ( dk , dv , #
416+ Q , k , v , sm_scale , #
417+ DO , #
418+ M , D , #
419+ stride_tok , stride_d , #
420+ H , N_CTX , #
421+ MASK_BLOCK_M1 , BLOCK_N1 , HEAD_DIM , #
422+ start_n , start_m , num_steps , #
423+ MASK = True , #
424+ )
425+
426+ start_m += num_steps * MASK_BLOCK_M1
426427
427428 # Compute dK and dV for non-masked blocks.
429+ num_steps = (N_CTX - start_m ) // BLOCK_M1
428430 dk , dv = _attn_bwd_dkdv ( #
429431 dk , dv , #
430432 Q , k , v , sm_scale , #
@@ -434,7 +436,7 @@ def _attn_bwd(Q, K, V, sm_scale, #
434436 H , N_CTX , #
435437 BLOCK_M1 , BLOCK_N1 , HEAD_DIM , #
436438 start_n , start_m , num_steps , #
437- MASK = False #
439+ MASK = False , #
438440 )
439441
440442 dv_ptrs = DV + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d
@@ -447,7 +449,8 @@ def _attn_bwd(Q, K, V, sm_scale, #
447449
448450 # THIS BLOCK DOES DQ:
449451 start_m = pid * BLOCK_M2
450- end_n = start_m + BLOCK_M2
452+ start_n = 0
453+ num_steps = N_CTX // BLOCK_N2
451454
452455 MASK_BLOCK_N2 : tl .constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
453456 offs_m = start_m + tl .arange (0 , BLOCK_M2 )
@@ -459,30 +462,34 @@ def _attn_bwd(Q, K, V, sm_scale, #
459462 m = tl .load (M + offs_m )
460463 m = m [:, None ]
461464
462- # Compute dQ for masked (diagonal) blocks.
463- # NOTE: This code scans each row of QK^T backward (from right to left,
464- # but inside each call to _attn_bwd_dq, from left to right), but that's
465- # not due to anything important. I just wanted to reuse the loop
466- # structure for dK & dV above as much as possible.
467- num_steps = BLOCK_M2 // MASK_BLOCK_N2
468- dq = _attn_bwd_dq (dq , q , K , V , #
469- do , m , D , #
470- stride_tok , stride_d , #
471- H , N_CTX , #
472- BLOCK_M2 , MASK_BLOCK_N2 , HEAD_DIM , #
473- start_m , end_n - num_steps * MASK_BLOCK_N2 , num_steps , #
474- MASK = True #
475- )
476- end_n -= num_steps * MASK_BLOCK_N2
477- # stage 2
478- num_steps = end_n // BLOCK_N2
465+ if CAUSAL :
466+ # Compute dQ for masked (diagonal) blocks.
467+ # NOTE: This code scans each row of QK^T backward (from right to left,
468+ # but inside each call to _attn_bwd_dq, from left to right), but that's
469+ # not due to anything important. I just wanted to reuse the loop
470+ # structure for dK & dV above as much as possible.
471+ end_n = start_m + BLOCK_M2
472+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
473+ dq = _attn_bwd_dq (dq , q , K , V , #
474+ do , m , D , #
475+ stride_tok , stride_d , #
476+ H , N_CTX , #
477+ BLOCK_M2 , MASK_BLOCK_N2 , HEAD_DIM , #
478+ start_m , end_n - num_steps * MASK_BLOCK_N2 , num_steps , #
479+ MASK = True , #
480+ )
481+ end_n -= num_steps * MASK_BLOCK_N2
482+ # stage 2
483+ num_steps = end_n // BLOCK_N2
484+ start_n = end_n - num_steps * BLOCK_N2
485+
479486 dq = _attn_bwd_dq (dq , q , K , V , #
480487 do , m , D , #
481488 stride_tok , stride_d , #
482489 H , N_CTX , #
483490 BLOCK_M2 , BLOCK_N2 , HEAD_DIM , #
484- start_m , end_n - num_steps * BLOCK_N2 , num_steps , #
485- MASK = False #
491+ start_m , start_n , num_steps , #
492+ MASK = False , #
486493 )
487494 # Write back dQ.
488495 dq_ptrs = DQ + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d
@@ -599,7 +606,8 @@ def backward(ctx, do):
599606 BLK_SLICE_FACTOR = BLK_SLICE_FACTOR , #
600607 HEAD_DIM = ctx .HEAD_DIM , #
601608 num_warps = NUM_WARPS , #
602- num_stages = NUM_STAGES #
609+ num_stages = NUM_STAGES , #
610+ CAUSAL = ctx .causal , #
603611 )
604612
605613 return dq , dk , dv , None , None , None , None
@@ -614,7 +622,7 @@ def backward(ctx, do):
614622@pytest .mark .parametrize ("H" , [2 , 48 ])
615623@pytest .mark .parametrize ("N_CTX" , [128 , 1024 , (2 if is_hip () else 4 ) * 1024 ])
616624@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
617- @pytest .mark .parametrize ("causal" , [True ]) # FIXME: Non-causal tests do not pass at the moment.
625+ @pytest .mark .parametrize ("causal" , [False , True ])
618626@pytest .mark .parametrize ("warp_specialize" , [False , True ] if is_blackwell () else [False ])
619627@pytest .mark .parametrize ("mode" , ["fwd" , "bwd" ])
620628@pytest .mark .parametrize ("provider" , ["triton-fp16" ] + (["triton-fp8" ] if TORCH_HAS_FP8 else []))
0 commit comments