Skip to content

Commit a6a1dae

Browse files
matthieuleMatthieu Le
andauthored
Fix _chunk_state_bwd_db_kernel when using seq_idx (#746)
Co-authored-by: Matthieu Le <[email protected]>
1 parent 74729d0 commit a6a1dae

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mamba_ssm/ops/triton/ssd_chunk_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def _chunk_state_bwd_db_kernel(
441441
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
442442
else:
443443
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
444-
scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0)
444+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
445445
db *= (scale * dt_m)[:, None]
446446
if HAS_DDA_CS:
447447
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum

0 commit comments

Comments
 (0)