-
Notifications
You must be signed in to change notification settings - Fork 3
ENH: asarray
for Array API support
#24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 41 commits
ea9c544
28a6339
65aca1e
912e55a
de56fe7
ca5ff59
cec8bd2
94d3044
2ef5892
ad770b8
039e931
bb8e89c
c4a2c7f
31a0300
3732e1c
1ebc347
4333c99
068b4fa
d72de86
d748ed1
0a1aa67
2a4f2f7
633648b
735101a
eb79de0
b4b9f32
1623d0e
abdb9b2
c432d88
603fbf3
a0b06cf
37c4c0f
afc5fdc
f900a4d
d6b568a
560b591
9ea70e6
e54c36d
9da0cd9
85c7045
8bc56ab
2eb7791
4596247
67ac6e0
52e982d
d8f66dd
488e3c3
77feb2f
d2d2e8d
e395445
1146129
bc6c61b
61d5860
dd2cab5
0469887
276d3f5
06e59f3
82a8aa3
de2257f
85fccda
0c90e8d
1d090e6
1f3856f
ea6f58b
a33efc6
4dbde35
bef978c
60e3103
b201851
0f022a1
226ac88
f7f2568
050a96e
f2dfb3c
4787f50
a38cd2e
2d12ad4
7aea9d3
0e2c5b1
15286a6
92814ac
b7455f2
5744618
5033b54
8bd49ae
50d7ea4
84cfe65
b9fe722
98f8309
f0c2ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
name: Array API | ||
|
||
on: | ||
push: | ||
branches: | ||
- maintenance/** | ||
pull_request: | ||
branches: | ||
- main | ||
- maintenance/** | ||
|
||
permissions: | ||
contents: read # to fetch code (actions/checkout) | ||
|
||
env: | ||
CCACHE_DIR: "${{ github.workspace }}/.ccache" | ||
INSTALLDIR: "build-install" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
pytorch_cpu: | ||
name: Linux PyTorch CPU | ||
# if: "github.repository == 'scipy/scipy' || github.repository == ''" | ||
runs-on: ubuntu-22.04 | ||
strategy: | ||
matrix: | ||
python-version: ['3.11'] | ||
maintenance-branch: | ||
- ${{ contains(github.ref, 'maintenance/') || contains(github.base_ref, 'maintenance/') }} | ||
exclude: | ||
- maintenance-branch: true | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
submodules: recursive | ||
|
||
- name: Setup Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: 'pip' # not using a path to also cache pytorch | ||
|
||
- name: Install Ubuntu dependencies | ||
run: | | ||
sudo apt-get update | ||
sudo apt-get install -y libopenblas-dev libatlas-base-dev liblapack-dev gfortran libgmp-dev libmpfr-dev libsuitesparse-dev ccache libmpc-dev | ||
|
||
- name: Install Python packages | ||
run: | | ||
python -m pip install numpy cython pytest pytest-xdist pytest-timeout pybind11 mpmath gmpy2 pythran ninja meson click rich-click doit pydevtool pooch | ||
# Packages for Array API testing | ||
python -m pip install array-api-compat | ||
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
||
- name: Prepare compiler cache | ||
id: prep-ccache | ||
shell: bash | ||
run: | | ||
mkdir -p "${CCACHE_DIR}" | ||
echo "dir=$CCACHE_DIR" >> $GITHUB_OUTPUT | ||
NOW=$(date -u +"%F-%T") | ||
echo "timestamp=${NOW}" >> $GITHUB_OUTPUT | ||
|
||
- name: Setup compiler cache | ||
uses: actions/cache@v3 | ||
id: cache-ccache | ||
with: | ||
path: ${{ steps.prep-ccache.outputs.dir }} | ||
key: ${{ github.workflow }}-${{ matrix.python-version }}-ccache-linux-${{ steps.prep-ccache.outputs.timestamp }} | ||
restore-keys: | | ||
${{ github.workflow }}-${{ matrix.python-version }}-ccache-linux- | ||
|
||
- name: Setup build and install scipy | ||
run: | | ||
python dev.py build | ||
|
||
- name: Test SciPy | ||
run: | | ||
export OMP_NUM_THREADS=2 | ||
export SCIPY_USE_PROPACK=1 | ||
python dev.py --no-build test --array-api-backend pytorch -s cluster -- --durations 10 --timeout=60 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
"""Utility functions to use Python Array API compatible libraries. | ||
|
||
For the context about the Array API see: | ||
https://data-apis.org/array-api/latest/purpose_and_scope.html | ||
|
||
The SciPy use case of the Array API is described on the following page: | ||
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy | ||
""" | ||
import os | ||
|
||
import numpy as np | ||
from numpy.core.numerictypes import typecodes | ||
# probably want to vendor it (submodule) | ||
import array_api_compat | ||
import array_api_compat.numpy | ||
|
||
__all__ = ['array_namespace', 'as_xparray', 'as_xparray_namespace'] | ||
|
||
|
||
# SCIPY_ARRAY_API, array_api_dispatch is used by sklearn | ||
array_api_dispatch = os.environ.get("array_api_dispatch", False) | ||
SCIPY_ARRAY_API = os.environ.get("SCIPY_ARRAY_API", array_api_dispatch) | ||
|
||
_GLOBAL_CONFIG = {"SCIPY_ARRAY_API": SCIPY_ARRAY_API} | ||
|
||
|
||
def compliance_scipy(*arrays): | ||
"""Raise exceptions on known-bad subclasses. | ||
|
||
The following subclasses are not supported and raise and error: | ||
- `np.ma.MaskedArray` | ||
- `numpy.matrix` | ||
- Any array-like which is not Array API compatible | ||
""" | ||
for array in arrays: | ||
if isinstance(array, np.ma.MaskedArray): | ||
raise TypeError("'numpy.ma.MaskedArray' are not supported") | ||
elif isinstance(array, np.matrix): | ||
raise TypeError("'numpy.matrix' are not supported") | ||
elif not array_api_compat.is_array_api_obj(array): | ||
raise TypeError("Only support Array API compatible arrays") | ||
elif array.dtype is np.dtype('O'): | ||
raise ValueError('object arrays are not supported') | ||
|
||
|
||
def _check_finite(array, xp): | ||
"""Check for NaNs or Infs.""" | ||
msg = "array must not contain infs or NaNs" | ||
try: | ||
if not xp.isfinite(array).all(): | ||
raise ValueError(msg) | ||
except TypeError: | ||
raise ValueError(msg) | ||
|
||
|
||
def array_namespace(*arrays): | ||
"""Get the array API compatible namespace for the arrays xs. | ||
|
||
Parameters | ||
---------- | ||
*arrays : sequence of array_like | ||
Arrays used to infer the common namespace. | ||
|
||
Returns | ||
------- | ||
namespace : module | ||
Common namespace. | ||
|
||
Notes | ||
----- | ||
Thin wrapper around `array_api_compat.array_namespace`. | ||
|
||
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed | ||
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. | ||
2. `compliance_scipy` raise exceptions on known-bad subclasses. See | ||
it's definition for more details. | ||
|
||
When the global switch is False, it defaults to the `numpy` namespace. | ||
In that case, there is no compliance check. This is a convenience to | ||
ease the adoption. Otherwise, arrays must comply with the new rules. | ||
""" | ||
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: | ||
# here we could wrap the namespace if needed | ||
return np | ||
|
||
arrays = [array for array in arrays if array is not None] | ||
|
||
compliance_scipy(*arrays) | ||
|
||
return array_api_compat.array_namespace(*arrays) | ||
|
||
|
||
def as_xparray( | ||
array, dtype=None, order=None, copy=None, *, xp=None, check_finite=True | ||
): | ||
"""Drop-in replacement for `np.asarray`. | ||
|
||
Memory layout parameter `order` is not exposed in the Array API standard. | ||
`order` is only enforced if the input array implementation | ||
is NumPy based, otherwise `order` is just silently ignored. | ||
|
||
The purpose of this helper is to make it possible to share code for data | ||
container validation without memory copies for both downstream use cases. | ||
""" | ||
if xp is None: | ||
xp = array_namespace(array) | ||
if xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"}: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No one's actually doing it yet as far as I know, but this wouldn't work if someone vendors array_api_compat and tries to call a scipy function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I don't know if it makes sense to list There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For scikit-learn, we wanted the performance with def scipy_func(X):
xp = array_namespace(X)
# switch order for performance reasons
X_f = asarray(X, xp, order="F")
# Do some operations that require prefer F ordered.
return xp.sum(X_f, axis=0) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way something like _X = numpy.asarray(..., order="F")
X = numpy.array_api.asarray(X) will also work. That's maybe a little more "spec compliant" in the sense that converting arrays from one library to another with asarray is supported. In this case it's a trivial zero-copy wrapping but in general it will use DLPack (although I don't know how DLPack handles order, so maybe someone could confirm whether this would actually work in a more general setting). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually that's wrong. I thought asarray in the spec used dlpack, but it's only numpy.asarray that does. In the spec you have to use from_dlpack (I'm not sure why they are separate). |
||
# Use NumPy API to support order | ||
if copy is True: | ||
array = np.array(array, order=order, dtype=dtype) | ||
else: | ||
array = np.asarray(array, order=order, dtype=dtype) | ||
|
||
# At this point array is a NumPy ndarray. We convert it to an array | ||
# container that is consistent with the input's namespace. | ||
array = xp.asarray(array) | ||
else: | ||
array = xp.asarray(array, dtype=dtype, copy=copy) | ||
|
||
if check_finite: | ||
_check_finite(array, xp) | ||
tupui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return array | ||
|
||
|
||
def as_xparray_namespace(*arrays): | ||
"""Validate and convert arrays to a common namespace. | ||
|
||
Parameters | ||
---------- | ||
*arrays : sequence of array_like | ||
Arrays to validate and convert. | ||
|
||
Returns | ||
------- | ||
*arrays : sequence of array_like | ||
Validated and converted arrays to the common namespace. | ||
namespace : module | ||
Common namespace. | ||
|
||
Notes | ||
----- | ||
This function is meant to be called from each public function in a SciPy | ||
submodule it does the following: | ||
|
||
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed | ||
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. | ||
2. `compliance_scipy` raise exceptions on known-bad subclasses. See | ||
it's definition for more details. | ||
3. Determine the namespace, without doing any coercion of array(-like) | ||
inputs. | ||
4. Call `xp.asarray` on all array. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> x, y, xp = as_xparray_namespace(np.array([0, 1, 2]), np.array([0, 1, 2])) | ||
>>> xp.__name__ | ||
'array_api_compat.numpy' | ||
>>> x, y | ||
(array([0, 1, 2]), array([0, 1, 2])) | ||
|
||
""" | ||
tupui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arrays = list(arrays) | ||
xp = array_namespace(*arrays) | ||
|
||
for i, array in enumerate(arrays): | ||
arrays[i] = xp.asarray(array) | ||
|
||
return *arrays, xp | ||
|
||
|
||
def to_numpy(array, xp): | ||
"""Convert `array` into a NumPy ndarray on the CPU. | ||
|
||
ONLY FOR TESTING | ||
""" | ||
xp_name = xp.__name__ | ||
|
||
if xp_name in {"array_api_compat.torch", "torch"}: | ||
return array.cpu().numpy() | ||
|
||
elif xp_name == "cupy.array_api": | ||
return array._array.get() | ||
elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover | ||
return array.get() | ||
|
||
return np.asarray(array) |
Uh oh!
There was an error while loading. Please reload this page.