Skip to content

Commit 236487d

Browse files
committed
fixup
1 parent eaf8063 commit 236487d

File tree

3 files changed

+47
-21
lines changed

3 files changed

+47
-21
lines changed

src/zarr/core/metadata/v3.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
216216
_ = parse_node_type_array(_data.pop("node_type"))
217217

218218
# check that the data_type attribute is valid
219-
if _data["data_type"] not in DataType:
220-
raise ValueError(f"Invalid V3 data_type: {_data['data_type']}")
219+
_ = DataType(_data["data_type"])
221220

222221
# dimension_names key is optional, normalize missing to `None`
223222
_data["dimension_names"] = _data.pop("dimension_names", None)
@@ -264,23 +263,38 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
264263

265264

266265
@overload
267-
def parse_fill_value(fill_value: object, dtype: BOOL_DTYPE) -> BOOL: ...
266+
def parse_fill_value(
267+
fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None,
268+
dtype: BOOL_DTYPE,
269+
) -> BOOL: ...
268270

269271

270272
@overload
271-
def parse_fill_value(fill_value: object, dtype: INTEGER_DTYPE) -> INTEGER: ...
273+
def parse_fill_value(
274+
fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None,
275+
dtype: INTEGER_DTYPE,
276+
) -> INTEGER: ...
272277

273278

274279
@overload
275-
def parse_fill_value(fill_value: object, dtype: FLOAT_DTYPE) -> FLOAT: ...
280+
def parse_fill_value(
281+
fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None,
282+
dtype: FLOAT_DTYPE,
283+
) -> FLOAT: ...
276284

277285

278286
@overload
279-
def parse_fill_value(fill_value: object, dtype: COMPLEX_DTYPE) -> COMPLEX: ...
287+
def parse_fill_value(
288+
fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None,
289+
dtype: COMPLEX_DTYPE,
290+
) -> COMPLEX: ...
280291

281292

282293
@overload
283-
def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
294+
def parse_fill_value(
295+
fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None,
296+
dtype: np.dtype[Any],
297+
) -> Any:
284298
# This dtype[Any] is unfortunately necessary right now.
285299
# See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899
286300
# for more details, but `dtype` here (which comes from `parse_dtype`)
@@ -292,7 +306,7 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
292306

293307

294308
def parse_fill_value(
295-
fill_value: object,
309+
fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None,
296310
dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any],
297311
) -> BOOL | INTEGER | FLOAT | COMPLEX | Any:
298312
"""
@@ -326,11 +340,11 @@ def parse_fill_value(
326340
else:
327341
msg = (
328342
f"Got an invalid fill value for complex data type {dtype}."
329-
f"Expected a sequence with 2 elements, but {fill_value} has "
343+
f"Expected a sequence with 2 elements, but {fill_value!r} has "
330344
f"length {len(fill_value)}."
331345
)
332346
raise ValueError(msg)
333-
msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}."
347+
msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype}."
334348
raise TypeError(msg)
335349

336350
# Cast the fill_value to the given dtype
@@ -339,7 +353,7 @@ def parse_fill_value(
339353
except (ValueError, OverflowError, TypeError) as e:
340354
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e
341355
# Check if the value is still representable by the dtype
342-
if fill_value != casted_value:
356+
if fill_value != casted_value and not (np.isnan(fill_value) and np.isnan(casted_value)):
343357
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}")
344358

345359
return casted_value
@@ -434,7 +448,7 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
434448
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
435449
try:
436450
dtype = np.dtype(data)
437-
except TypeError as e:
451+
except (ValueError, TypeError) as e:
438452
raise ValueError(f"Invalid V3 data_type: {data}") from e
439453
# check that this is a valid v3 data_type
440454
try:

tests/v3/test_array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ def test_array_v3_fill_value_default(
122122

123123

124124
@pytest.mark.parametrize("store", ["memory"], indirect=True)
125-
@pytest.mark.parametrize("fill_value", [False, 0.0, 1, 2.3])
126-
@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "float32", "complex64"])
125+
@pytest.mark.parametrize(
126+
"dtype_str,fill_value",
127+
[("bool", True), ("uint8", 99), ("float32", -99.9), ("complex64", 3 + 4j)],
128+
)
127129
def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str) -> None:
128130
shape = (10,)
129131
arr = Array.create(

tests/v3/test_metadata/test_v3.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,19 @@ def test_parse_auto_fill_value(dtype_str: str) -> None:
7979
assert parse_fill_value(fill_value, dtype) == dtype.type(0)
8080

8181

82-
@pytest.mark.parametrize("fill_value", [0, 1.11, False, True])
83-
@pytest.mark.parametrize("dtype_str", dtypes)
82+
@pytest.mark.parametrize(
83+
"fill_value,dtype_str",
84+
[
85+
(True, "bool"),
86+
(False, "bool"),
87+
(-8, "int8"),
88+
(0, "int16"),
89+
(1e10, "uint64"),
90+
(-999, "float32"),
91+
(1e32, "float64"),
92+
(0j, "complex64"),
93+
],
94+
)
8495
def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None:
8596
"""
8697
Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype.
@@ -141,8 +152,7 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None:
141152
This test excludes bool because the bool constructor takes anything.
142153
"""
143154
dtype = np.dtype(dtype_str)
144-
match = "must be"
145-
with pytest.raises(TypeError, match=match):
155+
with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"):
146156
parse_fill_value(fill_value, dtype)
147157

148158

@@ -269,13 +279,13 @@ async def test_invalid_dtype_raises() -> None:
269279
"codecs": (),
270280
"fill_value": np.datetime64(0, "ns"),
271281
}
272-
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
282+
with pytest.raises(ValueError, match=r".* is not a valid DataType"):
273283
ArrayV3Metadata.from_dict(metadata_dict)
274284

275285

276286
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
277287
def test_parse_invalid_dtype_raises(data):
278-
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
288+
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
279289
parse_dtype(data)
280290

281291

@@ -293,5 +303,5 @@ async def test_invalid_fill_value_raises(data_type: str, fill_value: int | float
293303
"codecs": (),
294304
"fill_value": fill_value, # this is not a valid fill value for uint8
295305
}
296-
with pytest.raises(ValueError, match=rf"fill value .* is not valid for dtype {data_type}"):
306+
with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"):
297307
ArrayV3Metadata.from_dict(metadata_dict)

0 commit comments

Comments
 (0)