Skip to content

Commit c091ccd

Browse files
authored
Merge pull request #151 from ev-br/nan_equal
ENH: treat nans as equal
2 parents 4dd6eea + 5235282 commit c091ccd

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

scpdt/impl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def __init__(self, *, # DTChecker configuration
139139
'float64': np.float64,
140140
'dtype': np.dtype,
141141
'nan': np.nan,
142+
'nanj': np.complex128(1j*np.nan),
143+
'infj': complex(0, np.inf),
142144
'NaN': np.nan,
143145
'inf': np.inf,
144146
'Inf': np.inf, }
@@ -343,7 +345,7 @@ def _do_check(self, want, got):
343345
warnings.simplefilter('ignore', VisibleDeprecationWarning)
344346

345347
# This line is the crux of the whole thing. The rest is mostly scaffolding.
346-
result = np.allclose(want, got, atol=self.atol, rtol=self.rtol)
348+
result = np.allclose(want, got, atol=self.atol, rtol=self.rtol, equal_nan=True)
347349
return result
348350

349351

scpdt/tests/module_cases.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,24 @@ def array_abbreviation():
163163
[0, 0, 0, ..., 0, 999, 0],
164164
[0, 0, 0, ..., 0, 0, 1000]])
165165
"""
166+
167+
def nan_equal():
168+
"""
169+
Test that nans are treated as equal.
170+
171+
>>> import numpy as np
172+
>>> np.nan
173+
np.float64(nan)
174+
175+
Complex nans
176+
>>> np.nan - 1j*np.nan
177+
nan + nanj
178+
179+
>>> np.nan + 1j*np.nan
180+
np.complex128(nan+nanj)
181+
182+
Throw in infs, for a good measure
183+
>>> np.inf + 1j*np.inf
184+
inf + infj
185+
186+
"""

0 commit comments

Comments
 (0)