Skip to content

Commit 98e13a8

Browse files
committed
remove ScalarWrapper usage
1 parent ab7afbb commit 98e13a8

File tree

3 files changed

+6
-57
lines changed

3 files changed

+6
-57
lines changed

src/zarr/core/buffer/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
552552
"""
553553
...
554554

555-
def as_scalar(self) -> ScalarWrapper:
555+
def as_scalar(self) -> np.generic:
556556
"""Returns the buffer as a scalar value
557557
558558
Returns
@@ -561,7 +561,7 @@ def as_scalar(self) -> ScalarWrapper:
561561
"""
562562
if self._data.size != 1:
563563
raise ValueError("Buffer does not contain a single scalar value")
564-
return ScalarWrapper(self.as_numpy_array().item(), np.dtype(self.dtype))
564+
return self.dtype.type(self.as_numpy_array().item())
565565

566566
@property
567567
def dtype(self) -> np.dtype[Any]:

tests/test_array.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
chunks_initialized,
3636
create_array,
3737
)
38-
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
39-
from zarr.core.buffer.core import ScalarWrapper
38+
from zarr.core.buffer import default_buffer_prototype
4039
from zarr.core.buffer.cpu import NDBuffer
4140
from zarr.core.chunk_grids import _auto_partition
4241
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
@@ -1335,8 +1334,8 @@ def test_scalar_array(value: Any, zarr_format: ZarrFormat) -> None:
13351334
assert arr.ndim == 0
13361335

13371336
x = arr[()]
1338-
assert isinstance(arr[()], ScalarWrapper)
1339-
assert isinstance(arr[()], NDArrayLike)
1337+
assert isinstance(arr[()], np.generic)
1338+
# assert isinstance(arr[()], NDArrayLike)
13401339
assert x.shape == arr.shape
13411340
assert x.ndim == arr.ndim
13421341

tests/test_buffer.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
import re
4-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING
54

65
import numpy as np
76
import pytest
@@ -32,7 +31,6 @@
3231

3332

3433
import zarr.api.asynchronous
35-
from zarr.core.buffer.core import ScalarWrapper
3634

3735
if TYPE_CHECKING:
3836
import types
@@ -44,54 +42,6 @@ def test_nd_array_like(xp: types.ModuleType) -> None:
4442
assert isinstance(ary, NDArrayLike)
4543

4644

47-
@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1), False, True])
48-
def test_scalar_wrapper(value: Any) -> None:
49-
x = ScalarWrapper(value)
50-
assert x == value
51-
assert value == x
52-
assert x == x[()]
53-
assert x.view(str) == x
54-
assert x.copy() == x
55-
assert x.transpose() == x
56-
assert x.ravel() == x
57-
assert x.all() == bool(value)
58-
if isinstance(value, (int, float)):
59-
assert -x == -value
60-
assert abs(x) == abs(value)
61-
assert int(x) == int(value)
62-
assert float(x) == float(value)
63-
assert complex(x) == complex(value)
64-
assert x + 1 == value + 1
65-
assert x - 1 == value - 1
66-
assert x * 2 == value * 2
67-
assert x / 2 == value / 2
68-
assert x // 2 == value // 2
69-
assert x % 2 == value % 2
70-
assert x**2 == value**2
71-
assert x == value
72-
assert x != value + 1
73-
assert bool(x) == bool(value)
74-
assert hash(x) == hash(value)
75-
assert str(x) == str(value)
76-
assert format(x, "") == format(value, "")
77-
x.fill(2)
78-
x[()] += 1
79-
assert x == 3
80-
elif isinstance(value, str):
81-
assert str(x) == value
82-
with pytest.raises(TypeError, match=re.escape("bad operand type for abs(): 'str'")):
83-
abs(x)
84-
85-
with pytest.raises(ValueError, match="Cannot reshape scalar to non-scalar shape."):
86-
x.reshape((1, 2))
87-
with pytest.raises(IndexError, match="Invalid index for scalar."):
88-
x[10] = value
89-
with pytest.raises(IndexError, match="Invalid index for scalar."):
90-
x[10]
91-
with pytest.raises(TypeError, match=re.escape("len() of unsized object.")):
92-
len(x)
93-
94-
9545
@pytest.mark.asyncio
9646
async def test_async_array_prototype() -> None:
9747
"""Test the use of a custom buffer prototype"""

0 commit comments

Comments
 (0)