Skip to content

Commit 2aca6c2

Browse files
committed
move test_scalar_wrapper to test_buffer.py
1 parent 50fd5ff commit 2aca6c2

File tree

3 files changed

+55
-51
lines changed

3 files changed

+55
-51
lines changed

src/zarr/core/buffer/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def astype(
159159
self, dtype: npt.DTypeLike, order: Literal["K", "A", "C", "F"] = "K", *, copy: bool = True
160160
) -> Self:
161161
if copy:
162-
return self.__class__(self._value, dtype)
163-
self._dtype = dtype
162+
return self.__class__(self._value, np.dtype(dtype))
163+
self._dtype = np.dtype(dtype)
164164
return self
165165

166166
def fill(self, value: Any) -> None:

tests/test_array.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,54 +1258,6 @@ async def test_create_array_v2_no_shards(store: MemoryStore) -> None:
12581258
)
12591259

12601260

1261-
@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1), False, True])
1262-
def test_scalar_wrapper(value: Any) -> None:
1263-
x = ScalarWrapper(value)
1264-
assert x == value
1265-
assert value == x
1266-
assert x == x[()]
1267-
assert x.view(str) == x
1268-
assert x.copy() == x
1269-
assert x.transpose() == x
1270-
assert x.ravel() == x
1271-
assert x.all() == bool(value)
1272-
if isinstance(value, (int, float)):
1273-
assert -x == -value
1274-
assert abs(x) == abs(value)
1275-
assert int(x) == int(value)
1276-
assert float(x) == float(value)
1277-
assert complex(x) == complex(value)
1278-
assert x + 1 == value + 1
1279-
assert x - 1 == value - 1
1280-
assert x * 2 == value * 2
1281-
assert x / 2 == value / 2
1282-
assert x // 2 == value // 2
1283-
assert x % 2 == value % 2
1284-
assert x**2 == value**2
1285-
assert x == value
1286-
assert x != value + 1
1287-
assert bool(x) == bool(value)
1288-
assert hash(x) == hash(value)
1289-
assert str(x) == str(value)
1290-
assert format(x, "") == format(value, "")
1291-
x.fill(2)
1292-
x[()] += 1
1293-
assert x == 3
1294-
elif isinstance(value, str):
1295-
assert str(x) == value
1296-
with pytest.raises(TypeError, match=re.escape("bad operand type for abs(): 'str'")):
1297-
abs(x)
1298-
1299-
with pytest.raises(ValueError, match="Cannot reshape scalar to non-scalar shape."):
1300-
x.reshape((1, 2))
1301-
with pytest.raises(IndexError, match="Invalid index for scalar."):
1302-
x[10] = value
1303-
with pytest.raises(IndexError, match="Invalid index for scalar."):
1304-
x[10]
1305-
with pytest.raises(TypeError, match=re.escape("len() of unsized object.")):
1306-
len(x)
1307-
1308-
13091261
@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1)])
13101262
@pytest.mark.parametrize("zarr_format", [2, 3])
13111263
def test_scalar_array(value: Any, zarr_format: ZarrFormat) -> None:

tests/test_buffer.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
import re
4+
from typing import TYPE_CHECKING, Any
45

56
import numpy as np
67
import pytest
@@ -30,6 +31,9 @@
3031
cp = None
3132

3233

34+
import zarr.api.asynchronous
35+
from zarr.core.buffer.core import ScalarWrapper
36+
3337
if TYPE_CHECKING:
3438
import types
3539

@@ -40,6 +44,54 @@ def test_nd_array_like(xp: types.ModuleType) -> None:
4044
assert isinstance(ary, NDArrayLike)
4145

4246

47+
@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1), False, True])
48+
def test_scalar_wrapper(value: Any) -> None:
49+
x = ScalarWrapper(value)
50+
assert x == value
51+
assert value == x
52+
assert x == x[()]
53+
assert x.view(str) == x
54+
assert x.copy() == x
55+
assert x.transpose() == x
56+
assert x.ravel() == x
57+
assert x.all() == bool(value)
58+
if isinstance(value, (int, float)):
59+
assert -x == -value
60+
assert abs(x) == abs(value)
61+
assert int(x) == int(value)
62+
assert float(x) == float(value)
63+
assert complex(x) == complex(value)
64+
assert x + 1 == value + 1
65+
assert x - 1 == value - 1
66+
assert x * 2 == value * 2
67+
assert x / 2 == value / 2
68+
assert x // 2 == value // 2
69+
assert x % 2 == value % 2
70+
assert x**2 == value**2
71+
assert x == value
72+
assert x != value + 1
73+
assert bool(x) == bool(value)
74+
assert hash(x) == hash(value)
75+
assert str(x) == str(value)
76+
assert format(x, "") == format(value, "")
77+
x.fill(2)
78+
x[()] += 1
79+
assert x == 3
80+
elif isinstance(value, str):
81+
assert str(x) == value
82+
with pytest.raises(TypeError, match=re.escape("bad operand type for abs(): 'str'")):
83+
abs(x)
84+
85+
with pytest.raises(ValueError, match="Cannot reshape scalar to non-scalar shape."):
86+
x.reshape((1, 2))
87+
with pytest.raises(IndexError, match="Invalid index for scalar."):
88+
x[10] = value
89+
with pytest.raises(IndexError, match="Invalid index for scalar."):
90+
x[10]
91+
with pytest.raises(TypeError, match=re.escape("len() of unsized object.")):
92+
len(x)
93+
94+
4395
@pytest.mark.asyncio
4496
async def test_async_array_prototype() -> None:
4597
"""Test the use of a custom buffer prototype"""

0 commit comments

Comments
 (0)