Skip to content

Commit 3b0dde5

Browse files
ilyasch2Ilyas.Chahed
andauthored
Fix exploding gradients when ngroups larger than one (#547)
Co-authored-by: Ilyas.Chahed <[email protected]>
1 parent 9259852 commit 3b0dde5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mamba_ssm/ops/triton/ssd_combined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def backward(ctx, dout, *args):
887887
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
888888
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
889889
out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
890-
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
890+
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
891891
out_for_linear = out_recompute if recompute_output else None
892892
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
893893
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(

0 commit comments

Comments
 (0)