Skip to content

Commit 19cbe87

Browse files
committed
ENH: more robust detection of printed arrays
Old-style numpy format, `[ 0 1 2 ]` was adding whitespace after the opening bracket. This was throwing off the attempt at reinserting the commas, which was trying to eval `[, 0, 1, 2, ]` instead of `[0, 1, 2]`. Thus the checker was falling back to the (whitespace-sensitive) vanilla doctest.
1 parent 90ac42a commit 19cbe87

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

scipy_doctest/impl.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@ def try_convert_namedtuple(got):
219219
return got_again
220220

221221

222+
def try_convert_printed_array(got):
223+
"""Printed arrays: reinsert commas.
224+
"""
225+
# a minimal version is `s_got = ", ".join(got[1:-1].split())`
226+
# but it fails if there's a space after the opening bracket: "[ 0 1 2 ]"
227+
# For 2D arrays, split into rows, drop spurious entries, then reassemble.
228+
if not got.startswith('['):
229+
return got
230+
231+
g1 = got[1:-1] # strip outer "[...]"-s
232+
rows = [x for x in g1.split("[") if x]
233+
rows2 = [", ".join(row.split()) for row in rows]
234+
235+
if got.startswith("[["):
236+
# was a 2D array, restore the opening brackets in rows; XXX clean up
237+
rows3 = ["[" + row for row in rows2]
238+
else:
239+
rows3 = rows2
240+
241+
# add back the outer brackets
242+
s_got = "[" + ", ".join(rows3) + "]"
243+
return s_got
244+
245+
222246
def has_masked(got):
223247
return 'masked_array' in got and '--' in got
224248

@@ -280,8 +304,9 @@ def check_output(self, want, got, optionflags):
280304
cond = (s_want.startswith("[") and s_want.endswith("]") and
281305
s_got.startswith("[") and s_got.endswith("]"))
282306
if cond:
283-
s_want = ", ".join(s_want[1:-1].split())
284-
s_got = ", ".join(s_got[1:-1].split())
307+
s_want = try_convert_printed_array(s_want)
308+
s_got = try_convert_printed_array(s_got)
309+
285310
return self.check_output(s_want, s_got, optionflags)
286311

287312
#handle array abbreviation for n-dimensional arrays, n >= 1

scipy_doctest/tests/module_cases.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,32 @@ def func3():
4444
"""
4545

4646

47+
def func_printed_arrays():
48+
"""
49+
Check various ways handling of printed arrays can go wrong.
50+
51+
>>> import numpy as np
52+
>>> a = np.arange(8).reshape(2, 4) / 3
53+
>>> print(a) # numpy 1.26.4
54+
[[0. 0.33333333 0.66666667 1. ]
55+
[1.33333333 1.66666667 2. 2.33333333]]
56+
57+
>>> print(a) # add spaces (older repr?)
58+
[[ 0. 0.33333333 0.66666667 1. ]
59+
[ 1.33333333 1.66666667 2. 2.33333333 ]]
60+
61+
Also check 1D arrays
62+
>>> a1 = np.arange(3)
63+
>>> print(a1)
64+
[0 1 2]
65+
>>> print(a1)
66+
[ 0 1 2]
67+
>>> print(a1)
68+
[ 0 1 2 ]
69+
70+
"""
71+
72+
4773
def func4():
4874
"""
4975
Test `# may vary` markers : these should not break doctests (but the code

0 commit comments

Comments
 (0)