@@ -47,6 +47,13 @@ def init_to_zero(names):
4747 return lambda nargs : [nargs [name ].zero_ () for name in names if nargs [name ] is not None ]
4848
4949
50+ def rearrange_and_update_stride (tensor , pattern = None , dim = 2 ):
51+ # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
52+ # if not call contiguous(), rearrange only if pattern is not None
53+ tensor_rearranged = rearrange (tensor , pattern ) if pattern is not None else tensor
54+ return tensor_rearranged .contiguous () if tensor_rearranged .stride (dim ) % 8 != 0 else tensor_rearranged
55+
56+
5057@triton .autotune (
5158 configs = [
5259 triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 64 }, num_stages = 3 , num_warps = 8 , pre_hook = init_to_zero (["ddt_ptr" ])),
@@ -120,11 +127,13 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
120127
121128 dA_cs_last = tl .load (dA_cumsum_ptr + (chunk_size - 1 ) * stride_dA_cs_csize ).to (tl .float32 )
122129 if not HAS_SEQ_IDX :
123- scale = tl .exp (dA_cs_last - dA_cs_m )
130+ # scale = tl.exp(dA_cs_last - dA_cs_m)
131+ scale = tl .exp (tl .minimum ((dA_cs_last - dA_cs_m ), 0.0 ))
124132 else :
125133 seq_idx_m = tl .load (seq_idx_ptr + offs_m * stride_seq_idx_seqlen , mask = offs_m < chunk_size_limit , other = - 1 )
126134 seq_idx_last = tl .load (seq_idx_ptr + (chunk_size_limit - 1 ) * stride_seq_idx_seqlen )
127- scale = tl .where (seq_idx_m == seq_idx_last , tl .exp (dA_cs_last - dA_cs_m ), 0.0 )
135+ # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
136+ scale = tl .where (seq_idx_m == seq_idx_last , tl .exp (tl .minimum ((dA_cs_last - dA_cs_m ), 0.0 )), 0.0 )
128137 # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
129138 # However, we're getting error with the Triton compiler 2.1.0 for that code path:
130139 # Unexpected mma -> mma layout conversion
@@ -170,7 +179,8 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
170179 cb = tl .load (cb_ptrs , mask = (offs_m [:, None ] < chunk_size ) & (offs_k [None , :] < K_MAX - k ), other = 0.0 )
171180 dout = tl .load (dout_ptrs , mask = (offs_k [:, None ] < K_MAX - k ) & (offs_n [None , :] < hdim ), other = 0.0 )
172181 dA_cs_k = tl .load (dA_cumsum_ptrs , mask = offs_k < K_MAX - k , other = 0.0 ).to (tl .float32 )
173- cb *= tl .exp (dA_cs_k [None , :] - dA_cs_m [:, None ])
182+ # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
183+ cb *= tl .exp (tl .minimum ((dA_cs_k [None , :] - dA_cs_m [:, None ]), 0.0 ))
174184 # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
175185 # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
176186 # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -776,7 +786,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
776786 zx0 , z , xBC , dt = torch .split (zxbcdt , [2 * d_nonssm , dim , dim + ngroups * dstate * 2 , nheads ], dim = - 1 )
777787 seq_idx = seq_idx .contiguous () if seq_idx is not None else None
778788 xBC_conv = rearrange (
779- causal_conv1d_cuda .causal_conv1d_fwd (rearrange (xBC , "b s d -> b d s" ),
789+ causal_conv1d_cuda .causal_conv1d_fwd (rearrange_and_update_stride (xBC , "b s d -> b d s" ),
780790 conv1d_weight , conv1d_bias , seq_idx , None , None , activation in ["silu" , "swish" ]),
781791 "b d s -> b s d"
782792 )
@@ -850,7 +860,7 @@ def backward(ctx, dout, *args):
850860 zx0 , z , xBC , dt = torch .split (zxbcdt , [2 * d_nonssm , dim , dim + 2 * ctx .ngroups * dstate , nheads ], dim = - 1 )
851861 # Recompute x, B, C
852862 xBC_conv = rearrange (
853- causal_conv1d_cuda .causal_conv1d_fwd (rearrange (xBC , "b s d -> b d s" ),
863+ causal_conv1d_cuda .causal_conv1d_fwd (rearrange_and_update_stride (xBC , "b s d -> b d s" ),
854864 conv1d_weight , conv1d_bias , seq_idx , None , None , ctx .activation in ["silu" , "swish" ]),
855865 "b d s -> b s d"
856866 )
@@ -900,10 +910,14 @@ def backward(ctx, dout, *args):
900910 else :
901911 doutproj_weight , doutproj_bias = None , None
902912 dxBC_given = rearrange (dxBC_given , "b s d -> b d s" )
903- dxBC_given , dweight , dbias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
904- rearrange (xBC , "b s d -> b d s" ), conv1d_weight , conv1d_bias ,
905- rearrange (dxBC , "b s d -> b d s" ), seq_idx , None , None , dxBC_given , False , ctx .activation in ["silu" , "swish" ]
913+ dxBC_given_update , dweight , dbias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
914+ rearrange_and_update_stride (xBC , "b s d -> b d s" ), conv1d_weight , conv1d_bias ,
915+ rearrange (dxBC , "b s d -> b d s" ), seq_idx , None , None , rearrange_and_update_stride ( dxBC_given ) , False , ctx .activation in ["silu" , "swish" ]
906916 )
917+ if dxBC_given .stride () != dxBC_given_update .stride ():
918+ dxBC_given .copy_ (dxBC_given_update )
919+ else :
920+ dxBC_given = dxBC_given_update
907921 dxBC_given = rearrange (dxBC_given , "b d s -> b s d" )
908922 return dzxbcdt , dweight , dbias , ddt_bias , dA , dD , None , dinitial_states , None , None , None , None , drmsnorm_weight , None , doutproj_weight , doutproj_bias , None , None , None
909923
0 commit comments