Skip to content

Commit 2e16fc3

Browse files
Add numerical instability patch and enable nheads to be a non multiple of 8 (#713)
* Numerical stability for large negative values * Fix causal_conv1d xBC stride not multiple of 8 issue * Fix backprop for causal_conv1d xBC stride not multiple of 8 issue * Fix ddt -> dt typo * Add nit comment * Call ontiguous before causal_conv1d only when stride is not a multiple of 8 * Copy only if strides differ --------- Co-authored-by: Roger Waleffe <[email protected]> Co-authored-by: Duncan Riach <[email protected]>
1 parent 0cce0fa commit 2e16fc3

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

mamba_ssm/ops/triton/ssd_chunk_scan.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def _chunk_scan_fwd_kernel(
132132
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
133133
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
134134
# So we don't need masking wrt seq_idx here.
135-
cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
135+
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
136+
cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))
136137
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
137138
cb *= dt_k
138139
if IS_CAUSAL:
@@ -679,7 +680,8 @@ def _chunk_scan_bwd_dx_kernel(
679680
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
680681
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
681682
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
682-
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
683+
# cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
684+
cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
683685
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
684686
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
685687
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -816,7 +818,8 @@ def _chunk_scan_bwd_dcb_kernel(
816818
dcb *= dt_n
817819
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
818820
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
819-
dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
821+
# dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
822+
dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
820823
if HAS_DDA_CS:
821824
tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
822825
ddA_cs = dcb * cb
@@ -1008,7 +1011,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old(
10081011
acc *= dt_n
10091012
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
10101013
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
1011-
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
1014+
# acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
1015+
acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
10121016
mask = offs_m[:, None] >= offs_n[None, :] + 1
10131017
acc = tl.where(mask, acc, 0.0)
10141018
acc = tl.cumsum(acc, axis=1)
@@ -1134,7 +1138,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel(
11341138
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
11351139
acc *= cb
11361140
dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
1137-
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
1141+
# acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
1142+
acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
11381143
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
11391144
acc = tl.where(mask, acc, 0.0)
11401145
rowsum_new = rowsum + tl.sum(acc, axis=1)

mamba_ssm/ops/triton/ssd_chunk_state.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
141141
dt += dt_bias[:, None]
142142
if DT_SOFTPLUS:
143143
dt_presoftplus = dt
144-
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
144+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
145145
clamp_mask = (dt < dt_min) | (dt > dt_max)
146146
# As of Triton 2.2.0, tl.clamp is not available yet
147147
# dt = tl.clamp(dt, dt_min, dt_max)
@@ -229,9 +229,11 @@ def _chunk_state_fwd_kernel(
229229
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
230230
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
231231
if not HAS_SEQ_IDX:
232-
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
232+
# scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
233+
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
233234
else:
234-
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
235+
# scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
236+
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
235237
b *= scale[:, None]
236238
b = b.to(x_ptr.dtype.element_ty)
237239
acc += tl.dot(x, b)
@@ -332,7 +334,8 @@ def _chunk_state_bwd_dx_kernel(
332334
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
333335
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
334336
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
335-
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
337+
# acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
338+
acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]
336339

337340
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
338341
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
@@ -434,9 +437,11 @@ def _chunk_state_bwd_db_kernel(
434437
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
435438
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
436439
if not HAS_SEQ_IDX:
437-
scale = tl.exp(dA_cs_last - dA_cs_m)
440+
# scale = tl.exp(dA_cs_last - dA_cs_m)
441+
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
438442
else:
439-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
443+
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
444+
scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0)
440445
db *= (scale * dt_m)[:, None]
441446
if HAS_DDA_CS:
442447
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
@@ -549,11 +554,13 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
549554
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
550555
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
551556
if not HAS_SEQ_IDX:
552-
scale = tl.exp(dA_cs_last - dA_cs_m)
557+
# scale = tl.exp(dA_cs_last - dA_cs_m)
558+
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
553559
else:
554560
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
555561
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
556-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
562+
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
563+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
557564
acc *= scale[:, None]
558565

559566
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
@@ -634,8 +641,10 @@ def _chunk_state_varlen_kernel(
634641
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
635642
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
636643
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
644+
# scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
645+
# tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
637646
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
638-
tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
647+
tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
639648
b *= scale[:, None]
640649
b = b.to(x_ptr.dtype.element_ty)
641650
acc += tl.dot(x, b)

mamba_ssm/ops/triton/ssd_combined.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def init_to_zero(names):
4747
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
4848

4949

50+
def rearrange_and_update_stride(tensor, pattern=None, dim=2):
51+
# ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
52+
# if not call contiguous(), rearrange only if pattern is not None
53+
tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor
54+
return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged
55+
56+
5057
@triton.autotune(
5158
configs=[
5259
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
@@ -120,11 +127,13 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
120127

121128
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
122129
if not HAS_SEQ_IDX:
123-
scale = tl.exp(dA_cs_last - dA_cs_m)
130+
# scale = tl.exp(dA_cs_last - dA_cs_m)
131+
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
124132
else:
125133
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
126134
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
127-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
135+
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
136+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
128137
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
129138
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
130139
# Unexpected mma -> mma layout conversion
@@ -170,7 +179,8 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
170179
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
171180
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
172181
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
173-
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
182+
# cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
183+
cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
174184
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
175185
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
176186
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -776,7 +786,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
776786
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
777787
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
778788
xBC_conv = rearrange(
779-
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
789+
causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
780790
conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
781791
"b d s -> b s d"
782792
)
@@ -850,7 +860,7 @@ def backward(ctx, dout, *args):
850860
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
851861
# Recompute x, B, C
852862
xBC_conv = rearrange(
853-
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
863+
causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
854864
conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
855865
"b d s -> b s d"
856866
)
@@ -900,10 +910,14 @@ def backward(ctx, dout, *args):
900910
else:
901911
doutproj_weight, doutproj_bias = None, None
902912
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
903-
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
904-
rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
905-
rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
913+
dxBC_given_update, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
914+
rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
915+
rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"]
906916
)
917+
if dxBC_given.stride() != dxBC_given_update.stride():
918+
dxBC_given.copy_(dxBC_given_update)
919+
else:
920+
dxBC_given = dxBC_given_update
907921
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
908922
return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
909923

0 commit comments

Comments
 (0)