Skip to content

Commit 1e6aa5e

Browse files
committed
Force self.A_log to be fp32
1 parent 3c77dcf commit 1e6aa5e

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

mamba_ssm/modules/mamba2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
170170
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
171171
if seqlen_og is not None:
172172
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
173-
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
173+
# If the model is loaded in fp16, without the .float() here, A might be -inf
174+
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
174175
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
175176
if self.use_mem_eff_path and inference_params is None:
176177
out = mamba_split_conv1d_scan_combined(

mamba_ssm/ops/triton/ssd_chunk_scan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,11 @@ def _chunk_scan_bwd_dx_kernel(
680680
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
681681
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
682682
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
683-
mask = k + offs_k[None, :] >= offs_m[:, None]
683+
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
684+
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
685+
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
686+
# This will cause NaN in acc, and hence NaN in dx and ddt.
687+
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
684688
cb = tl.where(mask, cb, 0.0)
685689
cb = cb.to(dout_ptr.dtype.element_ty)
686690
acc += tl.dot(cb, dout)

0 commit comments

Comments
 (0)