@@ -40,6 +40,11 @@ class DTConfig:
40
40
rtol : float
41
41
Absolute and relative tolerances to check doctest examples with.
42
42
Specifically, the check is ``np.allclose(want, got, atol=atol, rtol=rtol)``
43
+ strict_check : bool
44
+ Whether to check that dtypes match or rely on the lax definition of
45
+ equality of numpy objects. For instance, `3 == np.float64(3)`, but
46
+ dtypes do not match.
47
+ Default is False.
43
48
optionflags : int
44
49
doctest optionflags
45
50
Default is ``NORMALIZE_WHITESPACE | ELLIPSIS | IGNORE_EXCEPTION_DETAIL``
@@ -107,6 +112,7 @@ def __init__(self, *, # DTChecker configuration
107
112
rndm_markers = None ,
108
113
atol = 1e-8 ,
109
114
rtol = 1e-2 ,
115
+ strict_check = False ,
110
116
# DTRunner configuration
111
117
optionflags = None ,
112
118
# DTFinder/DTParser configuration
@@ -161,8 +167,8 @@ def __init__(self, *, # DTChecker configuration
161
167
'#random' , '#Random' ,
162
168
"# may vary" }
163
169
self .rndm_markers = rndm_markers
164
-
165
170
self .atol , self .rtol = atol , rtol
171
+ self .strict_check = strict_check
166
172
167
173
### DTRunner configuration ###
168
174
@@ -363,23 +369,35 @@ def check_output(self, want, got, optionflags):
363
369
return False
364
370
365
371
# ... and defer to numpy
372
+ strict = self .config .strict_check
366
373
try :
367
- return self ._do_check (a_want , a_got )
374
+ return self ._do_check (a_want , a_got , strict )
368
375
except Exception :
369
376
# heterog tuple, eg (1, np.array([1., 2.]))
370
377
try :
371
- return all (self ._do_check (w , g ) for w , g in zip_longest (a_want , a_got ))
378
+ return all (
379
+ self ._do_check (w , g , strict ) for w , g in zip_longest (a_want , a_got )
380
+ )
372
381
except (TypeError , ValueError ):
373
382
return False
374
383
375
- def _do_check (self , want , got ):
384
+ def _do_check (self , want , got , strict_check ):
376
385
# This should be done exactly as written to correctly handle all of
377
386
# numpy-comparable objects, strings, and heterogeneous tuples
378
- try :
379
- if want == got :
380
- return True
381
- except Exception :
382
- pass
387
+
388
+ # NB: 3 == np.float64(3.0) but dtypes differ
389
+ if strict_check :
390
+ want_dtype = np .asarray (want ).dtype
391
+ got_dtype = np .asarray (got ).dtype
392
+ if want_dtype != got_dtype :
393
+ return False
394
+ else :
395
+ try :
396
+ if want == got :
397
+ return True
398
+ except Exception :
399
+ pass
400
+
383
401
with warnings .catch_warnings ():
384
402
# NumPy's ragged array deprecation of np.array([1, (2, 3)])
385
403
warnings .simplefilter ('ignore' , VisibleDeprecationWarning )
0 commit comments