diff --git a/scipy/_lib/_testutils.py b/scipy/_lib/_testutils.py index daad1194214c..7e0380bda641 100644 --- a/scipy/_lib/_testutils.py +++ b/scipy/_lib/_testutils.py @@ -7,7 +7,9 @@ import re import sys import numpy as np +from numpy.testing import assert_allclose, assert_array_almost_equal_nulp import inspect +import array_api_compat __all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc'] @@ -240,3 +242,85 @@ def _get_mem_available(): return info['memfree'] + info['cached'] return None + + +def _assert_allclose_host(a, + b, + rtol=1e-07, + atol=0, + equal_nan=True, + err_msg="", + verbose=True): + """ + NumPy assert_allclose() with added enforcement that + both arguments are stored on the host. This is intended + to simplify testing of array API backends other than NumPy + for which we'd like to move the data back to the host before + checking against the expected value (i.e., with CuPy, Torch + on GPU). + """ + # see: https://github.com/data-apis/array-api-compat/pull/40 + # and: https://github.com/data-apis/array-api/issues/626 + if not isinstance(a, float): + a = array_api_compat.to_device(a, "cpu") + if not isinstance(b, float): + b = array_api_compat.to_device(b, "cpu") + assert_allclose(a, + b, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + err_msg=err_msg, + verbose=verbose) + + +def _import_xp(): + """ + Allows us to import an array API standard-compliant library + `xp` based on an environment variable, and defaults to + NumPy. + """ + try: + backend = os.environ["ARR_TST_BACKEND"] + if backend == "numpy": + import numpy as xp + elif backend == "cupy": + import cupy as xp + elif backend == "pytorch_cpu": + import torch + torch.set_default_device("cpu") + import torch as xp + elif backend == "pytorch_gpu": + import torch + torch.set_default_device("cuda") + import torch as xp + else: + raise ValueError(f"ARR_TST_BACKEND {backend} not recognized.") + except KeyError: + import numpy as xp + return xp + + +def _assert_matching_namespace(a, b): + """ + Check that a and b are in the same array namespace. Intended for + array API support/testing that array type in == array type out. + """ + a_space = array_api_compat.array_namespace(a) + b_space = array_api_compat.array_namespace(b) + assert a_space == b_space + + +def _assert_array_almost_equal_nulp_host(x, y, nulp=1): + # see: https://github.com/data-apis/array-api-compat/pull/40 + # and: https://github.com/data-apis/array-api/issues/626 + x = array_api_compat.to_device(x, "cpu") + y = array_api_compat.to_device(y, "cpu") + try: + assert_array_almost_equal_nulp(x=x, y=y, nulp=nulp) + except TypeError: + # unfortunately, we still need shims for Pytorch + # tensors x and y at the moment... + x = x.detach().cpu().numpy() + y = y.detach().cpu().numpy() + assert_array_almost_equal_nulp(x=x, y=y, nulp=nulp) diff --git a/scipy/_lib/_util.py b/scipy/_lib/_util.py index 4a588a216bb9..cbbb4fa4daf1 100644 --- a/scipy/_lib/_util.py +++ b/scipy/_lib/_util.py @@ -15,6 +15,7 @@ ) import numpy as np +import array_api_compat IntNumber = Union[int, np.integer] DecimalNumber = Union[float, np.floating, np.integer] @@ -275,6 +276,23 @@ def _validate_int(k, name, minimum=None): return k +def _get_namespace(*xs): + namespaces = set() + for x in xs: + if isinstance(x, list): + # When a list is involved, we default to NumPy, + # as before supporting multiple array namespaces. + x = array_api_compat.array_namespace(np.asarray(x)) + else: + x = array_api_compat.array_namespace(x) + namespaces.add(x) + + if len(namespaces) != 1: + raise TypeError("Input array-like objects do not belong to the same array namespace.") + xp, = namespaces + return xp + + # Add a replacement for inspect.getfullargspec()/ # The version below is borrowed from Django, # https://github.com/django/django/pull/4846. diff --git a/scipy/signal/_signaltools.py b/scipy/signal/_signaltools.py index 60efefac7a1d..5e6ee106fac4 100644 --- a/scipy/signal/_signaltools.py +++ b/scipy/signal/_signaltools.py @@ -18,6 +18,7 @@ from ._filter_design import cheby1, _validate_sos, zpk2sos from ._fir_filter_design import firwin from ._sosfilt import _sosfilt +from scipy._lib._util import _get_namespace import warnings @@ -3520,20 +3521,21 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False): 0.06 # random """ + xp = _get_namespace(data) if type not in ['linear', 'l', 'constant', 'c']: raise ValueError("Trend type must be 'linear' or 'constant'.") - data = np.asarray(data) - dtype = data.dtype.char - if dtype not in 'dfDF': - dtype = 'd' + data = xp.asarray(data) + dtype = data.dtype + if data.dtype not in [xp.float32, xp.float64, xp.complex128, xp.complex64]: + dtype = xp.float64 if type in ['constant', 'c']: - ret = data - np.mean(data, axis, keepdims=True) + ret = data - xp.mean(xp.asarray(data, dtype=dtype), axis=axis, keepdims=True) return ret else: dshape = data.shape N = dshape[axis] - bp = np.sort(np.unique(np.r_[0, bp, N])) - if np.any(bp > N): + bp = xp.sort(xp.unique(xp.asarray([0, bp, N]))) + if xp.any(bp > N): raise ValueError("Breakpoints must be less than length " "of data along given axis.") Nreg = len(bp) - 1 @@ -3543,28 +3545,31 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False): if axis < 0: axis = axis + rnk newdims = tuple(np.r_[axis, 0:axis, axis + 1:rnk]) - newdata = data.transpose(newdims).reshape(N, -1) - + newdata = data.reshape(N, -1) if not overwrite_data: - newdata = newdata.copy() # make sure we have a copy - if newdata.dtype.char not in 'dfDF': + newdata = xp.asarray(newdata, copy=True) # make sure we have a copy + if newdata.dtype not in [xp.float64, xp.float32, xp.complex128, xp.complex64]: newdata = newdata.astype(dtype) # Find leastsq fit and remove it for each piece for m in range(Nreg): Npts = bp[m + 1] - bp[m] - A = np.ones((Npts, 2), dtype) - A[:, 0] = np.arange(1, Npts + 1, dtype=dtype) / Npts - sl = slice(bp[m], bp[m + 1]) - coef, resids, rank, s = linalg.lstsq(A, newdata[sl]) + A = xp.ones((int(Npts), 2), dtype=dtype) + A[:, 0] = xp.arange(1, Npts + 1, dtype=dtype) / Npts + sl = slice(int(bp[m]), int(bp[m + 1])) + # NOTE: lstsq isn't in the array API standard, so for now + # we cheat by using CuPy implementation if needed + if "cupy" in xp.__name__ or "torch" in xp.__name__: + coef, resids, rank, s = xp.linalg.lstsq(A, newdata[sl], rcond=None) + else: + coef, resids, rank, s = linalg.lstsq(A, newdata[sl]) newdata[sl] = newdata[sl] - A @ coef # Put data back in original shape. - tdshape = np.take(dshape, newdims, 0) - ret = np.reshape(newdata, tuple(tdshape)) + tdshape = xp.take(xp.asarray(dshape), xp.asarray(newdims)) + ret = xp.reshape(newdata, shape=tuple([int(e) for e in tdshape])) vals = list(range(1, rnk)) olddims = vals[:axis] + [0] + vals[axis:] - ret = np.transpose(ret, tuple(olddims)) return ret diff --git a/scipy/signal/_spectral_py.py b/scipy/signal/_spectral_py.py index 4aa28e0bc254..6d8ab56a5c62 100644 --- a/scipy/signal/_spectral_py.py +++ b/scipy/signal/_spectral_py.py @@ -1,13 +1,17 @@ """Tools for spectral analysis. """ +import os +import math import numpy as np from scipy import fft as sp_fft +from scipy._lib._util import _get_namespace from . import _signaltools from .windows import get_window from ._spectral import _lombscargle from ._arraytools import const_ext, even_ext, odd_ext, zero_ext import warnings +import array_api_compat __all__ = ['periodogram', 'welch', 'lombscargle', 'csd', 'coherence', @@ -452,12 +456,14 @@ def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, >>> plt.show() """ + xp = _get_namespace(x) freqs, Pxx = csd(x, x, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, average=average) - return freqs, Pxx.real + # TODO: remove coercion here I think, was needed for PyTorch + return freqs, xp.asarray(Pxx.real) def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, @@ -591,25 +597,36 @@ def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend, return_onesided, scaling, axis, mode='psd') + xp = _get_namespace(x, y) # Average over windows. - if len(Pxy.shape) >= 2 and Pxy.size > 0: + if len(Pxy.shape) >= 2 and math.prod(Pxy.shape) > 0: if Pxy.shape[-1] > 1: if average == 'median': # np.median must be passed real arrays for the desired result bias = _median_bias(Pxy.shape[-1]) - if np.iscomplexobj(Pxy): - Pxy = (np.median(np.real(Pxy), axis=-1) - + 1j * np.median(np.imag(Pxy), axis=-1)) + if Pxy.dtype in [xp.complex64, xp.complex128]: + Pxy = (xp.median(xp.real(Pxy), axis=-1) + + 1j * xp.median(xp.imag(Pxy), axis=-1)) else: - Pxy = np.median(Pxy, axis=-1) - Pxy /= bias + Pxy = xp.median(Pxy, axis=-1) + # for PyTorch, Pxy is torch.return_types.median + # which is super confusing... + # NOTE: I don't actually see median in the API std + try: + device_pxy = array_api_compat.device(Pxy) + except AttributeError: + Pxy = Pxy.values + device_pxy = array_api_compat.device(Pxy) + bias = xp.asarray(bias) + bias = array_api_compat.to_device(bias, device_pxy) + Pxy = Pxy / bias elif average == 'mean': Pxy = Pxy.mean(axis=-1) else: raise ValueError('average must be "median" or "mean", got %s' % (average,)) else: - Pxy = np.reshape(Pxy, Pxy.shape[:-1]) + Pxy = xp.reshape(Pxy, Pxy.shape[:-1]) return freqs, Pxy @@ -1721,6 +1738,7 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, .. versionadded:: 0.16.0 """ + xp = _get_namespace(x, y) if mode not in ['psd', 'stft']: raise ValueError("Unknown value for mode %s, must be one of: " "{'psd', 'stft'}" % mode) @@ -1744,12 +1762,15 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, axis = int(axis) # Ensure we have np.arrays, get outdtype - x = np.asarray(x) + x = xp.asarray(x) + # TODO: remove temporary hack for: + # https://github.com/data-apis/array-api-compat/issues/43 + tmp = xp.asarray([0], dtype=xp.complex64) if not same_data: - y = np.asarray(y) - outdtype = np.result_type(x, y, np.complex64) + y = xp.asarray(y) + outdtype = xp.result_type(x, y, xp.complex64) else: - outdtype = np.result_type(x, np.complex64) + outdtype = xp.result_type(x, tmp) if not same_data: # Check if we can broadcast the outer axes together @@ -1758,24 +1779,30 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, xouter.pop(axis) youter.pop(axis) try: - outershape = np.broadcast(np.empty(xouter), np.empty(youter)).shape + outershape = xp.broadcast(xp.empty(xouter), xp.empty(youter)).shape except ValueError as e: raise ValueError('x and y cannot be broadcast together.') from e if same_data: - if x.size == 0: - return np.empty(x.shape), np.empty(x.shape), np.empty(x.shape) + # TODO: what to do about PyTorch not treating + # size as a property, and doing weird stuff + # with size() in general (uses subclass of tuple...) + size = x.size + if not isinstance(size, int): + size = math.prod(x.shape) + if size == 0: + return xp.empty(x.shape), xp.empty(x.shape), xp.empty(x.shape) else: if x.size == 0 or y.size == 0: outshape = outershape + (min([x.shape[axis], y.shape[axis]]),) - emptyout = np.moveaxis(np.empty(outshape), -1, axis) + emptyout = xp.moveaxis(xp.empty(outshape), -1, axis) return emptyout, emptyout, emptyout if x.ndim > 1: if axis != -1: - x = np.moveaxis(x, axis, -1) + x = xp.moveaxis(x, axis, -1) if not same_data and y.ndim > 1: - y = np.moveaxis(y, axis, -1) + y = xp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if not same_data: @@ -1783,11 +1810,11 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] - x = np.concatenate((x, np.zeros(pad_shape)), -1) + x = xp.concatenate((x, xp.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] - y = np.concatenate((y, np.zeros(pad_shape)), -1) + y = xp.concatenate((y, xp.zeros(pad_shape)), -1) if nperseg is not None: # if specified by user nperseg = int(nperseg) @@ -1796,6 +1823,12 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, # parse window; if array like, then set nperseg = win.shape win, nperseg = _triage_segments(window, nperseg, input_length=x.shape[-1]) + # NOTE: asarray is API conformant, but I wonder + # if what really needs to happen is a deeper set of xp + # shims around the various window functions that ultimately + # get called by _triage_segments to avoid NumPy use more completely + # when not using NumPy? + win = xp.asarray(win) if nfft is None: nfft = nperseg @@ -1829,10 +1862,10 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] - x = np.concatenate((x, np.zeros(zeros_shape)), axis=-1) + x = xp.concatenate((x, xp.zeros(zeros_shape)), axis=-1) if not same_data: zeros_shape = list(y.shape[:-1]) + [nadd] - y = np.concatenate((y, np.zeros(zeros_shape)), axis=-1) + y = xp.concatenate((y, xp.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend: @@ -1845,14 +1878,21 @@ def detrend_func(d): # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): - d = np.moveaxis(d, -1, axis) + d = xp.moveaxis(d, -1, axis) d = detrend(d) - return np.moveaxis(d, axis, -1) + return xp.moveaxis(d, axis, -1) else: detrend_func = detrend - if np.result_type(win, np.complex64) != outdtype: - win = win.astype(outdtype) + # TODO: remove tmp usage when this is fixed: + # https://github.com/data-apis/array-api-compat/issues/43 + if xp.result_type(win, tmp) != outdtype: + try: + win = win.astype(outdtype) + except AttributeError: + # TODO: remove this shim when array-api-compat + # has suitable shims for complex types? (PyTorch related) + win = win.to(outdtype) if scaling == 'density': scale = 1.0 / (fs * (win*win).sum()) @@ -1862,17 +1902,23 @@ def detrend_func(d): raise ValueError('Unknown scaling: %r' % scaling) if mode == 'stft': - scale = np.sqrt(scale) + scale = xp.sqrt(scale) if return_onesided: - if np.iscomplexobj(x): + try: + is_complex = xp.iscomplexobj(x) + except AttributeError: + # TODO: deal with PyTorch vs. other libs here... + is_complex = xp.is_complex(x) + + if is_complex: sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'onesided' if not same_data: - if np.iscomplexobj(y): + if xp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') @@ -1880,20 +1926,21 @@ def detrend_func(d): sides = 'twosided' if sides == 'twosided': - freqs = sp_fft.fftfreq(nfft, 1/fs) + freqs = xp.fft.fftfreq(nfft, 1/fs) elif sides == 'onesided': - freqs = sp_fft.rfftfreq(nfft, 1/fs) + freqs = xp.fft.rfftfreq(nfft, 1/fs) # Perform the windowed FFTs + win = xp.asarray(win) result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if not same_data: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) - result = np.conjugate(result) * result_y + result = xp.conj(result) * result_y elif mode == 'psd': - result = np.conjugate(result) * result + result = xp.conj(result) * result result *= scale if sides == 'onesided' and mode == 'psd': @@ -1903,12 +1950,16 @@ def detrend_func(d): # Last point is unpaired Nyquist freq point, don't double result[..., 1:-1] *= 2 - time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1, + time = xp.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1, nperseg - noverlap)/float(fs) if boundary is not None: time -= (nperseg/2) / fs - result = result.astype(outdtype) + try: + result = result.astype(outdtype) + except AttributeError: + # TODO: handle complex types for PyTorch + result = result.to(outdtype) # All imaginary parts are zero anyways if same_data and mode != 'stft': @@ -1920,7 +1971,7 @@ def detrend_func(d): axis -= 1 # Roll frequency axis back to axis where the data came from - result = np.moveaxis(result, -1, axis) + result = xp.moveaxis(result, -1, axis) return freqs, time, result @@ -1947,6 +1998,7 @@ def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides): .. versionadded:: 0.16.0 """ + xp = _get_namespace(x, win) # Created strided array of data segments if nperseg == 1 and noverlap == 0: result = x[..., np.newaxis] @@ -1954,22 +2006,61 @@ def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides): # https://stackoverflow.com/a/5568169 step = nperseg - noverlap shape = x.shape[:-1]+((x.shape[-1]-noverlap)//step, nperseg) - strides = x.strides[:-1]+(step*x.strides[-1], x.strides[-1]) - result = np.lib.stride_tricks.as_strided(x, shape=shape, - strides=strides) + result = xp.empty(shape, dtype=x.dtype, device=array_api_compat.device(x)) + # as_strided makes a huge performance difference here, + # so we allow its usage when an array API library provides it + # despite its absence from the standard + # NOTE: we also provide an env variable for easy disabling + # of code paths that circumvent the array API for perf reasons + # NOTE: as_strided lives in different namespaces in different + # projects (i.e., top-level in torch, nested in CuPy/NumPy) + if "cupy" in xp.__name__ and not os.environ.get("SCIPY_STRICT_ARR_API"): + import cupy as cp + strides = x.strides[:-1]+(step*x.strides[-1], x.strides[-1]) + result = cp.lib.stride_tricks.as_strided(x, shape=shape, + strides=strides) + elif "numpy" in xp.__name__ and not os.environ.get("SCIPY_STRICT_ARR_API"): + strides = x.strides[:-1]+(step*x.strides[-1], x.strides[-1]) + result = np.lib.stride_tricks.as_strided(x, shape=shape, + strides=strides) + elif "torch" in xp.__name__ and not os.environ.get("SCIPY_STRICT_ARR_API"): + import torch + strides = x.stride()[:-1]+(step*x.stride()[-1], x.stride()[-1]) + result = torch.as_strided(x, size=shape, stride=strides) + else: + # NOTE: there is perhaps a dimensionally-agnostic + # way to circumvent as_strided, but for now this is a modified + # version of the shim described in gh-18286, which did not hold + # for 3D+ arrays + for ii in range(shape[0]): + if len(shape) == 2: + result[ii, :] = x[ii * step: (ii * step + nperseg)] + elif len(shape) == 3: + for jj in range(shape[1]): + result[ii, jj, :] = x[ii, jj * step: (jj * step + nperseg)] + elif len(shape) == 4: + for jj in range(shape[1]): + for kk in range(shape[2]): + result[ii, jj, kk, :] = x[ii, jj, kk * step: (kk * step + nperseg)] + # Detrend each data segment individually result = detrend_func(result) # Apply window by multiplication + # NOTE: win not always on torch CUDA device here + # when using PyTorch with GPU tensors + # probably need a deeper analysis to avoid this shim + result_device = array_api_compat.device(result) + win = array_api_compat.to_device(win, result_device) result = win * result # Perform the fft. Acts on last axis by default. Zero-pads automatically if sides == 'twosided': - func = sp_fft.fft + func = xp.fft.fft else: result = result.real - func = sp_fft.rfft + func = xp.fft.rfft result = func(result, n=nfft) return result @@ -2020,7 +2111,8 @@ def _triage_segments(window, nperseg, input_length): nperseg = input_length win = get_window(window, nperseg) else: - win = np.asarray(window) + xp = _get_namespace(window) + win = xp.asarray(window) if len(win.shape) != 1: raise ValueError('window must be 1-D') if input_length < win.shape[-1]: diff --git a/scipy/signal/tests/test_spectral.py b/scipy/signal/tests/test_spectral.py index c897d51ffa79..177e8fa09443 100644 --- a/scipy/signal/tests/test_spectral.py +++ b/scipy/signal/tests/test_spectral.py @@ -12,6 +12,10 @@ from scipy.signal import (periodogram, welch, lombscargle, csd, coherence, spectrogram, stft, istft, check_COLA, check_NOLA) from scipy.signal._spectral_py import _spectral_helper +from scipy._lib._testutils import (_assert_allclose_host, _import_xp, + _assert_matching_namespace, + _assert_array_almost_equal_nulp_host) +xp = _import_xp() class TestPeriodogram: @@ -237,149 +241,189 @@ def test_shorter_window_error(self): class TestWelch: def test_real_onesided_even(self): - x = np.zeros(16) + x = xp.zeros(16) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8) - assert_allclose(f, np.linspace(0, 0.5, 5)) + _assert_allclose_host(f, np.linspace(0, 0.5, 5)) q = np.array([0.08333333, 0.15277778, 0.22222222, 0.22222222, 0.11111111]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_real_onesided_odd(self): - x = np.zeros(16) + x = xp.zeros(16) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=9) - assert_allclose(f, np.arange(5.0)/9.0) + _assert_allclose_host(f, np.arange(5.0)/9.0) q = np.array([0.12477455, 0.23430933, 0.17072113, 0.17072113, 0.17072113]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_real_twosided(self): - x = np.zeros(16) + x = xp.zeros(16) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8, return_onesided=False) - assert_allclose(f, fftfreq(8, 1.0)) + _assert_allclose_host(f, fftfreq(8, 1.0)) q = np.array([0.08333333, 0.07638889, 0.11111111, 0.11111111, 0.11111111, 0.11111111, 0.11111111, 0.07638889]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_real_spectrum(self): - x = np.zeros(16) + x = xp.zeros(16) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8, scaling='spectrum') - assert_allclose(f, np.linspace(0, 0.5, 5)) + _assert_allclose_host(f, np.linspace(0, 0.5, 5)) q = np.array([0.015625, 0.02864583, 0.04166667, 0.04166667, 0.02083333]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_integer_onesided_even(self): - x = np.zeros(16, dtype=int) + x = xp.zeros(16, dtype=int) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8) - assert_allclose(f, np.linspace(0, 0.5, 5)) + _assert_allclose_host(f, np.linspace(0, 0.5, 5)) q = np.array([0.08333333, 0.15277778, 0.22222222, 0.22222222, 0.11111111]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_integer_onesided_odd(self): - x = np.zeros(16, dtype=int) + x = xp.zeros(16, dtype=int) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=9) - assert_allclose(f, np.arange(5.0)/9.0) + _assert_allclose_host(f, np.arange(5.0)/9.0) q = np.array([0.12477455, 0.23430933, 0.17072113, 0.17072113, 0.17072113]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_integer_twosided(self): - x = np.zeros(16, dtype=int) + x = xp.zeros(16, dtype=int) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8, return_onesided=False) - assert_allclose(f, fftfreq(8, 1.0)) + _assert_allclose_host(f, fftfreq(8, 1.0)) q = np.array([0.08333333, 0.07638889, 0.11111111, 0.11111111, 0.11111111, 0.11111111, 0.11111111, 0.07638889]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_complex(self): - x = np.zeros(16, np.complex128) + x = xp.zeros(16, dtype=xp.complex128) x[0] = 1.0 + 2.0j x[8] = 1.0 + 2.0j f, p = welch(x, nperseg=8, return_onesided=False) - assert_allclose(f, fftfreq(8, 1.0)) + _assert_allclose_host(f, fftfreq(8, 1.0)) q = np.array([0.41666667, 0.38194444, 0.55555556, 0.55555556, 0.55555556, 0.55555556, 0.55555556, 0.38194444]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_unk_scaling(self): - assert_raises(ValueError, welch, np.zeros(4, np.complex128), + assert_raises(ValueError, welch, xp.zeros(4, dtype=xp.complex128), scaling='foo', nperseg=4) def test_detrend_linear(self): - x = np.arange(10, dtype=np.float64) + 0.04 + x = xp.arange(10, dtype=xp.float64) + 0.04 f, p = welch(x, nperseg=10, detrend='linear') - assert_allclose(p, np.zeros_like(p), atol=1e-15) + _assert_allclose_host(p, xp.zeros_like(p), atol=1e-15) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_no_detrending(self): - x = np.arange(10, dtype=np.float64) + 0.04 + x = xp.arange(10, dtype=xp.float64) + 0.04 f1, p1 = welch(x, nperseg=10, detrend=False) f2, p2 = welch(x, nperseg=10, detrend=lambda x: x) - assert_allclose(f1, f2, atol=1e-15) - assert_allclose(p1, p2, atol=1e-15) + _assert_allclose_host(f1, f2, atol=1e-15) + _assert_allclose_host(p1, p2, atol=1e-15) + _assert_matching_namespace(f1, x) + _assert_matching_namespace(p1, x) + _assert_matching_namespace(f2, x) + _assert_matching_namespace(p2, x) def test_detrend_external(self): - x = np.arange(10, dtype=np.float64) + 0.04 + x = xp.arange(10, dtype=xp.float64) + 0.04 f, p = welch(x, nperseg=10, detrend=lambda seg: signal.detrend(seg, type='l')) - assert_allclose(p, np.zeros_like(p), atol=1e-15) + _assert_allclose_host(p, xp.zeros_like(p), atol=1e-15) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_detrend_external_nd_m1(self): - x = np.arange(40, dtype=np.float64) + 0.04 + x = xp.arange(40, dtype=xp.float64) + 0.04 x = x.reshape((2,2,10)) f, p = welch(x, nperseg=10, detrend=lambda seg: signal.detrend(seg, type='l')) - assert_allclose(p, np.zeros_like(p), atol=1e-15) + _assert_allclose_host(p, xp.zeros_like(p), atol=1e-15) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_detrend_external_nd_0(self): - x = np.arange(20, dtype=np.float64) + 0.04 + x = xp.arange(20, dtype=xp.float64) + 0.04 x = x.reshape((2,1,10)) - x = np.moveaxis(x, 2, 0) + x = xp.moveaxis(x, 2, 0) f, p = welch(x, nperseg=10, axis=0, detrend=lambda seg: signal.detrend(seg, axis=0, type='l')) - assert_allclose(p, np.zeros_like(p), atol=1e-15) + _assert_allclose_host(p, xp.zeros_like(p), atol=1e-15) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_nd_axis_m1(self): - x = np.arange(20, dtype=np.float64) + 0.04 + x = xp.arange(20, dtype=xp.float64) + 0.04 x = x.reshape((2,1,10)) f, p = welch(x, nperseg=10) assert_array_equal(p.shape, (2, 1, 6)) - assert_allclose(p[0,0,:], p[1,0,:], atol=1e-13, rtol=1e-13) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) + _assert_allclose_host(p[0,0,:], p[1,0,:], atol=1e-13, rtol=1e-13) f0, p0 = welch(x[0,0,:], nperseg=10) - assert_allclose(p0[np.newaxis,:], p[1,:], atol=1e-13, rtol=1e-13) + _assert_matching_namespace(f0, x) + _assert_matching_namespace(p0, x) + _assert_allclose_host(p0[None,:], p[1,:], atol=1e-13, rtol=1e-13) def test_nd_axis_0(self): - x = np.arange(20, dtype=np.float64) + 0.04 + x = xp.arange(20, dtype=xp.float64) + 0.04 x = x.reshape((10,2,1)) f, p = welch(x, nperseg=10, axis=0) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) assert_array_equal(p.shape, (6,2,1)) - assert_allclose(p[:,0,0], p[:,1,0], atol=1e-13, rtol=1e-13) + _assert_allclose_host(p[:,0,0], p[:,1,0], atol=1e-13, rtol=1e-13) f0, p0 = welch(x[:,0,0], nperseg=10) - assert_allclose(p0, p[:,1,0], atol=1e-13, rtol=1e-13) + _assert_matching_namespace(f0, x) + _assert_matching_namespace(p0, x) + _assert_allclose_host(p0, p[:,1,0], atol=1e-13, rtol=1e-13) def test_window_external(self): - x = np.zeros(16) + x = xp.zeros(16) x[0] = 1 x[8] = 1 f, p = welch(x, 10, 'hann', nperseg=8) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) win = signal.get_window('hann', 8) fe, pe = welch(x, 10, win, nperseg=None) - assert_array_almost_equal_nulp(p, pe) - assert_array_almost_equal_nulp(f, fe) + _assert_matching_namespace(fe, x) + _assert_matching_namespace(pe, x) + _assert_array_almost_equal_nulp_host(p, pe) + _assert_array_almost_equal_nulp_host(f, fe) assert_array_equal(fe.shape, (5,)) # because win length used as nperseg assert_array_equal(pe.shape, (5,)) assert_raises(ValueError, welch, x, @@ -393,18 +437,18 @@ def test_empty_input(self): assert_array_equal(f.shape, (0,)) assert_array_equal(p.shape, (0,)) for shape in [(0,), (3,0), (0,5,2)]: - f, p = welch(np.empty(shape)) + f, p = welch(xp.empty(shape)) assert_array_equal(f.shape, shape) assert_array_equal(p.shape, shape) def test_empty_input_other_axis(self): for shape in [(3,0), (0,5,2)]: - f, p = welch(np.empty(shape), axis=1) + f, p = welch(xp.empty(shape), axis=1) assert_array_equal(f.shape, shape) assert_array_equal(p.shape, shape) def test_short_data(self): - x = np.zeros(8) + x = xp.zeros(8) x[0] = 1 #for string-like window, input signal length < nperseg value gives #UserWarning, sets nperseg to x.shape[-1] @@ -413,93 +457,110 @@ def test_short_data(self): f, p = welch(x,window='hann') # default nperseg f1, p1 = welch(x,window='hann', nperseg=256) # user-specified nperseg f2, p2 = welch(x, nperseg=8) # valid nperseg, doesn't give warning - assert_allclose(f, f2) - assert_allclose(p, p2) - assert_allclose(f1, f2) - assert_allclose(p1, p2) + _assert_allclose_host(f, f2) + _assert_allclose_host(p, p2) + _assert_allclose_host(f1, f2) + _assert_allclose_host(p1, p2) + _assert_matching_namespace(f, x) + _assert_matching_namespace(f2, x) + _assert_matching_namespace(p2, x) def test_window_long_or_nd(self): - assert_raises(ValueError, welch, np.zeros(4), 1, np.array([1,1,1,1,1])) - assert_raises(ValueError, welch, np.zeros(4), 1, - np.arange(6).reshape((2,3))) + assert_raises(ValueError, welch, xp.zeros(4), 1, xp.asarray([1,1,1,1,1])) + assert_raises(ValueError, welch, xp.zeros(4), 1, + xp.arange(6).reshape((2,3))) def test_nondefault_noverlap(self): - x = np.zeros(64) + x = xp.zeros(64) x[::8] = 1 f, p = welch(x, nperseg=16, noverlap=4) q = np.array([0, 1./12., 1./3., 1./5., 1./3., 1./5., 1./3., 1./5., 1./6.]) - assert_allclose(p, q, atol=1e-12) + _assert_allclose_host(p, q, atol=1e-12) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_bad_noverlap(self): - assert_raises(ValueError, welch, np.zeros(4), 1, 'hann', 2, 7) + assert_raises(ValueError, welch, xp.zeros(4), 1, 'hann', 2, 7) def test_nfft_too_short(self): - assert_raises(ValueError, welch, np.ones(12), nfft=3, nperseg=4) + assert_raises(ValueError, welch, xp.ones(12), nfft=3, nperseg=4) def test_real_onesided_even_32(self): - x = np.zeros(16, 'f') + x = xp.zeros(16, dtype=xp.float32) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8) - assert_allclose(f, np.linspace(0, 0.5, 5)) - q = np.array([0.08333333, 0.15277778, 0.22222222, 0.22222222, - 0.11111111], 'f') - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(f, np.linspace(0, 0.5, 5)) + q = xp.asarray([0.08333333, 0.15277778, 0.22222222, 0.22222222, + 0.11111111], dtype=xp.float32) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) assert_(p.dtype == q.dtype) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_real_onesided_odd_32(self): - x = np.zeros(16, 'f') + x = xp.zeros(16, dtype=xp.float32) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=9) - assert_allclose(f, np.arange(5.0)/9.0) - q = np.array([0.12477458, 0.23430935, 0.17072113, 0.17072116, - 0.17072113], 'f') - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(f, np.arange(5.0)/9.0) + q = xp.asarray([0.12477458, 0.23430935, 0.17072113, 0.17072116, + 0.17072113], dtype=xp.float32) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) assert_(p.dtype == q.dtype) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_real_twosided_32(self): - x = np.zeros(16, 'f') + x = xp.zeros(16, dtype=xp.float32) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8, return_onesided=False) - assert_allclose(f, fftfreq(8, 1.0)) - q = np.array([0.08333333, 0.07638889, 0.11111111, + _assert_allclose_host(f, fftfreq(8, 1.0)) + q = xp.asarray([0.08333333, 0.07638889, 0.11111111, 0.11111111, 0.11111111, 0.11111111, 0.11111111, - 0.07638889], 'f') - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + 0.07638889], dtype=xp.float32) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) assert_(p.dtype == q.dtype) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_complex_32(self): - x = np.zeros(16, 'F') + x = xp.zeros(16, dtype=xp.complex64) x[0] = 1.0 + 2.0j x[8] = 1.0 + 2.0j f, p = welch(x, nperseg=8, return_onesided=False) - assert_allclose(f, fftfreq(8, 1.0)) - q = np.array([0.41666666, 0.38194442, 0.55555552, 0.55555552, - 0.55555558, 0.55555552, 0.55555552, 0.38194442], 'f') - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(f, fftfreq(8, 1.0)) + q = xp.asarray([0.41666666, 0.38194442, 0.55555552, 0.55555552, + 0.55555558, 0.55555552, 0.55555552, 0.38194442], dtype=xp.float32) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) assert_(p.dtype == q.dtype, f'dtype mismatch, {p.dtype}, {q.dtype}') + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) def test_padded_freqs(self): - x = np.zeros(12) + x = xp.zeros(12) nfft = 24 f = fftfreq(nfft, 1.0)[:nfft//2+1] f[-1] *= -1 fodd, _ = welch(x, nperseg=5, nfft=nfft) feven, _ = welch(x, nperseg=6, nfft=nfft) - assert_allclose(f, fodd) - assert_allclose(f, feven) + _assert_allclose_host(f, fodd) + _assert_allclose_host(f, feven) + _assert_matching_namespace(fodd, x) + _assert_matching_namespace(feven, x) nfft = 25 f = fftfreq(nfft, 1.0)[:(nfft + 1)//2] fodd, _ = welch(x, nperseg=5, nfft=nfft) feven, _ = welch(x, nperseg=6, nfft=nfft) - assert_allclose(f, fodd) - assert_allclose(f, feven) + _assert_allclose_host(f, fodd) + _assert_allclose_host(f, feven) + _assert_matching_namespace(fodd, x) + _assert_matching_namespace(feven, x) def test_window_correction(self): A = 20 @@ -508,8 +569,8 @@ def test_window_correction(self): fsig = 300 ii = int(fsig*nperseg//fs) # Freq index of fsig - tt = np.arange(fs)/fs - x = A*np.sin(2*np.pi*fsig*tt) + tt = xp.arange(fs)/fs + x = A*xp.sin(2*xp.pi*fsig*tt) for window in ['hann', 'bartlett', ('tukey', 0.1), 'flattop']: _, p_spec = welch(x, fs=fs, nperseg=nperseg, window=window, @@ -518,12 +579,14 @@ def test_window_correction(self): scaling='density') # Check peak height at signal frequency for 'spectrum' - assert_allclose(p_spec[ii], A**2/2.0) + _assert_allclose_host(p_spec[ii], A**2/2.0, rtol=4e-7) # Check integrated spectrum RMS for 'density' - assert_allclose(np.sqrt(np.trapz(p_dens, freq)), A*np.sqrt(2)/2, + _assert_allclose_host(xp.sqrt(xp.trapz(p_dens, freq)), A*np.sqrt(2)/2, rtol=1e-3) def test_axis_rolling(self): + # TODO: what to do about random seeds and alternative + # array libraries? np.random.seed(1234) x_flat = np.random.randn(1024) @@ -541,13 +604,15 @@ def test_axis_rolling(self): assert_equal(p_flat, p_minus.squeeze(), err_msg=a-x.ndim) def test_average(self): - x = np.zeros(16) + x = xp.zeros(16) x[0] = 1 x[8] = 1 f, p = welch(x, nperseg=8, average='median') - assert_allclose(f, np.linspace(0, 0.5, 5)) + _assert_allclose_host(f, np.linspace(0, 0.5, 5)) q = np.array([.1, .05, 0., 1.54074396e-33, 0.]) - assert_allclose(p, q, atol=1e-7, rtol=1e-7) + _assert_allclose_host(p, q, atol=1e-7, rtol=1e-7) + _assert_matching_namespace(f, x) + _assert_matching_namespace(p, x) assert_raises(ValueError, welch, x, nperseg=8, average='unrecognised-average') @@ -1598,8 +1663,8 @@ def test_roundtrip_scaling(self): # Since x is real, its Fourier transform is conjugate symmetric, i.e., # the missing 'second side' can be expressed through the 'first side': Zp1 = np.conj(Zp0[-2:0:-1, :]) # 'second side' is conjugate reversed - assert_allclose(Zp[:129, :], Zp0) - assert_allclose(Zp[129:, :], Zp1) + assert_allclose(Zp[:129, :], Zp0, atol=9e-16) + assert_allclose(Zp[129:, :], Zp1, atol=9e-16) # Calculate the spectral power: s2 = (np.sum(Zp0.real ** 2 + Zp0.imag ** 2, axis=0) +