@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
193
193
194
194
# this tests the kernels on a single example (no batching)
195
195
196
+ # TODO: the bfloat16 case requires higher thresholds. To be investigated
197
+
198
+ if itype == torch .bfloat16 :
199
+ atol , rtol = 5e-2 , 5e-2
200
+ else :
201
+ atol , rtol = 8e-3 , 5e-3
202
+
196
203
# set seed
197
204
batch_size = 1 # batch_size
198
205
# ssd_minimal_discrete requires chunk_size divide seqlen
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
216
223
return_final_states = True )
217
224
218
225
# just test the last in sequence
219
- torch .allclose (Y [:, - 1 ], Y_min [:, - 1 ], atol = 1e-3 , rtol = 1e-3 )
226
+ torch .testing . assert_close (Y [:, - 1 ], Y_min [:, - 1 ], atol = atol , rtol = rtol )
220
227
221
228
# just test the last head
222
229
# NOTE, in the kernel we always cast states to fp32
223
- torch .allclose (final_state [:, - 1 ],
224
- final_state_min [:, - 1 ].to (torch .float32 ),
225
- atol = 1e-3 ,
226
- rtol = 1e-3 )
230
+ torch .testing . assert_close (final_state [:, - 1 ],
231
+ final_state_min [:, - 1 ].to (torch .float32 ),
232
+ atol = atol ,
233
+ rtol = rtol )
227
234
228
235
229
236
@pytest .mark .parametrize ("itype" , [torch .float32 , torch .float16 ])
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
263
270
264
271
seqlen , chunk_size , num_examples , cases = seq_len_chunk_size_cases
265
272
273
+ # TODO: the irregular chunk size cases have some issues and require higher
274
+ # tolerance. This is to be invesigated
275
+ if chunk_size not in {8 , 256 }:
276
+ atol , rtol = 5e-1 , 5e-1
277
+ else :
278
+ atol , rtol = 5e-3 , 5e-3
279
+
266
280
# hold state during the cutting process so we know if an
267
281
# example has been exhausted and needs to cycle
268
282
last_taken : dict = {} # map: eg -> pointer to last taken sample
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
300
314
# just test one dim and dstate
301
315
Y_eg = Y [0 , cu_seqlens [i ]:cu_seqlens [i + 1 ], 0 , 0 ]
302
316
Y_min_eg = Y_min [i ][:, 0 , 0 ]
303
- torch .allclose (Y_eg , Y_min_eg , atol = 1e-3 , rtol = 1e-3 )
317
+ torch .testing . assert_close (Y_eg , Y_min_eg , atol = atol , rtol = rtol )
304
318
305
319
# update states
306
320
states = new_states
0 commit comments