From a096b916cd878f2d7c6f84f1087e00e3690af026 Mon Sep 17 00:00:00 2001 From: Weihao Yu Date: Fri, 26 Sep 2025 13:01:53 +0800 Subject: [PATCH] Fix stride handling for optional dt_bias and D --- mamba_ssm/ops/triton/selective_state_update.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index d425bc728..876ff6018 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -199,11 +199,11 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, state.stride(0), state.stride(1), state.stride(2), state.stride(3), x.stride(0), x.stride(1), x.stride(2), dt.stride(0), dt.stride(1), dt.stride(2), - *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + *( (dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0) ), A.stride(0), A.stride(1), A.stride(2), B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), - *(D.stride(0), D.stride(1)) if D is not None else 0, + *( (D.stride(0), D.stride(1)) if D is not None else (0, 0) ), z_strides[0], z_strides[1], z_strides[2], out.stride(0), out.stride(1), out.stride(2), dt_softplus,