Skip to content

Commit 5d6a2db

Browse files
Fix mamba cumsum padded calculations (#1022)
Cherry-pick #1009 --------- Signed-off-by: Jan Kaniecki <jkaniecki@habana.ai> Co-authored-by: Artur Fierka <artur.fierka@intel.com>
1 parent fd33cc3 commit 5d6a2db

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

vllm_gaudi/ops/hpu_mamba_mixer2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def conv_ssm_forward(
437437
dt_limit=(0.0, float("inf")),
438438
out=output.view(output.shape[0], -1, self.head_dim),
439439
state_dtype=ssm_state.dtype,
440+
padding_mask=padding_mask_flat,
440441
)[last_chunk_indices_p]
441442
output = output * padding_mask_flat.view(output.shape[0], 1)
442443

vllm_gaudi/ops/pytorch_implementation.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
from einops import rearrange, repeat
66

77

8-
def new_chunk_cumsum(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
8+
def new_chunk_cumsum(dt,
9+
A,
10+
chunk_size,
11+
dt_bias=None,
12+
dt_softplus=False,
13+
dt_limit=(0.0, float("inf")),
14+
padding_mask=None):
915
"""
1016
Arguments:
1117
dt: Tensor - (seqlen, nheads)
@@ -14,6 +20,7 @@ def new_chunk_cumsum(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limi
1420
dt_bias: Optional Tensor - (nheads)
1521
dt_softplus: bool
1622
dt_limit: tuple - (min: float, max: float)
23+
padding_mask: Optional Tensor - (seqlen, 1) or (seqlen,)
1724
1825
Return:
1926
dA_cumsum: Tensor - (nheads, nchunks, chunk_size)
@@ -32,6 +39,10 @@ def new_chunk_cumsum(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limi
3239
dt = torch.where(dt <= 20.0, F.softplus(dt), dt)
3340

3441
dt = torch.clamp(dt, dt_min, dt_max)
42+
43+
if padding_mask is not None:
44+
dt = dt * padding_mask.view(seqlen, 1).float()
45+
3546
dA = dt * A.view(1, nheads)
3647
dA = dA.transpose(0, 1).reshape(nheads, nchunks, chunk_size)
3748
dt = dt.transpose(0, 1).reshape(nheads, nchunks, chunk_size)

vllm_gaudi/ops/ssd_combined.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _mamba_chunk_scan_combined_fwd(
3636
dt_softplus=False,
3737
dt_limit=(0.0, float("inf")),
3838
state_dtype=None,
39+
padding_mask=None,
3940
):
4041
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
4142
seqlen, nheads, headdim = x.shape
@@ -82,6 +83,7 @@ def _mamba_chunk_scan_combined_fwd(
8283
dt_bias=dt_bias,
8384
dt_softplus=dt_softplus,
8485
dt_limit=dt_limit,
86+
padding_mask=padding_mask,
8587
)
8688

8789
# 2. Compute the state for each intra-chunk
@@ -143,6 +145,7 @@ def hpu_mamba_chunk_scan_combined_varlen(
143145
dt_softplus=False,
144146
dt_limit=(0.0, float("inf")),
145147
state_dtype=None,
148+
padding_mask=None,
146149
):
147150
"""
148151
Argument:
@@ -185,6 +188,7 @@ def hpu_mamba_chunk_scan_combined_varlen(
185188
dt_softplus=dt_softplus,
186189
dt_limit=dt_limit,
187190
state_dtype=state_dtype,
191+
padding_mask=padding_mask,
188192
)
189193

190194
return varlen_states

0 commit comments

Comments
 (0)