|
31 | 31 | chunks_initialized, |
32 | 32 | create_array, |
33 | 33 | ) |
34 | | -from zarr.core.buffer import default_buffer_prototype |
| 34 | +from zarr.core.buffer import default_buffer_prototype, NDArrayLike |
35 | 35 | from zarr.core.buffer.core import ScalarWrapper |
36 | 36 | from zarr.core.buffer.cpu import NDBuffer |
37 | 37 | from zarr.core.chunk_grids import _auto_partition |
@@ -1257,13 +1257,37 @@ async def test_create_array_v2_no_shards(store: MemoryStore) -> None: |
1257 | 1257 | zarr_format=2, |
1258 | 1258 | ) |
1259 | 1259 |
|
1260 | | - |
1261 | | -async def test_scalar_array() -> None: |
1262 | | - arr = zarr.array(1.5) |
1263 | | - assert arr[...] == 1.5 |
1264 | | - assert arr[()] == 1.5 |
| 1260 | +@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1)]) |
| 1261 | +def test_scalar_array(value: Any) -> None: |
| 1262 | + arr = zarr.array(value) |
| 1263 | + assert arr[...] == value |
1265 | 1264 | assert arr.shape == () |
1266 | | - assert arr[()].shape == () |
1267 | 1265 | assert arr.ndim == 0 |
1268 | | - assert arr[()].ndim == 0 |
| 1266 | + |
| 1267 | + x = arr[()] |
1269 | 1268 | assert isinstance(arr[()], ScalarWrapper) |
| 1269 | + assert isinstance(arr[()], NDArrayLike) |
| 1270 | + assert x.shape == arr.shape |
| 1271 | + assert x.ndim == arr.ndim |
| 1272 | + assert x == value |
| 1273 | + assert value == x |
| 1274 | + if isinstance(value, (int, float)): |
| 1275 | + assert -x == -value |
| 1276 | + assert abs(x) == abs(value) |
| 1277 | + assert int(x) == int(value) |
| 1278 | + assert float(x) == float(value) |
| 1279 | + assert x + 1 == value + 1 |
| 1280 | + assert x - 1 == value - 1 |
| 1281 | + assert x * 2 == value * 2 |
| 1282 | + assert x / 2 == value / 2 |
| 1283 | + assert x // 2 == value // 2 |
| 1284 | + assert x % 2 == value % 2 |
| 1285 | + assert x ** 2 == value ** 2 |
| 1286 | + assert x == value |
| 1287 | + assert x != value + 1 |
| 1288 | + assert bool(x) == bool(value) |
| 1289 | + assert hash(x) == hash(value) |
| 1290 | + assert str(x) == str(value) |
| 1291 | + assert format(x, "") == format(value, "") |
| 1292 | + elif isinstance(value, str): |
| 1293 | + assert str(x) == value |
0 commit comments