Skip to content

Commit c2568f5

Browse files
authored
modify mamba triton kernels compatible to triton version >= 3.0.0 (#377)
1 parent f9dbb4f commit c2568f5

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

mamba_ssm/ops/triton/selective_state_update.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from einops import rearrange, repeat
1414

15+
from mamba_ssm.ops.triton.softplus import softplus
16+
1517

1618
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
1719
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@@ -83,15 +85,15 @@ def _selective_scan_update_kernel(
8385
if HAS_DT_BIAS:
8486
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
8587
if DT_SOFTPLUS:
86-
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
88+
dt = softplus(dt)
8789
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
8890
dA = tl.exp(A * dt[:, None])
8991
else:
9092
dt = tl.load(dt_ptr).to(tl.float32)
9193
if HAS_DT_BIAS:
9294
dt += tl.load(dt_bias_ptr).to(tl.float32)
9395
if DT_SOFTPLUS:
94-
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
96+
dt = softplus(dt)
9597
A = tl.load(A_ptr).to(tl.float32)
9698
dA = tl.exp(A * dt) # scalar, not a matrix
9799

mamba_ssm/ops/triton/softplus.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import triton
2+
import triton.language as tl
3+
from packaging import version
4+
5+
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
6+
7+
8+
if TRITON3:
9+
@triton.jit
10+
def softplus(dt):
11+
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
12+
return dt
13+
else:
14+
@triton.jit
15+
def softplus(dt):
16+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
17+
return dt

mamba_ssm/ops/triton/ssd_chunk_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from einops import rearrange, repeat
1414

15+
from mamba_ssm.ops.triton.softplus import softplus
16+
1517

1618
def init_to_zero(names):
1719
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
@@ -66,7 +68,7 @@ def _chunk_cumsum_fwd_kernel(
6668
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
6769
dt += dt_bias[:, None]
6870
if DT_SOFTPLUS:
69-
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
71+
dt = softplus(dt)
7072
# As of Triton 2.2.0, tl.clamp is not available yet
7173
# dt = tl.clamp(dt, dt_min, dt_max)
7274
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
@@ -139,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
139141
dt += dt_bias[:, None]
140142
if DT_SOFTPLUS:
141143
dt_presoftplus = dt
142-
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), ddt)
144+
dt = softplus(dt)
143145
clamp_mask = (dt < dt_min) | (dt > dt_max)
144146
# As of Triton 2.2.0, tl.clamp is not available yet
145147
# dt = tl.clamp(dt, dt_min, dt_max)

0 commit comments

Comments
 (0)