Skip to content

Commit cf28fa8

Browse files
committed
ENH: add a strict_check flag to (optionally) check dtypes
This is to decide whether to require matching dtypes or rely on NumPy's lax definition of equality: >>> np.float64(3) == 3 True
1 parent f9d7065 commit cf28fa8

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

scipy_doctest/impl.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ class DTConfig:
4040
rtol : float
4141
Absolute and relative tolerances to check doctest examples with.
4242
Specifically, the check is ``np.allclose(want, got, atol=atol, rtol=rtol)``
43+
strict_check : bool
44+
Whether to check that dtypes match or rely on the lax definition of
45+
equality of numpy objects. For instance, `3 == np.float64(3)`, but
46+
dtypes do not match.
47+
Default is False.
4348
optionflags : int
4449
doctest optionflags
4550
Default is ``NORMALIZE_WHITESPACE | ELLIPSIS | IGNORE_EXCEPTION_DETAIL``
@@ -107,6 +112,7 @@ def __init__(self, *, # DTChecker configuration
107112
rndm_markers=None,
108113
atol=1e-8,
109114
rtol=1e-2,
115+
strict_check=False,
110116
# DTRunner configuration
111117
optionflags=None,
112118
# DTFinder/DTParser configuration
@@ -161,8 +167,8 @@ def __init__(self, *, # DTChecker configuration
161167
'#random', '#Random',
162168
"# may vary"}
163169
self.rndm_markers = rndm_markers
164-
165170
self.atol, self.rtol = atol, rtol
171+
self.strict_check = strict_check
166172

167173
### DTRunner configuration ###
168174

@@ -363,23 +369,35 @@ def check_output(self, want, got, optionflags):
363369
return False
364370

365371
# ... and defer to numpy
372+
strict = self.config.strict_check
366373
try:
367-
return self._do_check(a_want, a_got)
374+
return self._do_check(a_want, a_got, strict)
368375
except Exception:
369376
# heterog tuple, eg (1, np.array([1., 2.]))
370377
try:
371-
return all(self._do_check(w, g) for w, g in zip_longest(a_want, a_got))
378+
return all(
379+
self._do_check(w, g, strict) for w, g in zip_longest(a_want, a_got)
380+
)
372381
except (TypeError, ValueError):
373382
return False
374383

375-
def _do_check(self, want, got):
384+
def _do_check(self, want, got, strict_check):
376385
# This should be done exactly as written to correctly handle all of
377386
# numpy-comparable objects, strings, and heterogeneous tuples
378-
try:
379-
if want == got:
380-
return True
381-
except Exception:
382-
pass
387+
388+
# NB: 3 == np.float64(3.0) but dtypes differ
389+
if strict_check:
390+
want_dtype = np.asarray(want).dtype
391+
got_dtype = np.asarray(got).dtype
392+
if want_dtype != got_dtype:
393+
return False
394+
else:
395+
try:
396+
if want == got:
397+
return True
398+
except Exception:
399+
pass
400+
383401
with warnings.catch_warnings():
384402
# NumPy's ragged array deprecation of np.array([1, (2, 3)])
385403
warnings.simplefilter('ignore', VisibleDeprecationWarning)

scipy_doctest/tests/failure_cases.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,11 @@ def tuple_and_list_2():
4545
>>> (0, 1, 2)
4646
[0, 1, 2]
4747
"""
48+
49+
50+
def dtype_mismatch():
51+
"""
52+
>>> import numpy as np
53+
>>> 3.0
54+
3
55+
"""

scipy_doctest/tests/test_testmod.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ def test_tuple_and_list():
117117
assert res.failed == 2
118118

119119

120+
@pytest.mark.parametrize('strict, num_fails', [(True, 1), (False, 0)])
121+
class TestStrictDType:
122+
def test_np_fix(self, strict, num_fails):
123+
config = DTConfig(strict_check=strict)
124+
res, _ = _testmod(failure_cases,
125+
strategy=[failure_cases.dtype_mismatch],
126+
config=config)
127+
assert res.failed == num_fails
128+
129+
120130
class TestLocalFiles:
121131
def test_local_files(self):
122132
# A doctest tries to open a local file. Test that it works

0 commit comments

Comments
 (0)