Skip to content

Commit 46ae7f6

Browse files
authored
[Bugfix] Mamba2 SSD varlen bug fix initstates decay, improve test, assert chunk pwr 2 (#21783)
Signed-off-by: Rishi Astra <[email protected]>
1 parent 1ece7f3 commit 46ae7f6

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def end_boundary(n: int):
187187
[torch.float32, torch.float16, torch.bfloat16])
188188
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
189189
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
190-
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
190+
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
191191
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
192192
itype):
193193

@@ -253,15 +253,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
253253
(8, 8, 16, 32, 16),
254254
]), # mode examples with varied lengths
255255
256-
# odd chunk_size
257-
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
258-
(21, 15)]), # irregular sizes
259-
260256
# large-ish chunk_size (256)
261257
(64, 256, 1, [(5, ), (1, ), (1, ),
262258
(1, )]), # irregular sizes with small sequences
263259
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
264260
(1, 2)]), # irregular sizes with small sequences
261+
262+
# we also need to test some large seqlen
263+
# to catch errors with init states decay
264+
(768, 128, 2, [(138, 225), (138, 225)]),
265265
])
266266
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
267267
itype):
@@ -271,10 +271,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
271271

272272
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
273273

274-
# TODO: the irregular chunk size cases have some issues and require higher
275-
# tolerance. This is to be invesigated
276-
if chunk_size not in {8, 256}:
277-
atol, rtol = 5e-1, 5e-1
274+
# This test can have larger error for longer sequences
275+
if seqlen > 256:
276+
atol, rtol = 1e-2, 5e-3
278277
else:
279278
atol, rtol = 5e-3, 5e-3
280279

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,8 @@ def _chunk_scan_fwd_kernel(
290290
# get the cs at the offset boundary
291291
# - c_off == 0 is a passthrough
292292
dA_cs_m_boundary = tl.load(
293-
dA_cumsum_ptr +
294-
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
295-
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
296-
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
293+
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
294+
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
297295
other=0.0).to(tl.float32)
298296

299297
if HAS_SEQ_IDX:

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
2222

2323

24+
def is_int_pow_2(n):
25+
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
26+
27+
2428
def _mamba_chunk_scan_combined_fwd(x,
2529
dt,
2630
A,
@@ -38,6 +42,7 @@ def _mamba_chunk_scan_combined_fwd(x,
3842
dt_softplus=False,
3943
dt_limit=(0.0, float("inf")),
4044
out=None):
45+
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
4146
batch, seqlen, nheads, headdim = x.shape
4247
_, _, ngroups, dstate = B.shape
4348
assert nheads % ngroups == 0

0 commit comments

Comments
 (0)