Skip to content

Commit 2d112a9

Browse files
authored
ENH: Updates to numpy.array_api (numpy#19937)
* Add __index__ to array_api and update __int__, __bool__, and __float__ The spec specifies that they should only work on arrays with corresponding dtypes. __index__ is new in the spec since the initial PR, and works identically to np.array.__index__. * Add the to_device method to the array_api This method is new since numpy#18585. It does nothing in NumPy since NumPy does not support non-CPU devices. * Update transpose methods in the array_api transpose() was renamed to matrix_transpose() and now operates on stacks of matrices. A function to permute dimensions will be added once it is finalized in the spec. The attribute mT was added and the T attribute was updated to only operate on 2-dimensional arrays as per the spec. * Restrict input dtypes in the array API statistical functions * Add the dtype parameter to the array API sum() and prod() * Add the function permute_dims() to the array_api namespace permute_dims() is the replacement for transpose(), which was split into permute_dims() and matrix_transpose(). * Add tril and triu to the array API namespace * Fix the array_api Array.__repr__ to indent the array properly * Make the Device type in the array_api just accept the string "cpu"
1 parent ac78192 commit 2d112a9

File tree

8 files changed

+147
-21
lines changed

8 files changed

+147
-21
lines changed

numpy/array_api/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
meshgrid,
144144
ones,
145145
ones_like,
146+
tril,
147+
triu,
146148
zeros,
147149
zeros_like,
148150
)
@@ -160,6 +162,8 @@
160162
"meshgrid",
161163
"ones",
162164
"ones_like",
165+
"tril",
166+
"triu",
163167
"zeros",
164168
"zeros_like",
165169
]
@@ -333,21 +337,22 @@
333337
# from ._linear_algebra_functions import einsum
334338
# __all__ += ['einsum']
335339

336-
from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
340+
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
337341

338-
__all__ += ["matmul", "tensordot", "transpose", "vecdot"]
342+
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
339343

340344
from ._manipulation_functions import (
341345
concat,
342346
expand_dims,
343347
flip,
348+
permute_dims,
344349
reshape,
345350
roll,
346351
squeeze,
347352
stack,
348353
)
349354

350-
__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"]
355+
__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
351356

352357
from ._searching_functions import argmax, argmin, nonzero, where
353358

numpy/array_api/_array_object.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __repr__(self: Array, /) -> str:
9999
"""
100100
Performs the operation __repr__.
101101
"""
102-
return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})"
102+
prefix = "Array("
103+
suffix = f", dtype={self.dtype.name})"
104+
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
105+
return prefix + mid + suffix
103106

104107
# These are various helper functions to make the array behavior match the
105108
# spec in places where it either deviates from or is more strict than
@@ -391,6 +394,8 @@ def __bool__(self: Array, /) -> bool:
391394
# Note: This is an error here.
392395
if self._array.ndim != 0:
393396
raise TypeError("bool is only allowed on arrays with 0 dimensions")
397+
if self.dtype not in _boolean_dtypes:
398+
raise ValueError("bool is only allowed on boolean arrays")
394399
res = self._array.__bool__()
395400
return res
396401

@@ -429,6 +434,8 @@ def __float__(self: Array, /) -> float:
429434
# Note: This is an error here.
430435
if self._array.ndim != 0:
431436
raise TypeError("float is only allowed on arrays with 0 dimensions")
437+
if self.dtype not in _floating_dtypes:
438+
raise ValueError("float is only allowed on floating-point arrays")
432439
res = self._array.__float__()
433440
return res
434441

@@ -488,9 +495,18 @@ def __int__(self: Array, /) -> int:
488495
# Note: This is an error here.
489496
if self._array.ndim != 0:
490497
raise TypeError("int is only allowed on arrays with 0 dimensions")
498+
if self.dtype not in _integer_dtypes:
499+
raise ValueError("int is only allowed on integer arrays")
491500
res = self._array.__int__()
492501
return res
493502

503+
def __index__(self: Array, /) -> int:
504+
"""
505+
Performs the operation __index__.
506+
"""
507+
res = self._array.__index__()
508+
return res
509+
494510
def __invert__(self: Array, /) -> Array:
495511
"""
496512
Performs the operation __invert__.
@@ -979,6 +995,11 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
979995
res = self._array.__rxor__(other._array)
980996
return self.__class__._new(res)
981997

998+
def to_device(self: Array, device: Device, /) -> Array:
999+
if device == 'cpu':
1000+
return self
1001+
raise ValueError(f"Unsupported device {device!r}")
1002+
9821003
@property
9831004
def dtype(self) -> Dtype:
9841005
"""
@@ -992,6 +1013,12 @@ def dtype(self) -> Dtype:
9921013
def device(self) -> Device:
9931014
return "cpu"
9941015

