-
Notifications
You must be signed in to change notification settings - Fork 3
ENH: asarray
for Array API support
#24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
ea9c544
28a6339
65aca1e
912e55a
de56fe7
ca5ff59
cec8bd2
94d3044
2ef5892
ad770b8
039e931
bb8e89c
c4a2c7f
31a0300
3732e1c
1ebc347
4333c99
068b4fa
d72de86
d748ed1
0a1aa67
2a4f2f7
633648b
735101a
eb79de0
b4b9f32
1623d0e
abdb9b2
c432d88
603fbf3
a0b06cf
37c4c0f
afc5fdc
f900a4d
d6b568a
560b591
9ea70e6
e54c36d
9da0cd9
85c7045
8bc56ab
2eb7791
4596247
67ac6e0
52e982d
d8f66dd
488e3c3
77feb2f
d2d2e8d
e395445
1146129
bc6c61b
61d5860
dd2cab5
0469887
276d3f5
06e59f3
82a8aa3
de2257f
85fccda
0c90e8d
1d090e6
1f3856f
ea6f58b
a33efc6
4dbde35
bef978c
60e3103
b201851
0f022a1
226ac88
f7f2568
050a96e
f2dfb3c
4787f50
a38cd2e
2d12ad4
7aea9d3
0e2c5b1
15286a6
92814ac
b7455f2
5744618
5033b54
8bd49ae
50d7ea4
84cfe65
b9fe722
98f8309
f0c2ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""Utility functions to use Python Array API compatible libraries. | ||
|
||
For the context about the Array API see: | ||
https://data-apis.org/array-api/latest/purpose_and_scope.html | ||
|
||
The SciPy use case of the Array API is described on the following page: | ||
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy | ||
""" | ||
import os | ||
|
||
import numpy as np | ||
import numpy.array_api | ||
# probably want to vendor it (submodule) | ||
import array_api_compat | ||
import array_api_compat.numpy | ||
|
||
__all__ = ['array_namespace', 'asarray', 'asarray_namespace'] | ||
|
||
|
||
# SCIPY_ARRAY_API, array_api_dispatch is used by sklearn | ||
array_api_dispatch = os.environ.get("array_api_dispatch", False) | ||
SCIPY_ARRAY_API = os.environ.get("SCIPY_ARRAY_API", array_api_dispatch) | ||
|
||
_GLOBAL_CONFIG = {"SCIPY_ARRAY_API": SCIPY_ARRAY_API} | ||
|
||
|
||
def compliance_scipy(*arrays): | ||
for array in arrays: | ||
if isinstance(array, np.ma.MaskedArray): | ||
raise TypeError("'numpy.ma.MaskedArray' are not supported") | ||
elif isinstance(array, np.matrix): | ||
raise TypeError("'numpy.matrix' are not supported") | ||
|
||
|
||
def array_namespace(*arrays, single_namespace=True): | ||
compliance_scipy(*arrays) | ||
|
||
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: | ||
return np | ||
|
||
# if we cannot get the namespace, np is used | ||
# here until moved upstream | ||
namespaces = set() | ||
for array in arrays: | ||
try: | ||
namespaces.add(array_api_compat.array_namespace(array)) | ||
except TypeError: | ||
namespaces.add(array_api_compat.numpy) | ||
|
||
if single_namespace and len(namespaces) != 1: | ||
raise ValueError( | ||
f"Expected a single common namespace for array inputs, \ | ||
but got: {[n.__name__ for n in namespaces]}" | ||
) | ||
|
||
(xp,) = namespaces | ||
|
||
return xp | ||
|
||
|
||
def asarray(array, dtype=None, order=None, copy=None, *, xp=None): | ||
"""Drop-in replacement for `np.asarray`. | ||
|
||
Memory layout parameter `order` is not exposed in the Array API standard. | ||
`order` is only enforced if the input array implementation | ||
is NumPy based, otherwise `order` is just silently ignored. | ||
|
||
The purpose of this helper is to make it possible to share code for data | ||
container validation without memory copies for both downstream use cases | ||
""" | ||
if xp is None: | ||
xp = array_namespace(array) | ||
if xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"}: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No one's actually doing it yet as far as I know, but this wouldn't work if someone vendors array_api_compat and tries to call a scipy function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I don't know if it makes sense to list There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For scikit-learn, we wanted the performance with def scipy_func(X):
xp = array_namespace(X)
# switch order for performance reasons
X_f = asarray(X, xp, order="F")
# Do some operations that require prefer F ordered.
return xp.sum(X_f, axis=0) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way something like _X = numpy.asarray(..., order="F")
X = numpy.array_api.asarray(X) will also work. That's maybe a little more "spec compliant" in the sense that converting arrays from one library to another with asarray is supported. In this case it's a trivial zero-copy wrapping but in general it will use DLPack (although I don't know how DLPack handles order, so maybe someone could confirm whether this would actually work in a more general setting). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually that's wrong. I thought asarray in the spec used dlpack, but it's only numpy.asarray that does. In the spec you have to use from_dlpack (I'm not sure why they are separate). |
||
# Use NumPy API to support order | ||
if copy is True: | ||
array = np.array(array, order=order, dtype=dtype) | ||
else: | ||
array = np.asarray(array, order=order, dtype=dtype) | ||
|
||
# At this point array is a NumPy ndarray. We convert it to an array | ||
# container that is consistent with the input's namespace. | ||
return xp.asarray(array) | ||
else: | ||
return xp.asarray(array, dtype=dtype, copy=copy) | ||
|
||
|
||
def asarray_namespace(*arrays): | ||
"""Validate and convert arrays to a common namespace. | ||
|
||
Parameters | ||
---------- | ||
*arrays : sequence of array_like | ||
Arrays to validate and convert. | ||
|
||
Returns | ||
------- | ||
*arrays : sequence of array_like | ||
Validated and converted arrays to the common namespace. | ||
namespace : module | ||
Common namespace. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> x, y, xp = asarray_namespace([0, 1, 2], np.arange(3)) | ||
>>> xp.__name__ | ||
'array_api_compat.numpy' | ||
>>> x, y | ||
(array([0, 1, 2]]), array([0, 1, 2])) | ||
|
||
""" | ||
tupui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arrays = list(arrays) # probably not good | ||
xp = array_namespace(*arrays) | ||
|
||
for i, array in enumerate(arrays): | ||
arrays[i] = asarray(array, xp=xp) | ||
|
||
return *arrays, xp | ||
|
||
|
||
def to_numpy(array, xp): | ||
"""Convert `array` into a NumPy ndarray on the CPU. | ||
|
||
This is specially useful to pass arrays to Cython. | ||
""" | ||
xp_name = xp.__name__ | ||
|
||
if xp_name in {"array_api_compat.torch", "torch"}: | ||
return array.cpu().numpy() | ||
|
||
elif xp_name == "cupy.array_api": | ||
return array._array.get() | ||
elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover | ||
return array.get() | ||
|
||
return np.asarray(array) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
from numpy.testing import assert_equal | ||
import pytest | ||
|
||
from scipy.conftest import array_api_compatible | ||
from scipy._lib._array_api import ( | ||
_GLOBAL_CONFIG, array_namespace, asarray, asarray_namespace, | ||
to_numpy | ||
) | ||
|
||
|
||
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: | ||
pytest.skip( | ||
"Array API test; set environment variable array_api_dispatch=1 to run it", | ||
allow_module_level=True | ||
) | ||
|
||
|
||
def test_array_namespace(): | ||
x, y = [0, 1, 2], np.arange(3) | ||
xp = array_namespace(x, y) | ||
assert xp.__name__ == 'array_api_compat.numpy' | ||
|
||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False | ||
xp = array_namespace(x, y) | ||
assert xp.__name__ == 'numpy' | ||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True | ||
|
||
|
||
@array_api_compatible | ||
def test_asarray(xp): | ||
x, y = asarray([0, 1, 2], xp=xp), asarray(np.arange(3), xp=xp) | ||
ref = np.array([0, 1, 2]) | ||
assert_equal(x, ref) | ||
assert_equal(y, ref) | ||
|
||
|
||
def test_asarray_namespace(): | ||
x, y = [0, 1, 2], np.arange(3) | ||
x, y, xp_ = asarray_namespace(x, y) | ||
assert xp_.__name__ == 'array_api_compat.numpy' | ||
ref = np.array([0, 1, 2]) | ||
assert_equal(x, ref) | ||
assert_equal(y, ref) | ||
assert type(x) == type(y) | ||
|
||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False | ||
x, y, xp_ = asarray_namespace(x, y) | ||
assert xp_.__name__ == 'numpy' | ||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True | ||
|
||
|
||
@array_api_compatible | ||
def test_to_numpy(xp): | ||
x = xp.asarray([0, 1, 2]) | ||
x = to_numpy(x, xp=xp) | ||
assert isinstance(x, np.ndarray) | ||
|
||
|
||
@pytest.mark.filterwarnings("ignore: the matrix subclass") | ||
def test_raises(): | ||
msg = "'numpy.ma.MaskedArray' are not supported" | ||
with pytest.raises(TypeError, match=msg): | ||
array_namespace(np.ma.array(1), np.array(1)) | ||
|
||
msg = "'numpy.matrix' are not supported" | ||
with pytest.raises(TypeError, match=msg): | ||
array_namespace(np.array(1), np.matrix(1)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
array_api_compat.array_namespace already accepts multiple arrays. Is the issue that there isn't a way to specify numpy as a default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I would need the function to not raise but return NumPy. This is because we want to accept things like Python lists or scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can add a flag to
array_namespace
likedefault=numpy
that makes it do this. That function's not part of the spec so we can adjust it however makes it most useful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would help with our use case 😃
Quick question, should this really fail?
This is why I have the following to go around and convert
numpy.array_api
toarray_api_compat.numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And in general any compat code that's useful to more than one library can go in the compat library (assuming it's not too complex and pure Python).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I can do a PR for that if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't try mixing numpy and numpy.array_api. numpy.array_api should only be used for testing purposes. It implements a strict version of the standard so you can use it to check that you are saying within the spec. It shouldn't be used for actual user code.
The whole purpose of the compat library is to provide sufficient wrappers around numpy itself to make it array API compatible. Using numpy.array_api for user code was found to be too challenging because it uses a different array class from NumPy, which is not what most users want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmm ok so when I test, I should consider that
numpy.array_api
is something different thannumpy
such as lets saycupy
. Makes sense to me 👍 Thanks for the explanations Aaron 😃There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@asmeurer I opened data-apis/array-api-compat#39