Skip to content

Commit 93b5bc0

Browse files
committed
support for datetime and timedelta dtypes (#2616)
* Add support for the datetime dtypes * Add support for the timedelta dtypes * Add test to validate the fill_values for for datetime * Add test to validate the fill_values for for timedelta * Add towncrier file for changes
1 parent 8d2fb47 commit 93b5bc0

File tree

5 files changed

+210
-44
lines changed

5 files changed

+210
-44
lines changed

changes/2616.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
NumPy’s datetime64 (‘M8’) and timedelta64 (‘m8’) dtypes are supported for Zarr arrays, as long as the units are specified.

docs/user-guide/arrays.rst

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,22 @@ In this example a shard shape of (1000, 1000) and a chunk shape of (100, 100) is
619619
This means that 10*10 chunks are stored in each shard, and there are 10*10 shards in total.
620620
Without the ``shards`` argument, there would be 10,000 chunks stored as individual files.
621621

622+
.. _user-guide-datetime:
623+
624+
Datetime and Timedelta arrays
625+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
626+
NumPy’s datetime64 (‘M8’) and timedelta64 (‘m8’) dtypes are supported for Zarr arrays, as long as the units are specified. E.g.:
627+
628+
>>> data = np.array(['2007-07-13', '2006-01-13', '2010-08-13'], dtype='M8[D]')
629+
>>> z = zarr.create_array(store='data/example-datetime.zarr', shape=data.shape, dtype=data.dtype)
630+
>>> z[:] = data
631+
>>> z[:]
632+
array(['2007-07-13', '2006-01-13', '2010-08-13'], dtype='datetime64[D]')
633+
>>> z[0] = '1999-12-31'
634+
>>> z[:]
635+
array(['1999-12-31', '2006-01-13', '2010-08-13'], dtype='datetime64[D]')
636+
637+
622638
Missing features in 3.0
623639
-----------------------
624640

@@ -639,13 +655,6 @@ Fixed-length string arrays
639655

640656
See the Zarr-Python 2 documentation on `Fixed-length string arrays <https://zarr.readthedocs.io/en/support-v2/tutorial.html#string-arrays>`_ for more details.
641657

642-
.. _user-guide-datetime:
643-
644-
Datetime and Timedelta arrays
645-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
646-
647-
See the Zarr-Python 2 documentation on `Datetime and Timedelta <https://zarr.readthedocs.io/en/support-v2/tutorial.html#datetimes-and-timedeltas>`_ for more details.
648-
649658
.. _user-guide-copy:
650659

651660
Copying and migrating data

src/zarr/core/metadata/v3.py

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
import warnings
45
from typing import TYPE_CHECKING, TypedDict, overload
56

@@ -174,11 +175,13 @@ def default(self, o: object) -> Any:
174175
return str(o)
175176
if np.isscalar(o):
176177
out: Any
177-
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
178+
if hasattr(o, "dtype") and o.dtype.kind in "Mm" and hasattr(o, "view"):
178179
# https://github.com/zarr-developers/zarr-python/issues/2119
179180
# `.item()` on a datetime type might or might not return an
180181
# integer, depending on the value.
181182
# Explicitly cast to an int first, and then grab .item()
183+
if np.isnat(o):
184+
return "NaT"
182185
out = o.view("i8").item()
183186
else:
184187
# convert numpy scalar to python type, and pass
@@ -440,12 +443,25 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
440443
FLOAT = np.float16 | np.float32 | np.float64
441444
COMPLEX_DTYPE = Literal["complex64", "complex128"]
442445
COMPLEX = np.complex64 | np.complex128
446+
DATETIME_DTYPE = Literal["datetime64"]
447+
DATETIME = np.datetime64
448+
TIMEDELTA_DTYPE = Literal["timedelta64"]
449+
TIMEDELTA = np.timedelta64
443450
STRING_DTYPE = Literal["string"]
444451
STRING = np.str_
445452
BYTES_DTYPE = Literal["bytes"]
446453
BYTES = np.bytes_
447454

448-
ALL_DTYPES = BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE
455+
ALL_DTYPES = (
456+
BOOL_DTYPE
457+
| INTEGER_DTYPE
458+
| FLOAT_DTYPE
459+
| COMPLEX_DTYPE
460+
| DATETIME_DTYPE
461+
| TIMEDELTA_DTYPE
462+
| STRING_DTYPE
463+
| BYTES_DTYPE
464+
)
449465

450466

451467
@overload
@@ -490,6 +506,20 @@ def parse_fill_value(
490506
) -> BYTES: ...
491507

492508

509+
@overload
510+
def parse_fill_value(
511+
fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool,
512+
dtype: DATETIME_DTYPE,
513+
) -> DATETIME: ...
514+
515+
516+
@overload
517+
def parse_fill_value(
518+
fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool,
519+
dtype: TIMEDELTA_DTYPE,
520+
) -> TIMEDELTA: ...
521+
522+
493523
def parse_fill_value(
494524
fill_value: Any,
495525
dtype: ALL_DTYPES,
@@ -551,12 +581,24 @@ def parse_fill_value(
551581
# fill_value != casted_value below.
552582
with warnings.catch_warnings():
553583
warnings.filterwarnings("ignore", category=DeprecationWarning)
554-
casted_value = np.dtype(np_dtype).type(fill_value)
584+
if np.dtype(np_dtype).kind in "Mm":
585+
# datetime64 values have an associated precision
586+
match = re.search(r"\[(.*?)\]", np.dtype(np_dtype).str)
587+
if match:
588+
precision = match.group(1)
589+
else:
590+
precision = "s"
591+
casted_value = np.dtype(np_dtype).type(fill_value, precision)
592+
else:
593+
casted_value = np.dtype(np_dtype).type(fill_value)
555594
except (ValueError, OverflowError, TypeError) as e:
556595
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") from e
557596
# Check if the value is still representable by the dtype
558-
if (fill_value == "NaN" and np.isnan(casted_value)) or (
559-
fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value)
597+
if (
598+
(fill_value == "NaN" and np.isnan(casted_value))
599+
or (fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value))
600+
or (fill_value == "NaT" and np.isnat(casted_value))
601+
or (np.dtype(np_dtype).kind in "Mm" and np.isnat(casted_value) and np.isnat(fill_value))
560602
):
561603
pass
562604
elif np_dtype.kind == "f":
@@ -576,7 +618,6 @@ def parse_fill_value(
576618
else:
577619
if fill_value != casted_value:
578620
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")
579-
580621
return casted_value
581622

582623

@@ -585,9 +626,17 @@ def default_fill_value(dtype: DataType) -> str | bytes | np.generic:
585626
return ""
586627
elif dtype == DataType.bytes:
587628
return b""
629+
np_dtype = dtype.to_numpy()
630+
np_dtype = cast(np.dtype[Any], np_dtype)
631+
if np_dtype.kind in "Mm":
632+
# datetime64 values have an associated precision
633+
match = re.search(r"\[(.*?)\]", np_dtype.str)
634+
if match:
635+
precision = match.group(1)
636+
else:
637+
precision = "s"
638+
return np_dtype.type("nat", precision) # type: ignore[misc,call-arg]
588639
else:
589-
np_dtype = dtype.to_numpy()
590-
np_dtype = cast(np.dtype[Any], np_dtype)
591640
return np_dtype.type(0) # type: ignore[misc]
592641

593642

@@ -610,6 +659,24 @@ class DataType(Enum):
610659
float64 = "float64"
611660
complex64 = "complex64"
612661
complex128 = "complex128"
662+
datetime64ns = ("datetime[ns]",)
663+
datetime64ms = ("datetime[ms]",)
664+
datetime64s = ("datetime[s]",)
665+
datetime64m = ("datetime[m]",)
666+
datetime64h = ("datetime[h]",)
667+
datetime64D = ("datetime[D]",)
668+
datetime64W = ("datetime[W]",)
669+
datetime64M = ("datetime[M]",)
670+
datetime64Y = ("datetime[Y]",)
671+
timedelta64ns = ("deltatime[ns]",)
672+
timedelta64ms = ("deltatime[ms]",)
673+
timedelta64s = ("deltatime[s]",)
674+
timedelta64m = ("deltatime[m]",)
675+
timedelta64h = ("deltatime[h]",)
676+
timedelta64D = ("deltatime[D]",)
677+
timedelta64W = ("deltatime[W]",)
678+
timedelta64M = ("deltatime[M]",)
679+
timedelta64Y = ("deltatime[Y]",)
613680
string = "string"
614681
bytes = "bytes"
615682

@@ -630,6 +697,24 @@ def byte_count(self) -> int | None:
630697
DataType.float64: 8,
631698
DataType.complex64: 8,
632699
DataType.complex128: 16,
700+
DataType.datetime64ns: 8,
701+
DataType.datetime64ms: 8,
702+
DataType.datetime64s: 8,
703+
DataType.datetime64m: 8,
704+
DataType.datetime64h: 8,
705+
DataType.datetime64D: 8,
706+
DataType.datetime64W: 8,
707+
DataType.datetime64M: 8,
708+
DataType.datetime64Y: 8,
709+
DataType.timedelta64ns: 8,
710+
DataType.timedelta64ms: 8,
711+
DataType.timedelta64s: 8,
712+
DataType.timedelta64m: 8,
713+
DataType.timedelta64h: 8,
714+
DataType.timedelta64D: 8,
715+
DataType.timedelta64W: 8,
716+
DataType.timedelta64M: 8,
717+
DataType.timedelta64Y: 8,
633718
}
634719
try:
635720
return data_type_byte_counts[self]
@@ -657,6 +742,24 @@ def to_numpy_shortname(self) -> str:
657742
DataType.float64: "f8",
658743
DataType.complex64: "c8",
659744
DataType.complex128: "c16",
745+
DataType.datetime64ns: "M8[ns]",
746+
DataType.datetime64ms: "M8[ms]",
747+
DataType.datetime64s: "M8[s]",
748+
DataType.datetime64m: "M8[m]",
749+
DataType.datetime64h: "M8[h]",
750+
DataType.datetime64D: "M8[D]",
751+
DataType.datetime64W: "M8[W]",
752+
DataType.datetime64M: "M8[M]",
753+
DataType.datetime64Y: "M8[Y]",
754+
DataType.timedelta64ns: "m8[ns]",
755+
DataType.timedelta64ms: "m8[ms]",
756+
DataType.timedelta64s: "m8[s]",
757+
DataType.timedelta64m: "m8[m]",
758+
DataType.timedelta64h: "m8[h]",
759+
DataType.timedelta64D: "m8[D]",
760+
DataType.timedelta64W: "m8[W]",
761+
DataType.timedelta64M: "m8[M]",
762+
DataType.timedelta64Y: "m8[Y]",
660763
}
661764
return data_type_to_numpy[self]
662765

@@ -700,6 +803,24 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
700803
"<f8": "float64",
701804
"<c8": "complex64",
702805
"<c16": "complex128",
806+
"<M8[ns]": "datetime64ns",
807+
"<M8[ms]": "datetime64ms",
808+
"<M8[s]": "datetime64s",
809+
"<M8[m]": "datetime64m",
810+
"<M8[h]": "datetime64h",
811+
"<M8[D]": "datetime64D",
812+
"<M8[W]": "datetime64W",
813+
"<M8[M]": "datetime64M",
814+
"<M8[Y]": "datetime64Y",
815+
"<m8[ns]": "timedelta64ns",
816+
"<m8[ms]": "timedelta64ms",
817+
"<m8[s]": "timedelta64s",
818+
"<m8[m]": "timedelta64m",
819+
"<m8[h]": "timedelta64h",
820+
"<m8[D]": "timedelta64D",
821+
"<m8[W]": "timedelta64W",
822+
"<m8[M]": "timedelta64M",
823+
"<m8[Y]": "timedelta64Y",
703824
}
704825
return DataType[dtype_to_data_type[dtype.str]]
705826

