Skip to content

Commit dc9a25c

Browse files
committed
ENH: handle both NumPy>=2.2 and NumPy<2.2 abbreviations
with and without shapes=
1 parent abbc68d commit dc9a25c

File tree

3 files changed

+49
-55
lines changed

3 files changed

+49
-55
lines changed

.github/workflows/pip.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ jobs:
1515
runs-on: ${{ matrix.os }}
1616
strategy:
1717
matrix:
18-
python-version: ['3.10', '3.11']
18+
python-version: ['3.11', '3.12']
19+
numpy: ['"numpy<2.2"', 'numpy']
1920
os: [ubuntu-latest]
2021
pytest: ['"pytest<8.0"', pytest]
2122
pre: ['', '--pre']
@@ -35,6 +36,7 @@ jobs:
3536
python -m pip install --upgrade pip
3637
python -m pip install ${{matrix.pytest}} ${{matrix.pre}}
3738
python -m pip install -e . ${{matrix.pre}}
39+
python -m pip install ${{matrix.numpy}}
3840
3941
- name: Echo versions
4042
run: |

scipy_doctest/impl.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,12 @@ def __init__(self, *, # DTChecker configuration
148148
'masked_array': np.ma.masked_array,
149149
'int64': np.int64,
150150
'uint64': np.uint64,
151-
'int8': np.int8,
152151
'int32': np.int32,
152+
'uint32': np.uint32,
153+
'int16': np.int16,
154+
'uint16': np.uint16,
155+
'int8': np.int8,
156+
'uint8': np.uint8,
153157
'float32': np.float32,
154158
'float64': np.float64,
155159
'dtype': np.dtype,
@@ -262,20 +266,27 @@ def has_masked(got):
262266
return 'masked_array' in got and '--' in got
263267

264268

265-
def remove_shape_from_abbrv(s_got):
266-
"""NumPy 2.2 added shape=(123,) to abbreviated array repr. Remove it.
269+
def try_split_shape_from_abbrv(s_got):
270+
"""NumPy 2.2 added shape=(123,) to abbreviated array repr.
271+
272+
If present, split it off, and return a tuple. `(array, shape)`
267273
"""
268274
if "shape=" in s_got:
269275
# handle
270276
# array(..., shape=(1000,))
271277
# array(..., shape=(100, 100))
272278
# array(..., shape=(100, 100), dtype=uint16)
273-
grp = re.match(r'(.+) shape=\(([\d\s,]+\))(.+)', s_got, flags=re.DOTALL).groups()
279+
match = re.match(r'(.+),\s+shape=\(([\d\s,]+)\)(.+)', s_got, flags=re.DOTALL)
280+
if match:
281+
grp = match.groups()
274282

275-
s_got = grp[0] + grp[-1]
276-
s_got = s_got.replace(',,', ',')
283+
s_got = grp[0] + grp[-1]
284+
s_got = s_got.replace(',,', ',')
285+
shape_str = f'({grp[1]})'
277286

278-
return ''.join(s_got.split('...,'))
287+
return ''.join(s_got.split('...,')), shape_str
288+
289+
return ''.join(s_got.split('...,')), ''
279290

280291

281292
class DTChecker(doctest.OutputChecker):
@@ -344,8 +355,14 @@ def check_output(self, want, got, optionflags):
344355
ndim_array = (s_want.startswith("array([") and "..." in s_want and
345356
s_got.startswith("array([") and "..." in s_got)
346357
if ndim_array:
347-
s_want = remove_shape_from_abbrv(s_want)
348-
s_got = remove_shape_from_abbrv(s_got)
358+
s_want, want_shape = try_split_shape_from_abbrv(s_want)
359+
s_got, got_shape = try_split_shape_from_abbrv(s_got)
360+
361+
if got_shape:
362+
# NumPy 2.2 output, `with shape=`, check the shapes, too
363+
s_want = f"{s_want}, {want_shape}"
364+
s_got = f"{s_got}, {got_shape}"
365+
349366
return self.check_output(s_want, s_got, optionflags)
350367

351368
# maybe we are dealing with masked arrays?

scipy_doctest/tests/module_cases.py

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,25 @@ def array_abbreviation():
164164
"""
165165
Numpy abbreviates arrays, check that it works.
166166
167-
NB: the implementation might need to change when
168-
numpy finally disallows default-creating ragged arrays.
169-
Currently, `...` gets interpreted as an Ellipsis,
170-
thus the `a_want/a_got` variables in DTChecker are in fact
171-
object arrays.
167+
XXX: check if ... creates ragged arrays, avoid if so.
168+
169+
NumPy 2.2 abbreviations
170+
=======================
171+
172+
NumPy 2.2 adds shape=(...) to abbreviated arrays.
173+
174+
This is not a valid argument to `array(...), so it cannot be eval-ed,
175+
and need to be removed for doctesting.
176+
177+
The implementation handles both formats, and checks the shapes if present
178+
in the actual output. If not present in the output, they are ignored.
179+
172180
>>> import numpy as np
173181
>>> np.arange(10000)
174-
array([0, 1, 2, ..., 9997, 9998, 9999])
182+
array([0, 1, 2, ..., 9997, 9998, 9999], shape=(10000,))
183+
184+
>>> np.arange(10000, dtype=np.uint16)
185+
array([ 0, 1, 2, ..., 9997, 9998, 9999], shape=(10000,), dtype=uint16)
175186
176187
>>> np.diag(np.arange(33)) / 30
177188
array([[0., 0., 0., ..., 0., 0.,0.],
@@ -180,53 +191,17 @@ def array_abbreviation():
180191
...,
181192
[0., 0., 0., ..., 1., 0., 0.],
182193
[0., 0., 0., ..., 0., 1.03333333, 0.],
183-
[0., 0., 0., ..., 0., 0., 1.06666667]])
194+
[0., 0., 0., ..., 0., 0., 1.06666667]], shape=(33, 33))
184195
185196
186-
>>> np.diag(np.arange(1, 1001, dtype=float))
197+
>>> np.diag(np.arange(1, 1001, dtype=np.uint16))
187198
array([[1, 0, 0, ..., 0, 0, 0],
188199
[0, 2, 0, ..., 0, 0, 0],
189200
[0, 0, 3, ..., 0, 0, 0],
190201
...,
191202
[0, 0, 0, ..., 998, 0, 0],
192203
[0, 0, 0, ..., 0, 999, 0],
193-
[0, 0, 0, ..., 0, 0, 1000]])
194-
"""
195-
196-
197-
def array_abbreviation_2():
198-
""" NumPy 2.2 adds shape=(...) to abbreviated arrays.
199-
200-
So the actual numpy==2.2.0 output below is
201-
# array([ 0, 1, 2, ..., 9997, 9998, 9999], shape=(10000,))
202-
203-
This is not a valid argument to `array(...), so it cannot be eval-ed,
204-
and need to be removed for doctesting.
205-
206-
>>> import numpy as np
207-
>>> np.arange(10000)
208-
array([ 0, 1, 2, ..., 9997, 9998, 9999])
209-
210-
>>> np.arange(10000, dtype=np.uint16)
211-
array([ 0, 1, 2, ..., 9997, 9998, 9999], dtype=np.uint16)
212-
213-
>>> np.arange(5000).reshape(50, 100)
214-
array([[ 0, 1, 2, ..., 97, 98, 99],
215-
[ 100, 101, 102, ..., 197, 198, 199],
216-
[ 200, 201, 202, ..., 297, 298, 299],
217-
...,
218-
[4700, 4701, 4702, ..., 4797, 4798, 4799],
219-
[4800, 4801, 4802, ..., 4897, 4898, 4899],
220-
[4900, 4901, 4902, ..., 4997, 4998, 4999]])
221-
222-
>>> np.arange(5000, dtype=np.uint16).reshape(50, 100)
223-
array([[ 0, 1, 2, ..., 97, 98, 99],
224-
[ 100, 101, 102, ..., 197, 198, 199],
225-
[ 200, 201, 202, ..., 297, 298, 299],
226-
...,
227-
[4700, 4701, 4702, ..., 4797, 4798, 4799],
228-
[4800, 4801, 4802, ..., 4897, 4898, 4899],
229-
[4900, 4901, 4902, ..., 4997, 4998, 4999]], dtype=uint16)
204+
[0, 0, 0, ..., 0, 0, 1000]], shape=(1000, 1000), dtype=uint16)
230205
"""
231206

232207

0 commit comments

Comments
 (0)