44import pickle
55import re
66from itertools import accumulate
7- from typing import TYPE_CHECKING , Any , Literal , cast
7+ from typing import TYPE_CHECKING , Any , Literal
88
99import numcodecs
1010import numpy as np
@@ -1258,26 +1258,23 @@ async def test_create_array_v2_no_shards(store: MemoryStore) -> None:
12581258 )
12591259
12601260
1261- @pytest .mark .parametrize ("value" , [1 , 1.4 , "a" , b"a" , np .array (1 )])
1262- def test_scalar_array (value : Any ) -> None :
1263- arr = zarr .array (value )
1264- assert arr [...] == value
1265- assert arr .shape == ()
1266- assert arr .ndim == 0
1267-
1268- x = arr [()]
1269- assert isinstance (arr [()], ScalarWrapper )
1270- assert isinstance (arr [()], NDArrayLike )
1271- assert x .shape == arr .shape
1272- assert x .ndim == arr .ndim
1261+ @pytest .mark .parametrize ("value" , [1 , 1.4 , "a" , b"a" , np .array (1 ), False , True ])
1262+ def test_scalar_wrapper (value : Any ) -> None :
1263+ x = ScalarWrapper (value )
12731264 assert x == value
12741265 assert value == x
1266+ assert x == x [()]
1267+ assert x .view (str ) == x
1268+ assert x .copy () == x
1269+ assert x .transpose () == x
1270+ assert x .ravel () == x
1271+ assert x .all () == bool (value )
12751272 if isinstance (value , (int , float )):
1276- x = cast (ScalarWrapper , x )
12771273 assert - x == - value
12781274 assert abs (x ) == abs (value )
12791275 assert int (x ) == int (value )
12801276 assert float (x ) == float (value )
1277+ assert complex (x ) == complex (value )
12811278 assert x + 1 == value + 1
12821279 assert x - 1 == value - 1
12831280 assert x * 2 == value * 2
@@ -1291,5 +1288,34 @@ def test_scalar_array(value: Any) -> None:
12911288 assert hash (x ) == hash (value )
12921289 assert str (x ) == str (value )
12931290 assert format (x , "" ) == format (value , "" )
1291+ x .fill (2 )
1292+ x [()] += 1
1293+ assert x == 3
12941294 elif isinstance (value , str ):
12951295 assert str (x ) == value
1296+ with pytest .raises (TypeError , match = re .escape ("bad operand type for abs(): 'str'" )):
1297+ abs (x )
1298+
1299+ with pytest .raises (ValueError , match = "Cannot reshape scalar to non-scalar shape." ):
1300+ x .reshape ((1 , 2 ))
1301+ with pytest .raises (IndexError , match = "Invalid index for scalar." ):
1302+ x [10 ] = value
1303+ with pytest .raises (IndexError , match = "Invalid index for scalar." ):
1304+ x [10 ]
1305+ with pytest .raises (TypeError , match = re .escape ("len() of unsized object." )):
1306+ len (x )
1307+
1308+
1309+ @pytest .mark .parametrize ("value" , [1 , 1.4 , "a" , b"a" , np .array (1 )])
1310+ @pytest .mark .parametrize ("zarr_format" , [2 , 3 ])
1311+ def test_scalar_array (value : Any , zarr_format : ZarrFormat ) -> None :
1312+ arr = zarr .array (value , zarr_format = zarr_format )
1313+ assert arr [...] == value
1314+ assert arr .shape == ()
1315+ assert arr .ndim == 0
1316+
1317+ x = arr [()]
1318+ assert isinstance (arr [()], ScalarWrapper )
1319+ assert isinstance (arr [()], NDArrayLike )
1320+ assert x .shape == arr .shape
1321+ assert x .ndim == arr .ndim
0 commit comments