diff --git a/numcodecs/compat.py b/numcodecs/compat.py index d1844e10..355338f5 100644 --- a/numcodecs/compat.py +++ b/numcodecs/compat.py @@ -2,11 +2,12 @@ import codecs import numpy as np +import numpy.typing as npt -from .ndarray_like import NDArrayLike, is_ndarray_like +from .ndarray_like import is_ndarray_like -def ensure_ndarray_like(buf) -> NDArrayLike: +def ensure_ndarray_like(buf) -> npt.NDArray: """Convenience function to coerce `buf` to ndarray-like array. Parameters @@ -38,7 +39,7 @@ def ensure_ndarray_like(buf) -> NDArrayLike: mem = memoryview(buf) # instantiate array from memoryview, ensures no copy buf = np.array(mem, copy=False) - return buf + return np.asanyarray(buf, copy=False) def ensure_ndarray(buf) -> np.ndarray: @@ -63,7 +64,7 @@ def ensure_ndarray(buf) -> np.ndarray: return np.array(ensure_ndarray_like(buf), copy=False) -def ensure_contiguous_ndarray_like(buf, max_buffer_size=None, flatten=True) -> NDArrayLike: +def ensure_contiguous_ndarray_like(buf, max_buffer_size=None, flatten=True) -> npt.ArrayLike: """Convenience function to coerce `buf` to ndarray-like array. Also ensures that the returned value exports fully contiguous memory, and supports the new-style buffer interface. If the optional max_buffer_size is @@ -174,7 +175,7 @@ def ensure_text(s, encoding="utf-8"): return s -def ndarray_copy(src, dst) -> NDArrayLike: +def ndarray_copy(src, dst) -> npt.NDArray: """Copy the contents of the array from `src` to `dst`.""" if dst is None: diff --git a/numcodecs/ndarray_like.py b/numcodecs/ndarray_like.py index 06e15ea9..170428c5 100644 --- a/numcodecs/ndarray_like.py +++ b/numcodecs/ndarray_like.py @@ -1,65 +1,6 @@ -from typing import Any, ClassVar, Protocol, runtime_checkable - - -class _CachedProtocolMeta(Protocol.__class__): # type: ignore[name-defined] - """Custom implementation of @runtime_checkable - - The native implementation of @runtime_checkable is slow, - see . - - This metaclass keeps an unbounded cache of the result of - isinstance checks using the object's class as the cache key. - """ - - _instancecheck_cache: ClassVar[dict[tuple[type, type], bool]] = {} - - def __instancecheck__(self, instance): - key = (self, instance.__class__) - ret = self._instancecheck_cache.get(key) - if ret is None: - ret = super().__instancecheck__(instance) - self._instancecheck_cache[key] = ret - return ret - - -@runtime_checkable -class DType(Protocol, metaclass=_CachedProtocolMeta): - itemsize: int - name: str - kind: str - - -@runtime_checkable -class FlagsObj(Protocol, metaclass=_CachedProtocolMeta): - c_contiguous: bool - f_contiguous: bool - owndata: bool - - -@runtime_checkable -class NDArrayLike(Protocol, metaclass=_CachedProtocolMeta): - dtype: DType - shape: tuple[int, ...] - strides: tuple[int, ...] - ndim: int - size: int - itemsize: int - nbytes: int - flags: FlagsObj - - def __len__(self) -> int: ... # pragma: no cover - - def __getitem__(self, key) -> Any: ... # pragma: no cover - - def __setitem__(self, key, value): ... # pragma: no cover - - def tobytes(self, order: str | None = ...) -> bytes: ... # pragma: no cover - - def reshape(self, *shape: int, order: str = ...) -> "NDArrayLike": ... # pragma: no cover - - def view(self, dtype: DType = ...) -> "NDArrayLike": ... # pragma: no cover +import numpy.typing as npt def is_ndarray_like(obj: object) -> bool: """Return True when `obj` is ndarray-like""" - return isinstance(obj, NDArrayLike) + return isinstance(obj, npt.ArrayLike) diff --git a/numcodecs/tests/test_ndarray_like.py b/numcodecs/tests/test_ndarray_like.py deleted file mode 100644 index 6c16e7db..00000000 --- a/numcodecs/tests/test_ndarray_like.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest - -from numcodecs.ndarray_like import DType, FlagsObj, NDArrayLike - - -@pytest.mark.parametrize("module", ["numpy", "cupy"]) -def test_is_ndarray_like(module): - m = pytest.importorskip(module) - a = m.arange(10) - assert isinstance(a, NDArrayLike) - - -def test_is_not_ndarray_like(): - assert not isinstance([1, 2, 3], NDArrayLike) - assert not isinstance(b"1,2,3", NDArrayLike) - - -@pytest.mark.parametrize("module", ["numpy", "cupy"]) -def test_is_dtype_like(module): - m = pytest.importorskip(module) - d = m.dtype("u8") - assert isinstance(d, DType) - - -def test_is_not_dtype_like(): - assert not isinstance([1, 2, 3], DType) - assert not isinstance(b"1,2,3", DType) - - -@pytest.mark.parametrize("module", ["numpy", "cupy"]) -def test_is_flags_like(module): - m = pytest.importorskip(module) - d = m.arange(10).flags - assert isinstance(d, FlagsObj) - - -def test_is_not_flags_like(): - assert not isinstance([1, 2, 3], FlagsObj) - assert not isinstance(b"1,2,3", FlagsObj) - - -@pytest.mark.parametrize("module", ["numpy", "cupy"]) -def test_cached_isinstance_check(module): - m = pytest.importorskip(module) - a = m.arange(10) - assert isinstance(a, NDArrayLike) - assert not isinstance(a, DType) - assert not isinstance(a, FlagsObj)