tests/test_array.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,15 @@ def test_array_v3_fill_value_default(
209209
@pytest.mark.parametrize("store", ["memory"], indirect=True)
210210
@pytest.mark.parametrize(
211211
("dtype_str", "fill_value"),
212-
[("bool", True), ("uint8", 99), ("float32", -99.9), ("complex64", 3 + 4j)],
212+
[
213+
("bool", True),
214+
("uint8", 99),
215+
("float32", -99.9),
216+
("complex64", 3 + 4j),
217+
("m8[ns]", 0),
218+
("M8[s]", None),
219+
("<m8[D]", "NaT"),
220+
],
213221
)
214222
def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str) -> None:
215223
shape = (10,)
@@ -221,9 +229,13 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str
221229
chunks=shape,
222230
fill_value=fill_value,
223231
)
224-
225-
assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
226232
assert arr.fill_value.dtype == arr.dtype
233+
if np.isfinite(arr.fill_value):
234+
assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
235+
else:
236+
if arr.dtype.kind in "Mm":
237+
assert np.isnat(arr.fill_value)
238+
assert np.isnat(np.dtype(dtype_str).type(fill_value))
227239

228240

229241
def test_create_positional_args_deprecated() -> None:

tests/test_metadata/test_v3.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -313,43 +313,66 @@ def test_json_indent(indent: int):
313313
assert d == json.dumps(json.loads(d), indent=indent).encode()
314314

