Skip to content

Commit 19d61c9

Browse files
committed
Unify metadata v2 fill value parsing
1 parent 260cfbc commit 19d61c9

File tree

1 file changed

+69
-53
lines changed
  • src/zarr/core/metadata

1 file changed

+69
-53
lines changed

src/zarr/core/metadata/v2.py

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
import base64
44
import warnings
5-
from collections.abc import Iterable
5+
from collections.abc import Iterable, Mapping, Sequence
66
from enum import Enum
77
from functools import cached_property
8-
from typing import TYPE_CHECKING, TypedDict, cast
8+
from typing import TYPE_CHECKING, Any, TypedDict, cast
99

1010
import numcodecs.abc
1111

1212
from zarr.abc.metadata import Metadata
1313

1414
if TYPE_CHECKING:
15-
from typing import Any, Literal, Self
15+
from typing import Literal, Self
1616

1717
import numpy.typing as npt
1818

@@ -109,6 +109,29 @@ def shards(self) -> ChunkCoords | None:
109109
return None
110110

111111
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
112+
def _serialize_fill_value(fv: Any) -> JSON:
113+
if self.fill_value is None:
114+
pass
115+
elif self.dtype.kind in "SV":
116+
# There's a relationship between self.dtype and self.fill_value
117+
# that mypy isn't aware of. The fact that we have S or V dtype here
118+
# means we should have a bytes-type fill_value.
119+
fv = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii")
120+
elif isinstance(fv, np.datetime64):
121+
if np.isnat(fv):
122+
fv = "NaT"
123+
else:
124+
fv = np.datetime_as_string(fv)
125+
elif isinstance(fv, numbers.Real):
126+
float_fv = float(fv)
127+
if np.isnan(float_fv):
128+
fv = "NaN"
129+
elif np.isinf(float_fv):
130+
fv = "Infinity" if float_fv > 0 else "-Infinity"
131+
elif isinstance(fv, numbers.Complex):
132+
fv = [_serialize_fill_value(fv.real), _serialize_fill_value(fv.imag)]
133+
return cast(JSON, fv)
134+
112135
def _json_convert(
113136
o: Any,
114137
) -> Any:
@@ -147,6 +170,7 @@ def _json_convert(
147170
raise TypeError
148171

149172
zarray_dict = self.to_dict()
173+
zarray_dict["fill_value"] = _serialize_fill_value(zarray_dict["fill_value"])
150174
zattrs_dict = zarray_dict.pop("attributes", {})
151175
json_indent = config.get("json_indent")
152176
return {
@@ -161,38 +185,24 @@ def _json_convert(
161185
}
162186

163187
@classmethod
164-
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
188+
def from_dict(cls, data: dict[str, JSON]) -> ArrayV2Metadata:
165189
# Make a copy to protect the original from modification.
166190
_data = data.copy()
167191
# Check that the zarr_format attribute is correct.
168192
_ = parse_zarr_format(_data.pop("zarr_format"))
169-
dtype = parse_dtype(_data["dtype"])
170193

171-
if dtype.kind in "SV":
172-
fill_value_encoded = _data.get("fill_value")
173-
if fill_value_encoded is not None:
174-
fill_value: Any = base64.standard_b64decode(fill_value_encoded)
175-
_data["fill_value"] = fill_value
176-
else:
177-
fill_value = _data.get("fill_value")
178-
if fill_value is not None:
179-
if np.issubdtype(dtype, np.datetime64):
180-
if fill_value == "NaT":
181-
_data["fill_value"] = np.array("NaT", dtype=dtype)[()]
182-
else:
183-
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
184-
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
185-
val = complex(float(fill_value[0]), float(fill_value[1]))
186-
_data["fill_value"] = np.array(val, dtype=dtype)[()]
187-
elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}:
188-
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
189194
# zarr v2 allowed arbitrary keys in the metadata.
190195
# Filter the keys to only those expected by the constructor.
191196
expected = {x.name for x in fields(cls)}
192197
expected |= {"dtype", "chunks"}
193198

194199
# check if `filters` is an empty sequence; if so use None instead and raise a warning
195-
if _data["filters"] is not None and len(_data["filters"]) == 0:
200+
filters = _data.get("filters")
201+
if (
202+
isinstance(filters, Sequence)
203+
and not isinstance(filters, (str, bytes))
204+
and len(filters) == 0
205+
):
196206
msg = (
197207
"Found an empty list of filters in the array metadata document. "
198208
"This is contrary to the Zarr V2 specification, and will cause an error in the future. "
@@ -203,36 +213,11 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
203213

204214
_data = {k: v for k, v in _data.items() if k in expected}
205215

206-
return cls(**_data)
216+
return cls(**cast(Mapping[str, Any], _data))
207217

208218
def to_dict(self) -> dict[str, JSON]:
209-
def _sanitize_fill_value(fv: Any) -> JSON:
210-
if fv is None:
211-
return fv
212-
elif isinstance(fv, np.datetime64):
213-
if np.isnat(fv):
214-
return "NaT"
215-
return np.datetime_as_string(fv)
216-
elif isinstance(fv, numbers.Real):
217-
float_fv = float(fv)
218-
if np.isnan(float_fv):
219-
fv = "NaN"
220-
elif np.isinf(float_fv):
221-
fv = "Infinity" if float_fv > 0 else "-Infinity"
222-
elif isinstance(fv, numbers.Complex):
223-
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
224-
return cast(JSON, fv)
225-
226219
zarray_dict = super().to_dict()
227220

228-
if self.dtype.kind in "SV" and self.fill_value is not None:
229-
# There's a relationship between self.dtype and self.fill_value
230-
# that mypy isn't aware of. The fact that we have S or V dtype here
231-
# means we should have a bytes-type fill_value.
232-
fill_value = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii")
233-
zarray_dict["fill_value"] = fill_value
234-
235-
zarray_dict["fill_value"] = _sanitize_fill_value(zarray_dict["fill_value"])
236221
_ = zarray_dict.pop("dtype")
237222
dtype_json: JSON
238223
# In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string
@@ -330,7 +315,25 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata:
330315
return data
331316

332317

333-
def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
318+
def parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
319+
"""Handle structured dtype/fill value pairs"""
320+
try:
321+
if isinstance(fill_value, (tuple, list)):
322+
fill_value = np.array([fill_value], dtype=dtype)[0]
323+
elif isinstance(fill_value, bytes):
324+
fill_value = np.frombuffer(fill_value, dtype=dtype)[0]
325+
elif isinstance(fill_value, str):
326+
decoded = base64.standard_b64decode(fill_value)
327+
fill_value = np.frombuffer(decoded, dtype=dtype)[0]
328+
else:
329+
fill_value = np.array(fill_value, dtype=dtype)[()]
330+
except Exception as e:
331+
msg = f"Fill_value {fill_value} is not valid for dtype {dtype}."
332+
raise ValueError(msg) from e
333+
return fill_value
334+
335+
336+
def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
334337
"""
335338
Parse a potential fill value into a value that is compatible with the provided dtype.
336339
@@ -345,14 +348,15 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
345348
-------
346349
An instance of `dtype`, or `None`, or any python object (in the case of an object dtype)
347350
"""
351+
348352
if fill_value is None or dtype.hasobject:
349-
# no fill value
350353
pass
354+
elif dtype.fields is not None:
355+
fill_value = parse_structured_fill_value(fill_value, dtype)
351356
elif not isinstance(fill_value, np.void) and fill_value == 0:
352357
# this should be compatible across numpy versions for any array type, including
353358
# structured arrays
354359
fill_value = np.zeros((), dtype=dtype)[()]
355-
356360
elif dtype.kind == "U":
357361
# special case unicode because of encoding issues on Windows if passed through numpy
358362
# https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713
@@ -361,6 +365,18 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
361365
raise ValueError(
362366
f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string"
363367
)
368+
elif dtype.kind in "SV" and isinstance(fill_value, str):
369+
fill_value = base64.standard_b64decode(fill_value)
370+
elif np.issubdtype(dtype, np.datetime64):
371+
if fill_value == "NaT":
372+
fill_value = np.array("NaT", dtype=dtype)[()]
373+
else:
374+
fill_value = np.array(fill_value, dtype=dtype)[()]
375+
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
376+
complex_val = complex(float(fill_value[0]), float(fill_value[1]))
377+
fill_value = np.array(complex_val, dtype=dtype)[()]
378+
elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}:
379+
fill_value = np.array(fill_value, dtype=dtype)[()]
364380
else:
365381
try:
366382
if isinstance(fill_value, bytes) and dtype.kind == "V":

0 commit comments

Comments
 (0)