-
Notifications
You must be signed in to change notification settings - Fork 3
ENH: asarray
for Array API support
#24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
ea9c544
28a6339
65aca1e
912e55a
de56fe7
ca5ff59
cec8bd2
94d3044
2ef5892
ad770b8
039e931
bb8e89c
c4a2c7f
31a0300
3732e1c
1ebc347
4333c99
068b4fa
d72de86
d748ed1
0a1aa67
2a4f2f7
633648b
735101a
eb79de0
b4b9f32
1623d0e
abdb9b2
c432d88
603fbf3
a0b06cf
37c4c0f
afc5fdc
f900a4d
d6b568a
560b591
9ea70e6
e54c36d
9da0cd9
85c7045
8bc56ab
2eb7791
4596247
67ac6e0
52e982d
d8f66dd
488e3c3
77feb2f
d2d2e8d
e395445
1146129
bc6c61b
61d5860
dd2cab5
0469887
276d3f5
06e59f3
82a8aa3
de2257f
85fccda
0c90e8d
1d090e6
1f3856f
ea6f58b
a33efc6
4dbde35
bef978c
60e3103
b201851
0f022a1
226ac88
f7f2568
050a96e
f2dfb3c
4787f50
a38cd2e
2d12ad4
7aea9d3
0e2c5b1
15286a6
92814ac
b7455f2
5744618
5033b54
8bd49ae
50d7ea4
84cfe65
b9fe722
98f8309
f0c2ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
"""Utility functions to use Python Array API compatible libraries. | ||
|
||
For the context about the Array API see: | ||
https://data-apis.org/array-api/latest/purpose_and_scope.html | ||
|
||
The SciPy use case of the Array API is described on the following page: | ||
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy | ||
""" | ||
import os | ||
|
||
import numpy as np | ||
# probably want to vendor it (submodule) | ||
import array_api_compat | ||
import array_api_compat.numpy | ||
|
||
__all__ = ['array_namespace', 'asarray', 'asarray_namespace'] | ||
|
||
|
||
# SCIPY_ARRAY_API, array_api_dispatch is used by sklearn | ||
array_api_dispatch = os.environ.get("array_api_dispatch", False) | ||
SCIPY_ARRAY_API = os.environ.get("SCIPY_ARRAY_API", array_api_dispatch) | ||
|
||
_GLOBAL_CONFIG = {"SCIPY_ARRAY_API": SCIPY_ARRAY_API} | ||
|
||
|
||
def compliance_scipy(*arrays): | ||
"""Raise exceptions on known-bad subclasses. | ||
|
||
The following subclasses are not supported and raise and error: | ||
- `np.ma.MaskedArray` | ||
- `numpy.matrix` | ||
- Any array-like which is not Array API compatible | ||
""" | ||
for array in arrays: | ||
if isinstance(array, np.ma.MaskedArray): | ||
raise TypeError("'numpy.ma.MaskedArray' are not supported") | ||
elif isinstance(array, np.matrix): | ||
raise TypeError("'numpy.matrix' are not supported") | ||
elif not array_api_compat.is_array_api_obj(array): | ||
raise TypeError("Only support Array API compatible arrays") | ||
|
||
|
||
def array_namespace(*arrays): | ||
"""Get the array API compatible namespace for the arrays xs. | ||
|
||
Parameters | ||
---------- | ||
*arrays : sequence of array_like | ||
Arrays used to infer the common namespace. | ||
|
||
Returns | ||
------- | ||
namespace : module | ||
Common namespace. | ||
|
||
Notes | ||
----- | ||
Thin wrapper around `array_api_compat.array_namespace`. | ||
|
||
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed | ||
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. | ||
2. `compliance_scipy` raise exceptions on known-bad subclasses. See | ||
it's definition for more details. | ||
|
||
When the global switch is False, it defaults to the `numpy` namespace. | ||
In that case, there is no compliance check. This is a convenience to | ||
ease the adoption. Otherwise, arrays must comply with the new rules. | ||
""" | ||
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: | ||
# here we could wrap the namespace if needed | ||
return np | ||
|
||
compliance_scipy(*arrays) | ||
|
||
return array_api_compat.array_namespace(*arrays) | ||
|
||
|
||
def asarray(array, dtype=None, order=None, copy=None, *, xp=None): | ||
"""Drop-in replacement for `np.asarray`. | ||
|
||
Memory layout parameter `order` is not exposed in the Array API standard. | ||
`order` is only enforced if the input array implementation | ||
is NumPy based, otherwise `order` is just silently ignored. | ||
|
||
The purpose of this helper is to make it possible to share code for data | ||
container validation without memory copies for both downstream use cases. | ||
""" | ||
if xp is None: | ||
xp = array_namespace(array) | ||
if xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"}: | ||
# Use NumPy API to support order | ||
if copy is True: | ||
array = np.array(array, order=order, dtype=dtype) | ||
else: | ||
array = np.asarray(array, order=order, dtype=dtype) | ||
|
||
# At this point array is a NumPy ndarray. We convert it to an array | ||
# container that is consistent with the input's namespace. | ||
return xp.asarray(array) | ||
else: | ||
return xp.asarray(array, dtype=dtype, copy=copy) | ||
|
||
|
||
def asarray_namespace(*arrays): | ||
"""Validate and convert arrays to a common namespace. | ||
|
||
Parameters | ||
---------- | ||
*arrays : sequence of array_like | ||
Arrays to validate and convert. | ||
|
||
Returns | ||
------- | ||
*arrays : sequence of array_like | ||
Validated and converted arrays to the common namespace. | ||
namespace : module | ||
Common namespace. | ||
|
||
Notes | ||
----- | ||
This function is meant to be called from each public function in a SciPy | ||
submodule it does the following: | ||
|
||
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed | ||
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. | ||
2. `compliance_scipy` raise exceptions on known-bad subclasses. See | ||
it's definition for more details. | ||
3. Determine the namespace, without doing any coercion of array(-like) | ||
inputs. | ||
4. Call `xp.asarray` on all array. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> x, y, xp = asarray_namespace(np.array([0, 1, 2]), np.array([0, 1, 2])) | ||
>>> xp.__name__ | ||
'array_api_compat.numpy' | ||
>>> x, y | ||
(array([0, 1, 2]), array([0, 1, 2])) | ||
|
||
""" | ||
tupui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arrays = list(arrays) | ||
xp = array_namespace(*arrays) | ||
|
||
for i, array in enumerate(arrays): | ||
arrays[i] = xp.asarray(array) | ||
|
||
return *arrays, xp | ||
|
||
|
||
def to_numpy(array, xp): | ||
"""Convert `array` into a NumPy ndarray on the CPU. | ||
|
||
ONLY FOR TESTING | ||
""" | ||
xp_name = xp.__name__ | ||
|
||
if xp_name in {"array_api_compat.torch", "torch"}: | ||
return array.cpu().numpy() | ||
|
||
elif xp_name == "cupy.array_api": | ||
return array._array.get() | ||
elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover | ||
return array.get() | ||
|
||
return np.asarray(array) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import numpy as np | ||
from numpy.testing import assert_equal | ||
import pytest | ||
|
||
from scipy.conftest import array_api_compatible | ||
from scipy._lib._array_api import ( | ||
_GLOBAL_CONFIG, array_namespace, asarray, asarray_namespace, | ||
to_numpy | ||
) | ||
|
||
|
||
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: | ||
pytest.skip( | ||
"Array API test; set environment variable array_api_dispatch=1 to run it", | ||
allow_module_level=True | ||
) | ||
|
||
|
||
def test_array_namespace(): | ||
x, y = np.array([0, 1, 2]), np.array([0, 1, 2]) | ||
xp = array_namespace(x, y) | ||
assert xp.__name__ == 'array_api_compat.numpy' | ||
|
||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False | ||
xp = array_namespace(x, y) | ||
assert xp.__name__ == 'numpy' | ||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True | ||
|
||
|
||
@array_api_compatible | ||
def test_asarray(xp): | ||
x, y = asarray([0, 1, 2], xp=xp), asarray(np.arange(3), xp=xp) | ||
ref = np.array([0, 1, 2]) | ||
assert_equal(x, ref) | ||
assert_equal(y, ref) | ||
|
||
|
||
def test_asarray_namespace(): | ||
x, y = np.array([0, 1, 2]), np.array([0, 1, 2]) | ||
x, y, xp_ = asarray_namespace(x, y) | ||
assert xp_.__name__ == 'array_api_compat.numpy' | ||
ref = np.array([0, 1, 2]) | ||
assert_equal(x, ref) | ||
assert_equal(y, ref) | ||
assert type(x) == type(y) | ||
|
||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False | ||
x, y, xp_ = asarray_namespace(x, y) | ||
assert xp_.__name__ == 'numpy' | ||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True | ||
|
||
|
||
@array_api_compatible | ||
def test_to_numpy(xp): | ||
x = xp.asarray([0, 1, 2]) | ||
x = to_numpy(x, xp=xp) | ||
assert isinstance(x, np.ndarray) | ||
|
||
|
||
@pytest.mark.filterwarnings("ignore: the matrix subclass") | ||
def test_raises(): | ||
msg = "'numpy.ma.MaskedArray' are not supported" | ||
with pytest.raises(TypeError, match=msg): | ||
array_namespace(np.ma.array(1), np.array(1)) | ||
|
||
msg = "'numpy.matrix' are not supported" | ||
with pytest.raises(TypeError, match=msg): | ||
array_namespace(np.array(1), np.matrix(1)) | ||
|
||
msg = "Only support Array API" | ||
with pytest.raises(TypeError, match=msg): | ||
array_namespace([0, 1, 2]) | ||
|
||
with pytest.raises(TypeError, match=msg): | ||
array_namespace(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No one's actually doing it yet as far as I know, but this wouldn't work if someone vendors array_api_compat and tries to call a scipy function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I don't know if it makes sense to list
numpy.array_api
here. That namespace is designed to only support a strict implementation of the standard, which doesn't includeorder
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For scikit-learn, we wanted the performance with
numpy.array_api
to be the same asnumpy
. When one explicitly sets the order, there is usually a performance reason for doing so. For example:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way something like
will also work. That's maybe a little more "spec compliant" in the sense that converting arrays from one library to another with asarray is supported. In this case it's a trivial zero-copy wrapping but in general it will use DLPack (although I don't know how DLPack handles order, so maybe someone could confirm whether this would actually work in a more general setting).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that's wrong. I thought asarray in the spec used dlpack, but it's only numpy.asarray that does. In the spec you have to use from_dlpack (I'm not sure why they are separate).