315315

316-
# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
317-
# @pytest.mark.parametrize("precision", ["ns", "D"])
318-
# async def test_datetime_metadata(fill_value: int, precision: str) -> None:
319-
# metadata_dict = {
320-
# "zarr_format": 3,
321-
# "node_type": "array",
322-
# "shape": (1,),
323-
# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
324-
# "data_type": f"<M8[{precision}]",
325-
# "chunk_key_encoding": {"name": "default", "separator": "."},
326-
# "codecs": (),
327-
# "fill_value": np.datetime64(fill_value, precision),
328-
# }
329-
# metadata = ArrayV3Metadata.from_dict(metadata_dict)
330-
# # ensure there isn't a TypeError here.
331-
# d = metadata.to_buffer_dict(default_buffer_prototype())
332-
333-
# result = json.loads(d["zarr.json"].to_bytes())
334-
# assert result["fill_value"] == fill_value
335-
336-
337-
def test_invalid_dtype_raises() -> None:
316+
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897, "NaT"])
317+
@pytest.mark.parametrize("precision", ["ns", "ms", "s", "m", "h", "D", "W", "M", "Y"])
318+
async def test_datetime_metadata(fill_value: int, precision: str) -> None:
338319
metadata_dict = {
339320
"zarr_format": 3,
340321
"node_type": "array",
341322
"shape": (1,),
342323
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
343-
"data_type": "<M8[ns]",
324+
"data_type": f"<M8[{precision}]",
344325
"chunk_key_encoding": {"name": "default", "separator": "."},
345-
"codecs": (),
346-
"fill_value": np.datetime64(0, "ns"),
326+
"codecs": [BytesCodec()],
327+
"fill_value": np.datetime64(fill_value, precision),
328+
}
329+
metadata = ArrayV3Metadata.from_dict(metadata_dict)
330+
# ensure there isn't a TypeError here.
331+
d = metadata.to_buffer_dict(default_buffer_prototype())
332+
result = json.loads(d["zarr.json"].to_bytes())
333+
assert result["fill_value"] == fill_value
334+
335+
336+
@pytest.mark.parametrize("fill_value", [None, -1, 0, 1, 2932897, "NaT"])
337+
@pytest.mark.parametrize("precision", ["ns", "ms", "s", "m", "h", "D", "W", "M", "Y"])
338+
async def test_deltatime_metadata(fill_value: int, precision: str) -> None:
339+
metadata_dict = {
340+
"zarr_format": 3,
341+
"node_type": "array",
342+
"shape": (1,),
343+
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
344+
"data_type": f"<m8[{precision}]",
345+
"chunk_key_encoding": {"name": "default", "separator": "."},
346+
"codecs": [BytesCodec()],
347+
"fill_value": None if fill_value is None else np.timedelta64(fill_value, precision),
348+
}
349+
metadata = ArrayV3Metadata.from_dict(metadata_dict)
350+
# ensure there isn't a TypeError here.
351+
d = metadata.to_buffer_dict(default_buffer_prototype())
352+
result = json.loads(d["zarr.json"].to_bytes())
353+
if fill_value is None:
354+
assert result["fill_value"] == "NaT"
355+
else:
356+
assert result["fill_value"] == fill_value
357+
358+
359+
@pytest.mark.parametrize("data", ["foo", object()])
360+
def test_invalid_dtype_raises(data) -> None:
361+
metadata_dict = {
362+
"zarr_format": 3,
363+
"node_type": "array",
364+
"shape": (1,),
365+
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
366+
"data_type": data,
367+
"chunk_key_encoding": {"name": "default", "separator": "."},
368+
"codecs": [BytesCodec()],
369+
"fill_value": "",
347370
}
348371
with pytest.raises(ValueError, match=r"Invalid Zarr format 3 data_type: .*"):
349372
ArrayV3Metadata.from_dict(metadata_dict)
350373

351374

352-
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
375+
@pytest.mark.parametrize("data", ["foo", object()])
353376
def test_parse_invalid_dtype_raises(data):
354377
with pytest.raises(ValueError, match=r"Invalid Zarr format 3 data_type: .*"):
355378
DataType.parse(data)

0 commit comments

Comments
 (0)