Skip to content

Commit 9259852

Browse files
authored
fixing softplus bug with _chunk_cumsum_bwd_kernel() triton kernel (#574)
1 parent 62db608 commit 9259852

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

mamba_ssm/ops/triton/selective_state_update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ def _selective_scan_update_kernel(
8585
if HAS_DT_BIAS:
8686
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
8787
if DT_SOFTPLUS:
88-
dt = softplus(dt)
88+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
8989
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
9090
dA = tl.exp(A * dt[:, None])
9191
else:
9292
dt = tl.load(dt_ptr).to(tl.float32)
9393
if HAS_DT_BIAS:
9494
dt += tl.load(dt_bias_ptr).to(tl.float32)
9595
if DT_SOFTPLUS:
96-
dt = softplus(dt)
96+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
9797
A = tl.load(A_ptr).to(tl.float32)
9898
dA = tl.exp(A * dt) # scalar, not a matrix
9999

mamba_ssm/ops/triton/softplus.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
if TRITON3:
99
@triton.jit
1010
def softplus(dt):
11-
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
12-
return dt
11+
return tl.math.log(tl.math.exp(dt) + 1)
1312
else:
1413
@triton.jit
1514
def softplus(dt):
16-
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
17-
return dt
15+
return tl.math.log1p(tl.exp(dt))

mamba_ssm/ops/triton/ssd_chunk_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _chunk_cumsum_fwd_kernel(
6868
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
6969
dt += dt_bias[:, None]
7070
if DT_SOFTPLUS:
71-
dt = softplus(dt)
71+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
7272
# As of Triton 2.2.0, tl.clamp is not available yet
7373
# dt = tl.clamp(dt, dt_min, dt_max)
7474
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
@@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
141141
dt += dt_bias[:, None]
142142
if DT_SOFTPLUS:
143143
dt_presoftplus = dt
144-
dt = softplus(dt)
144+
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
145145
clamp_mask = (dt < dt_min) | (dt > dt_max)
146146
# As of Triton 2.2.0, tl.clamp is not available yet
147147
# dt = tl.clamp(dt, dt_min, dt_max)

0 commit comments

Comments
 (0)