Skip to content

Commit 0441ee3

Browse files
committed
Fix fill_value handling for complex & datetime dtypes
1 parent 8c5038a commit 0441ee3

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

src/zarr/core/metadata/v3.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
257257
FLOAT = np.float16 | np.float32 | np.float64
258258
COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType
259259
COMPLEX = np.complex64 | np.complex128
260+
DATETIME_DTYPE = np.dtypes.DateTime64DType
260261

261262

262263
@overload
@@ -275,6 +276,10 @@ def parse_fill_value(fill_value: object, dtype: FLOAT_DTYPE) -> FLOAT: ...
275276
def parse_fill_value(fill_value: object, dtype: COMPLEX_DTYPE) -> COMPLEX: ...
276277

277278

279+
@overload
280+
def parse_fill_value(fill_value: object, dtype: DATETIME_DTYPE) -> DATETIME_DTYPE: ...
281+
282+
278283
@overload
279284
def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
280285
# This dtype[Any] is unfortunately necessary right now.
@@ -314,7 +319,7 @@ def parse_fill_value(
314319
if fill_value is None:
315320
return dtype.type(0)
316321
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
317-
if dtype in (np.complex64, np.complex128):
322+
if dtype.type in (np.complex64, np.complex128):
318323
dtype = cast(COMPLEX_DTYPE, dtype)
319324
if len(fill_value) == 2:
320325
# complex datatypes serialize to JSON arrays with two elements
@@ -328,7 +333,12 @@ def parse_fill_value(
328333
raise ValueError(msg)
329334
msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}."
330335
raise TypeError(msg)
331-
return dtype.type(fill_value) # type: ignore[arg-type]
336+
if np.issubdtype(dtype, np.datetime64):
337+
if TYPE_CHECKING:
338+
assert isinstance(dtype.type, np.datetime64)
339+
return dtype.type(fill_value, np.datetime_data(dtype)) # type: ignore[unreachable]
340+
else:
341+
return dtype.type(fill_value) # type: ignore[arg-type]
332342

333343

334344
# For type checking

src/zarr/testing/strategies.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from typing import Any
32

43
import hypothesis.extra.numpy as npst
@@ -75,6 +74,7 @@ def arrays(
7574
path = draw(paths)
7675
name = draw(array_names)
7776
attributes = draw(attrs)
77+
fill_value = draw(npst.from_dtype(nparray.dtype))
7878
# compressor = draw(compressors)
7979

8080
# TODO: clean this up
@@ -100,16 +100,6 @@ def arrays(
100100

101101
array_path = path + ("/" if not path.endswith("/") else "") + name
102102
root = Group.create(store)
103-
fill_value_args: tuple[Any, ...] = tuple()
104-
if nparray.dtype.kind == "M":
105-
m = re.search(r"\[(.+)\]", nparray.dtype.str)
106-
if not m:
107-
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")
108-
109-
fill_value_args = (
110-
# e.g. ns, D
111-
m.groups()[0],
112-
)
113103

114104
a = root.create_array(
115105
array_path,
@@ -118,7 +108,7 @@ def arrays(
118108
dtype=nparray.dtype.str,
119109
attributes=attributes,
120110
# compressor=compressor, # TODO: FIXME
121-
fill_value=nparray.dtype.type(0, *fill_value_args),
111+
fill_value=fill_value,
122112
)
123113

124114
assert isinstance(a, Array)

0 commit comments

Comments
 (0)