Skip to content

Commit 32ec9e2

Browse files
authored
Mamba V2 Test not Asserting Failures. (#21379)
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent accac82 commit 32ec9e2

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

tests/kernels/mamba/test_mamba_mixer2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
119119
gate_states[..., local_rank * N:(local_rank + 1) * N],
120120
)
121121
ref_output = mixer_single_gpu(hidden_states, gate_states)
122-
torch.allclose(output,
123-
ref_output[..., local_rank * N:(local_rank + 1) * N],
124-
atol=1e-3,
125-
rtol=1e-3)
122+
torch.testing.assert_close(output,
123+
ref_output[...,
124+
local_rank * N:(local_rank + 1) * N],
125+
atol=5e-3,
126+
rtol=1e-3)

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
193193

194194
# this tests the kernels on a single example (no batching)
195195

196+
# TODO: the bfloat16 case requires higher thresholds. To be investigated
197+
198+
if itype == torch.bfloat16:
199+
atol, rtol = 5e-2, 5e-2
200+
else:
201+
atol, rtol = 8e-3, 5e-3
202+
196203
# set seed
197204
batch_size = 1 # batch_size
198205
# ssd_minimal_discrete requires chunk_size divide seqlen
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
216223
return_final_states=True)
217224

218225
# just test the last in sequence
219-
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
226+
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
220227

221228
# just test the last head
222229
# NOTE, in the kernel we always cast states to fp32
223-
torch.allclose(final_state[:, -1],
224-
final_state_min[:, -1].to(torch.float32),
225-
atol=1e-3,
226-
rtol=1e-3)
230+
torch.testing.assert_close(final_state[:, -1],
231+
final_state_min[:, -1].to(torch.float32),
232+
atol=atol,
233+
rtol=rtol)
227234

228235

229236
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
263270

264271
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
265272

273+
# TODO: the irregular chunk size cases have some issues and require higher
274+
# tolerance. This is to be invesigated
275+
if chunk_size not in {8, 256}:
276+
atol, rtol = 5e-1, 5e-1
277+
else:
278+
atol, rtol = 5e-3, 5e-3
279+
266280
# hold state during the cutting process so we know if an
267281
# example has been exhausted and needs to cycle
268282
last_taken: dict = {} # map: eg -> pointer to last taken sample
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
300314
# just test one dim and dstate
301315
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
302316
Y_min_eg = Y_min[i][:, 0, 0]
303-
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
317+
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
304318

305319
# update states
306320
states = new_states

0 commit comments

Comments
 (0)