Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions scipy/_lib/_testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import re
import sys
import numpy as np
from numpy.testing import assert_allclose, assert_array_almost_equal_nulp
import inspect
import array_api_compat


__all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc']
Expand Down Expand Up @@ -240,3 +242,85 @@ def _get_mem_available():
return info['memfree'] + info['cached']

return None


def _assert_allclose_host(a,
b,
rtol=1e-07,
atol=0,
equal_nan=True,
err_msg="",
verbose=True):
"""
NumPy assert_allclose() with added enforcement that
both arguments are stored on the host. This is intended
to simplify testing of array API backends other than NumPy
for which we'd like to move the data back to the host before
checking against the expected value (i.e., with CuPy, Torch
on GPU).
"""
# see: https://github.com/data-apis/array-api-compat/pull/40
# and: https://github.com/data-apis/array-api/issues/626
if not isinstance(a, float):
a = array_api_compat.to_device(a, "cpu")
if not isinstance(b, float):
b = array_api_compat.to_device(b, "cpu")
assert_allclose(a,
b,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
err_msg=err_msg,
verbose=verbose)


def _import_xp():
"""
Allows us to import an array API standard-compliant library
`xp` based on an environment variable, and defaults to
NumPy.
"""
try:
backend = os.environ["ARR_TST_BACKEND"]
if backend == "numpy":
import numpy as xp
elif backend == "cupy":
import cupy as xp
elif backend == "pytorch_cpu":
import torch
torch.set_default_device("cpu")
import torch as xp
elif backend == "pytorch_gpu":
import torch
torch.set_default_device("cuda")
import torch as xp
else:
raise ValueError(f"ARR_TST_BACKEND {backend} not recognized.")
except KeyError:
import numpy as xp
return xp


def _assert_matching_namespace(a, b):
"""
Check that a and b are in the same array namespace. Intended for
array API support/testing that array type in == array type out.
"""
a_space = array_api_compat.array_namespace(a)
b_space = array_api_compat.array_namespace(b)
assert a_space == b_space


def _assert_array_almost_equal_nulp_host(x, y, nulp=1):
# see: https://github.com/data-apis/array-api-compat/pull/40
# and: https://github.com/data-apis/array-api/issues/626
x = array_api_compat.to_device(x, "cpu")
y = array_api_compat.to_device(y, "cpu")
try:
assert_array_almost_equal_nulp(x=x, y=y, nulp=nulp)
except TypeError:
# unfortunately, we still need shims for Pytorch
# tensors x and y at the moment...
x = x.detach().cpu().numpy()
y = y.detach().cpu().numpy()
assert_array_almost_equal_nulp(x=x, y=y, nulp=nulp)
18 changes: 18 additions & 0 deletions scipy/_lib/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

import numpy as np
import array_api_compat

IntNumber = Union[int, np.integer]
DecimalNumber = Union[float, np.floating, np.integer]
Expand Down Expand Up @@ -275,6 +276,23 @@ def _validate_int(k, name, minimum=None):
return k


def _get_namespace(*xs):
namespaces = set()
for x in xs:
if isinstance(x, list):
# When a list is involved, we default to NumPy,
# as before supporting multiple array namespaces.
x = array_api_compat.array_namespace(np.asarray(x))
else:
x = array_api_compat.array_namespace(x)
namespaces.add(x)

if len(namespaces) != 1:
raise TypeError("Input array-like objects do not belong to the same array namespace.")
xp, = namespaces
return xp


# Add a replacement for inspect.getfullargspec()/
# The version below is borrowed from Django,
# https://github.com/django/django/pull/4846.
Expand Down
41 changes: 23 additions & 18 deletions scipy/signal/_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ._filter_design import cheby1, _validate_sos, zpk2sos
from ._fir_filter_design import firwin
from ._sosfilt import _sosfilt
from scipy._lib._util import _get_namespace
import warnings


Expand Down Expand Up @@ -3520,20 +3521,21 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False):
0.06 # random

"""
xp = _get_namespace(data)
if type not in ['linear', 'l', 'constant', 'c']:
raise ValueError("Trend type must be 'linear' or 'constant'.")
data = np.asarray(data)
dtype = data.dtype.char
if dtype not in 'dfDF':
dtype = 'd'
data = xp.asarray(data)
dtype = data.dtype
if data.dtype not in [xp.float32, xp.float64, xp.complex128, xp.complex64]:
dtype = xp.float64
if type in ['constant', 'c']:
ret = data - np.mean(data, axis, keepdims=True)
ret = data - xp.mean(xp.asarray(data, dtype=dtype), axis=axis, keepdims=True)
return ret
else:
dshape = data.shape
N = dshape[axis]
bp = np.sort(np.unique(np.r_[0, bp, N]))
if np.any(bp > N):
bp = xp.sort(xp.unique(xp.asarray([0, bp, N])))
if xp.any(bp > N):
raise ValueError("Breakpoints must be less than length "
"of data along given axis.")
Nreg = len(bp) - 1
Expand All @@ -3543,28 +3545,31 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False):
if axis < 0:
axis = axis + rnk
newdims = tuple(np.r_[axis, 0:axis, axis + 1:rnk])
newdata = data.transpose(newdims).reshape(N, -1)

newdata = data.reshape(N, -1)
if not overwrite_data:
newdata = newdata.copy() # make sure we have a copy
if newdata.dtype.char not in 'dfDF':
newdata = xp.asarray(newdata, copy=True) # make sure we have a copy
if newdata.dtype not in [xp.float64, xp.float32, xp.complex128, xp.complex64]:
newdata = newdata.astype(dtype)

# Find leastsq fit and remove it for each piece
for m in range(Nreg):
Npts = bp[m + 1] - bp[m]
A = np.ones((Npts, 2), dtype)
A[:, 0] = np.arange(1, Npts + 1, dtype=dtype) / Npts
sl = slice(bp[m], bp[m + 1])
coef, resids, rank, s = linalg.lstsq(A, newdata[sl])
A = xp.ones((int(Npts), 2), dtype=dtype)
A[:, 0] = xp.arange(1, Npts + 1, dtype=dtype) / Npts
sl = slice(int(bp[m]), int(bp[m + 1]))
# NOTE: lstsq isn't in the array API standard, so for now
# we cheat by using CuPy implementation if needed
if "cupy" in xp.__name__ or "torch" in xp.__name__:
coef, resids, rank, s = xp.linalg.lstsq(A, newdata[sl], rcond=None)
else:
coef, resids, rank, s = linalg.lstsq(A, newdata[sl])
newdata[sl] = newdata[sl] - A @ coef

# Put data back in original shape.
tdshape = np.take(dshape, newdims, 0)
ret = np.reshape(newdata, tuple(tdshape))
tdshape = xp.take(xp.asarray(dshape), xp.asarray(newdims))
ret = xp.reshape(newdata, shape=tuple([int(e) for e in tdshape]))
vals = list(range(1, rnk))
olddims = vals[:axis] + [0] + vals[axis:]
ret = np.transpose(ret, tuple(olddims))
return ret


Expand Down
Loading