1016+
# Note: mT is new in array API spec (see matrix_transpose)
1017+
@property
1018+
def mT(self) -> Array:
1019+
from ._linear_algebra_functions import matrix_transpose
1020+
return matrix_transpose(self)
1021+
9951022
@property
9961023
def ndim(self) -> int:
9971024
"""
@@ -1026,4 +1053,9 @@ def T(self) -> Array:
10261053
10271054
See its docstring for more information.
10281055
"""
1056+
# Note: T only works on 2-dimensional arrays. See the corresponding
1057+
# note in the specification:
1058+
# https://data-apis.org/array-api/latest/API_specification/array_object.html#t
1059+
if self.ndim != 2:
1060+
raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
10291061
return self._array.T

numpy/array_api/_creation_functions.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _check_valid_dtype(dtype):
2222
# Note: Only spelling dtypes as the dtype objects is supported.
2323

2424
# We use this instead of "dtype in _all_dtypes" because the dtype objects
25-
# define equality with the sorts of things we want to disallw.
25+
# define equality with the sorts of things we want to disallow.
2626
for d in (None,) + _all_dtypes:
2727
if dtype is d:
2828
return
@@ -281,6 +281,34 @@ def ones_like(
281281
return Array._new(np.ones_like(x._array, dtype=dtype))
282282

283283

284+
def tril(x: Array, /, *, k: int = 0) -> Array:
285+
"""
286+
Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
287+
288+
See its docstring for more information.
289+
"""
290+
from ._array_object import Array
291+
292+
if x.ndim < 2:
293+
# Note: Unlike np.tril, x must be at least 2-D
294+
raise ValueError("x must be at least 2-dimensional for tril")
295+
return Array._new(np.tril(x._array, k=k))
296+
297+
298+
def triu(x: Array, /, *, k: int = 0) -> Array:
299+
"""
300+
Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
301+
302+
See its docstring for more information.
303+
"""
304+
from ._array_object import Array
305+
306+
if x.ndim < 2:
307+
# Note: Unlike np.triu, x must be at least 2-D
308+
raise ValueError("x must be at least 2-dimensional for triu")
309+
return Array._new(np.triu(x._array, k=k))
310+
311+
284312
def zeros(
285313
shape: Union[int, Tuple[int, ...]],
286314
*,

numpy/array_api/_linear_algebra_functions.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,12 @@ def tensordot(
5252
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
5353

5454

55-
def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array:
56-
"""
57-
Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
58-
59-
See its docstring for more information.
60-
"""
61-
return Array._new(np.transpose(x._array, axes=axes))
55+
# Note: this function is new in the array API spec. Unlike transpose, it only
56+
# transposes the last two axes.
57+
def matrix_transpose(x: Array, /) -> Array:
58+
if x.ndim < 2:
59+
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
60+
return Array._new(np.swapaxes(x._array, -1, -2))
6261

6362

6463
# Note: vecdot is not in NumPy

numpy/array_api/_manipulation_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
4141
return Array._new(np.flip(x._array, axis=axis))
4242

4343

44+
# Note: The function name is different here (see also matrix_transpose).
45+
# Unlike transpose(), the axes argument is required.
46+
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
47+
"""
48+
Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
49+
50+
See its docstring for more information.
51+
"""
52+
return Array._new(np.transpose(x._array, axes))
53+
54+
4455
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
4556
"""
4657
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.

numpy/array_api/_statistical_functions.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
from __future__ import annotations
22

3+
from ._dtypes import (
4+
_floating_dtypes,
5+
_numeric_dtypes,
6+
)
37
from ._array_object import Array
8+
from ._creation_functions import asarray
9+
from ._dtypes import float32, float64
410

5-
from typing import Optional, Tuple, Union
11+
from typing import TYPE_CHECKING, Optional, Tuple, Union
12+
13+
if TYPE_CHECKING:
14+
from ._typing import Dtype
615

716
import numpy as np
817

@@ -14,6 +23,8 @@ def max(
1423
axis: Optional[Union[int, Tuple[int, ...]]] = None,
1524
keepdims: bool = False,
1625
) -> Array:
26+
if x.dtype not in _numeric_dtypes:
27+
raise TypeError("Only numeric dtypes are allowed in max")
1728
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
1829

1930

@@ -24,6 +35,8 @@ def mean(
2435
axis: Optional[Union[int, Tuple[int, ...]]] = None,
2536
keepdims: bool = False,
2637
) -> Array:
38+
if x.dtype not in _floating_dtypes:
39+
raise TypeError("Only floating-point dtypes are allowed in mean")
2740
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
2841

2942

@@ -34,6 +47,8 @@ def min(
3447
axis: Optional[Union[int, Tuple[int, ...]]] = None,
3548
keepdims: bool = False,
3649
) -> Array:
50+
if x.dtype not in _numeric_dtypes:
51+
raise TypeError("Only numeric dtypes are allowed in min")
3752
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
3853

3954

@@ -42,8 +57,15 @@ def prod(
4257
/,
4358
*,
4459
axis: Optional[Union[int, Tuple[int, ...]]] = None,
60+
dtype: Optional[Dtype] = None,
4561
keepdims: bool = False,
4662
) -> Array:
63+
if x.dtype not in _numeric_dtypes:
64+
raise TypeError("Only numeric dtypes are allowed in prod")
65+
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
66+
# We need to do so here before computing the product to avoid overflow
67+
if dtype is None and x.dtype == float32:
68+
x = asarray(x, dtype=float64)
4769
return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
4870

4971

@@ -56,6 +78,8 @@ def std(
5678
keepdims: bool = False,
5779
) -> Array:
5880
# Note: the keyword argument correction is different here
81+
if x.dtype not in _floating_dtypes:
82+
raise TypeError("Only floating-point dtypes are allowed in std")
5983
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
6084

6185

@@ -64,8 +88,15 @@ def sum(
6488
/,
6589
*,
6690
axis: Optional[Union[int, Tuple[int, ...]]] = None,
91+
dtype: Optional[Dtype] = None,
6792
keepdims: bool = False,
6893
) -> Array:
94+
if x.dtype not in _numeric_dtypes:
95+
raise TypeError("Only numeric dtypes are allowed in sum")
96+
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
97+
# We need to do so here before summing to avoid overflow
98+
if dtype is None and x.dtype == float32:
99+
x = asarray(x, dtype=float64)
69100
return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
70101

71102

@@ -78,4 +109,6 @@ def var(
78109
keepdims: bool = False,
79110
) -> Array:
80111
# Note: the keyword argument correction is different here
112+
if x.dtype not in _floating_dtypes:
113+
raise TypeError("Only floating-point dtypes are allowed in var")
81114
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))

numpy/array_api/_typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"PyCapsule",
1616
]
1717

18-
from typing import Any, Sequence, Type, Union
18+
from typing import Any, Literal, Sequence, Type, Union
1919

2020
from . import (
2121
Array,
@@ -35,7 +35,7 @@
3535
# similar comment in numpy/typing/_array_like.py
3636
NestedSequence = Sequence[Sequence[Any]]
3737

38-
Device = Any
38+
Device = Literal["cpu"]
3939
Dtype = Type[
4040
Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]
4141
]

numpy/array_api/tests/test_array_object.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import operator
2+
13
from numpy.testing import assert_raises
24
import numpy as np
35

@@ -255,15 +257,31 @@ def _matmul_array_vals():
255257

256258

257259
def test_python_scalar_construtors():
258-
a = asarray(False)
259-
b = asarray(0)
260-
c = asarray(0.0)
260+
b = asarray(False)
261+
i = asarray(0)
262+
f = asarray(0.0)
261263

262-
assert bool(a) == bool(b) == bool(c) == False
263-
assert int(a) == int(b) == int(c) == 0
264-
assert float(a) == float(b) == float(c) == 0.0
264+
assert bool(b) == False
265+
assert int(i) == 0
266+
assert float(f) == 0.0
267+
assert operator.index(i) == 0
265268

266269
# bool/int/float should only be allowed on 0-D arrays.
267270
assert_raises(TypeError, lambda: bool(asarray([False])))
268271
assert_raises(TypeError, lambda: int(asarray([0])))
269272
assert_raises(TypeError, lambda: float(asarray([0.0])))
273+
assert_raises(TypeError, lambda: operator.index(asarray([0])))
274+
275+
# bool/int/float should only be allowed on arrays of the corresponding
276+
# dtype
277+
assert_raises(ValueError, lambda: bool(i))
278+
assert_raises(ValueError, lambda: bool(f))
279+
280+
assert_raises(ValueError, lambda: int(b))
281+
assert_raises(ValueError, lambda: int(f))
282+
283+
assert_raises(ValueError, lambda: float(b))
284+
assert_raises(ValueError, lambda: float(i))
285+
286+
assert_raises(TypeError, lambda: operator.index(b))
287+
assert_raises(TypeError, lambda: operator.index(f))

0 commit comments

Comments
 (0)