Skip to content

Commit be0d2df

Browse files
committed
use none to denote default fill value; remove old structured tests; use cast_value where appropriate
1 parent 8a976d6 commit be0d2df

File tree

8 files changed

+41
-142
lines changed

8 files changed

+41
-142
lines changed

src/zarr/api/synchronous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def create(
601601
chunks: ChunkCoords | int | bool | None = None,
602602
dtype: npt.DTypeLike | None = None,
603603
compressor: CompressorLike = "auto",
604-
fill_value: Any | None = 0, # TODO: need type
604+
fill_value: Any | None = None, # TODO: need type
605605
order: MemoryOrder | None = None,
606606
store: str | StoreLike | None = None,
607607
synchronizer: Any | None = None,

src/zarr/core/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,8 @@ def _create_metadata_v2(
778778
) -> ArrayV2Metadata:
779779
if dimension_separator is None:
780780
dimension_separator = "."
781-
781+
if fill_value is None:
782+
fill_value = dtype.default_value() # type: ignore[assignment]
782783
return ArrayV2Metadata(
783784
shape=shape,
784785
dtype=dtype,

src/zarr/core/dtype/npy/sized.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bytes_:
7979
raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover
8080

8181
def check_value(self, data: object) -> bool:
82-
return isinstance(data, np.bytes_ | str | bytes)
82+
# this is generous for backwards compatibility
83+
return isinstance(data, np.bytes_ | str | bytes | int)
8384

8485
def _cast_value_unsafe(self, value: object) -> np.bytes_:
8586
return self.to_dtype().type(value)
@@ -168,7 +169,11 @@ def check_value(self, data: object) -> bool:
168169
return isinstance(data, np.bytes_ | str | bytes | np.void)
169170

170171
def _cast_value_unsafe(self, value: object) -> np.void:
171-
return self.to_dtype().type(value) # type: ignore[call-overload, no-any-return]
172+
native_dtype = self.to_dtype()
173+
# Without the second argument, numpy will return a void scalar for dtype V1.
174+
# The second argument ensures that, if native_dtype is something like V10,
175+
# the result will actually be a V10 scalar.
176+
return native_dtype.type(value, native_dtype)
172177

173178

174179
@dataclass(frozen=True, kw_only=True)
@@ -239,7 +244,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.str_:
239244
raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover
240245

241246
def check_value(self, data: object) -> bool:
242-
return isinstance(data, str | np.str_ | bytes)
247+
# this is generous for backwards compatibility
248+
return isinstance(data, str | np.str_ | bytes | int)
243249

244250
def _cast_value_unsafe(self, value: object) -> np.str_:
245251
return self.to_dtype().type(value)
@@ -254,8 +260,15 @@ class Structured(ZDType[np.dtypes.VoidDType[int], np.void]):
254260
def default_value(self) -> np.void:
255261
return self._cast_value_unsafe(0)
256262

257-
def _cast_value_unsafe(self, value: object) -> np.void:
258-
return cast("np.void", np.array([value], dtype=self.to_dtype())[0])
263+
def _cast_value_unsafe(self, data: object) -> np.void:
264+
na_dtype = self.to_dtype()
265+
if isinstance(data, bytes):
266+
res = np.frombuffer(data, dtype=na_dtype)[0]
267+
elif isinstance(data, list | tuple):
268+
res = np.array([tuple(data)], dtype=na_dtype)[0]
269+
else:
270+
res = np.array([data], dtype=na_dtype)[0]
271+
return cast("np.void", res)
259272

260273
@classmethod
261274
def check_dtype(cls, dtype: TBaseDType) -> TypeGuard[np.dtypes.VoidDType[int]]:

src/zarr/core/dtype/wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def cast_value(self, data: object) -> TScalar_co:
160160
if self.check_value(data):
161161
return self._cast_value_unsafe(data)
162162
msg = (
163-
f"The value {data} failed a type check."
164-
f"It cannot be safely cast to a scalar compatible with {self.dtype_cls}."
165-
f"Consult the documentation for {self} to determine the possible values that can"
163+
f"The value {data} failed a type check. "
164+
f"It cannot be safely cast to a scalar compatible with {self.dtype_cls}. "
165+
f"Consult the documentation for {self} to determine the possible values that can "
166166
"be cast to scalars of the wrapped data type."
167167
)
168168
raise TypeError(msg)

src/zarr/core/metadata/v2.py

Lines changed: 10 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import base64
43
import warnings
54
from collections.abc import Iterable, Sequence
65
from functools import cached_property
@@ -52,7 +51,7 @@ class ArrayV2Metadata(Metadata):
5251
shape: ChunkCoords
5352
chunks: ChunkCoords
5453
dtype: ZDType[TBaseDType, TBaseScalar]
55-
fill_value: int | float | str | bytes | None = 0
54+
fill_value: int | float | str | bytes | None = None
5655
order: MemoryOrder = "C"
5756
filters: tuple[numcodecs.abc.Codec, ...] | None = None
5857
dimension_separator: Literal[".", "/"] = "."
@@ -85,7 +84,11 @@ def __init__(
8584
order_parsed = parse_indexing_order(order)
8685
dimension_separator_parsed = parse_separator(dimension_separator)
8786
filters_parsed = parse_filters(filters)
88-
fill_value_parsed = parse_fill_value(fill_value, dtype=dtype.to_dtype())
87+
fill_value_parsed: TBaseScalar | None
88+
if fill_value is not None:
89+
fill_value_parsed = dtype.cast_value(fill_value)
90+
else:
91+
fill_value_parsed = fill_value
8992
attributes_parsed = parse_attributes(attributes)
9093

9194
object.__setattr__(self, "shape", shape_parsed)
@@ -134,11 +137,10 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
134137
_ = parse_zarr_format(_data.pop("zarr_format"))
135138
dtype = get_data_type_from_native_dtype(_data["dtype"])
136139
_data["dtype"] = dtype
137-
if dtype.to_dtype().kind in "SV":
138-
fill_value_encoded = _data.get("fill_value")
139-
if fill_value_encoded is not None:
140-
fill_value = base64.standard_b64decode(fill_value_encoded)
141-
_data["fill_value"] = fill_value
140+
fill_value_encoded = _data.get("fill_value")
141+
if fill_value_encoded is not None:
142+
fill_value = dtype.from_json_value(fill_value_encoded, zarr_format=2)
143+
_data["fill_value"] = fill_value
142144

143145
# zarr v2 allowed arbitrary keys here.
144146
# We don't want the ArrayV2Metadata constructor to fail just because someone put an
@@ -281,76 +283,3 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata:
281283
)
282284
raise ValueError(msg)
283285
return data
284-
285-
286-
def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
287-
"""Handle structured dtype/fill value pairs"""
288-
try:
289-
if isinstance(fill_value, list):
290-
return np.array([tuple(fill_value)], dtype=dtype)[0]
291-
elif isinstance(fill_value, tuple):
292-
return np.array([fill_value], dtype=dtype)[0]
293-
elif isinstance(fill_value, bytes):
294-
return np.frombuffer(fill_value, dtype=dtype)[0]
295-
elif isinstance(fill_value, str):
296-
decoded = base64.standard_b64decode(fill_value)
297-
return np.frombuffer(decoded, dtype=dtype)[0]
298-
else:
299-
return np.array(fill_value, dtype=dtype)[()]
300-
except Exception as e:
301-
raise ValueError(f"Fill_value {fill_value} is not valid for dtype {dtype}.") from e
302-
303-
304-
def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
305-
"""
306-
Parse a potential fill value into a value that is compatible with the provided dtype.
307-
308-
Parameters
309-
----------
310-
fill_value : Any
311-
A potential fill value.
312-
dtype : np.dtype[Any]
313-
A numpy dtype.
314-
315-
Returns
316-
-------
317-
An instance of `dtype`, or `None`, or any python object (in the case of an object dtype)
318-
"""
319-
320-
if fill_value is None or dtype.hasobject:
321-
pass
322-
elif dtype.fields is not None:
323-
# the dtype is structured (has multiple fields), so the fill_value might be a
324-
# compound value (e.g., a tuple or dict) that needs field-wise processing.
325-
# We use parse_structured_fill_value to correctly convert each component.
326-
fill_value = _parse_structured_fill_value(fill_value, dtype)
327-
elif not isinstance(fill_value, np.void) and fill_value == 0:
328-
# this should be compatible across numpy versions for any array type, including
329-
# structured arrays
330-
fill_value = np.zeros((), dtype=dtype)[()]
331-
elif dtype.kind == "U":
332-
# special case unicode because of encoding issues on Windows if passed through numpy
333-
# https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713
334-
335-
if not isinstance(fill_value, str):
336-
raise ValueError(
337-
f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string"
338-
)
339-
elif dtype.kind in "SV" and isinstance(fill_value, str):
340-
fill_value = base64.standard_b64decode(fill_value)
341-
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
342-
complex_val = complex(float(fill_value[0]), float(fill_value[1]))
343-
fill_value = np.array(complex_val, dtype=dtype)[()]
344-
else:
345-
try:
346-
if isinstance(fill_value, bytes) and dtype.kind == "V":
347-
# special case for numpy 1.14 compatibility
348-
fill_value = np.array(fill_value, dtype=dtype.str).view(dtype)[()]
349-
else:
350-
fill_value = np.array(fill_value, dtype=dtype)[()]
351-
352-
except Exception as e:
353-
msg = f"Fill_value {fill_value} is not valid for dtype {dtype}."
354-
raise ValueError(msg) from e
355-
356-
return fill_value

src/zarr/core/metadata/v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __init__(
175175
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
176176
dimension_names_parsed = parse_dimension_names(dimension_names)
177177
# Note: relying on a type method is numpy-specific
178-
fill_value_parsed = data_type.to_dtype().type(fill_value)
178+
fill_value_parsed = data_type.cast_value(fill_value)
179179
attributes_parsed = parse_attributes(attributes)
180180
codecs_parsed_partial = parse_codecs(codecs)
181181
storage_transformers_parsed = parse_storage_transformers(storage_transformers)

tests/test_metadata/test_v2.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ async def v2_consolidated_metadata(
128128
"chunks": [730],
129129
"compressor": None,
130130
"dtype": "<f4",
131-
"fill_value": "0.0",
131+
"fill_value": 0.0,
132132
"filters": None,
133133
"order": "C",
134134
"shape": [730],
@@ -147,7 +147,7 @@ async def v2_consolidated_metadata(
147147
"chunks": [730],
148148
"compressor": None,
149149
"dtype": "<f4",
150-
"fill_value": "0.0",
150+
"fill_value": 0.0,
151151
"filters": None,
152152
"order": "C",
153153
"shape": [730],
@@ -318,9 +318,7 @@ def test_zstd_checksum() -> None:
318318
assert "checksum" not in metadata["compressor"]
319319

320320

321-
@pytest.mark.parametrize(
322-
"fill_value", [None, np.void((0, 0), np.dtype([("foo", "i4"), ("bar", "i4")]))]
323-
)
321+
@pytest.mark.parametrize("fill_value", [np.void((0, 0), np.dtype([("foo", "i4"), ("bar", "i4")]))])
324322
def test_structured_dtype_fill_value_serialization(tmp_path, fill_value):
325323
zarr_format = 2
326324
group_path = tmp_path / "test.zarr"

tests/test_v2.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from zarr import config
1616
from zarr.abc.store import Store
1717
from zarr.core.buffer.core import default_buffer_prototype
18-
from zarr.core.metadata.v2 import _parse_structured_fill_value
18+
from zarr.core.dtype.npy.sized import Structured
1919
from zarr.core.sync import sync
2020
from zarr.storage import MemoryStore, StorePath
2121

@@ -261,67 +261,25 @@ def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None:
261261
np.dtype([("x", "i4"), ("y", "i4")]),
262262
np.array([(1, 2)], dtype=[("x", "i4"), ("y", "i4")])[0],
263263
),
264-
(
265-
"BQAAAA==",
266-
np.dtype([("val", "i4")]),
267-
np.array([(5,)], dtype=[("val", "i4")])[0],
268-
),
269-
(
270-
{"x": 1, "y": 2},
271-
np.dtype([("location", "O")]),
272-
np.array([({"x": 1, "y": 2},)], dtype=[("location", "O")])[0],
273-
),
274-
(
275-
{"x": 1, "y": 2, "z": 3},
276-
np.dtype([("location", "O")]),
277-
np.array([({"x": 1, "y": 2, "z": 3},)], dtype=[("location", "O")])[0],
278-
),
279264
],
280265
ids=[
281266
"tuple_input",
282267
"list_input",
283268
"bytes_input",
284-
"string_input",
285-
"dictionary_input",
286-
"dictionary_input_extra_fields",
287269
],
288270
)
289271
def test_parse_structured_fill_value_valid(
290272
fill_value: Any, dtype: np.dtype[Any], expected_result: Any
291273
) -> None:
292-
result = _parse_structured_fill_value(fill_value, dtype)
274+
zdtype = Structured.from_dtype(dtype)
275+
result = zdtype.cast_value(fill_value)
293276
assert result.dtype == expected_result.dtype
294277
assert result == expected_result
295278
if isinstance(expected_result, np.void):
296279
for name in expected_result.dtype.names or []:
297280
assert result[name] == expected_result[name]
298281

299282

300-
@pytest.mark.parametrize(
301-
(
302-
"fill_value",
303-
"dtype",
304-
),
305-
[
306-
(("Alice", 30), np.dtype([("name", "U10"), ("age", "i4"), ("city", "U20")])),
307-
(b"\x01\x00\x00\x00", np.dtype([("x", "i4"), ("y", "i4")])),
308-
("this_is_not_base64", np.dtype([("val", "i4")])),
309-
("hello", np.dtype([("age", "i4")])),
310-
({"x": 1, "y": 2}, np.dtype([("location", "i4")])),
311-
],
312-
ids=[
313-
"tuple_list_wrong_length",
314-
"bytes_wrong_length",
315-
"invalid_base64",
316-
"wrong_data_type",
317-
"wrong_dictionary",
318-
],
319-
)
320-
def test_parse_structured_fill_value_invalid(fill_value: Any, dtype: np.dtype[Any]) -> None:
321-
with pytest.raises(ValueError):
322-
_parse_structured_fill_value(fill_value, dtype)
323-
324-
325283
@pytest.mark.parametrize("fill_value", [None, b"x"], ids=["no_fill", "fill"])
326284
def test_other_dtype_roundtrip(fill_value, tmp_path) -> None:
327285
a = np.array([b"a\0\0", b"bb", b"ccc"], dtype="V7")

0 commit comments

Comments
 (0)