diff --git a/changes/2616.feature.rst b/changes/2616.feature.rst new file mode 100644 index 0000000000..6a9422f757 --- /dev/null +++ b/changes/2616.feature.rst @@ -0,0 +1 @@ +NumPy’s datetime64 (‘M8’) and timedelta64 (‘m8’) dtypes are supported for Zarr arrays, as long as the units are specified. diff --git a/docs/user-guide/arrays.rst b/docs/user-guide/arrays.rst index a62b2ea0fa..90b144b8cf 100644 --- a/docs/user-guide/arrays.rst +++ b/docs/user-guide/arrays.rst @@ -619,6 +619,23 @@ In this example a shard shape of (1000, 1000) and a chunk shape of (100, 100) is This means that 10*10 chunks are stored in each shard, and there are 10*10 shards in total. Without the ``shards`` argument, there would be 10,000 chunks stored as individual files. +.. _user-guide-datetime: + +Datetime and Timedelta arrays +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +NumPy’s datetime64 (‘M8’) and timedelta64 (‘m8’) dtypes are supported for Zarr arrays, as long as the units are specified. E.g.: + +.. code-block:: python + >>> data = np.array(['2007-07-13', '2006-01-13', '2010-08-13'], dtype='M8[D]') + >>> z = zarr.create_array(store='data/example-datetime.zarr', shape=data.shape, dtype=data.dtype) + >>> z[:] = data + >>> z[:] + array(['2007-07-13', '2006-01-13', '2010-08-13'], dtype='datetime64[D]') + >>> z[0] = '1999-12-31' + >>> z[:] + array(['1999-12-31', '2006-01-13', '2010-08-13'], dtype='datetime64[D]') + + Missing features in 3.0 ----------------------- @@ -639,13 +656,6 @@ Fixed-length string arrays See the Zarr-Python 2 documentation on `Fixed-length string arrays `_ for more details. -.. _user-guide-datetime: - -Datetime and Timedelta arrays -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -See the Zarr-Python 2 documentation on `Datetime and Timedelta `_ for more details. - .. _user-guide-copy: Copying and migrating data diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 9154762648..10f2589624 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import warnings from typing import TYPE_CHECKING, TypedDict, overload @@ -174,11 +175,13 @@ def default(self, o: object) -> Any: return str(o) if np.isscalar(o): out: Any - if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + if hasattr(o, "dtype") and o.dtype.kind in "Mm" and hasattr(o, "view"): # https://github.com/zarr-developers/zarr-python/issues/2119 # `.item()` on a datetime type might or might not return an # integer, depending on the value. # Explicitly cast to an int first, and then grab .item() + if np.isnat(o): + return "NaT" out = o.view("i8").item() else: # convert numpy scalar to python type, and pass @@ -440,12 +443,25 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: FLOAT = np.float16 | np.float32 | np.float64 COMPLEX_DTYPE = Literal["complex64", "complex128"] COMPLEX = np.complex64 | np.complex128 +DATETIME_DTYPE = Literal["datetime64"] +DATETIME = np.datetime64 +TIMEDELTA_DTYPE = Literal["timedelta64"] +TIMEDELTA = np.timedelta64 STRING_DTYPE = Literal["string"] STRING = np.str_ BYTES_DTYPE = Literal["bytes"] BYTES = np.bytes_ -ALL_DTYPES = BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE +ALL_DTYPES = ( + BOOL_DTYPE + | INTEGER_DTYPE + | FLOAT_DTYPE + | COMPLEX_DTYPE + | DATETIME_DTYPE + | TIMEDELTA_DTYPE + | STRING_DTYPE + | BYTES_DTYPE +) @overload @@ -490,6 +506,20 @@ def parse_fill_value( ) -> BYTES: ... +@overload +def parse_fill_value( + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, + dtype: DATETIME_DTYPE, +) -> DATETIME: ... + + +@overload +def parse_fill_value( + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, + dtype: TIMEDELTA_DTYPE, +) -> TIMEDELTA: ... + + def parse_fill_value( fill_value: Any, dtype: ALL_DTYPES, @@ -551,12 +581,24 @@ def parse_fill_value( # fill_value != casted_value below. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - casted_value = np.dtype(np_dtype).type(fill_value) + if np.dtype(np_dtype).kind in "Mm": + # datetime64 values have an associated precision + match = re.search(r"\[(.*?)\]", np.dtype(np_dtype).str) + if match: + precision = match.group(1) + else: + precision = "s" + casted_value = np.dtype(np_dtype).type(fill_value, precision) + else: + casted_value = np.dtype(np_dtype).type(fill_value) except (ValueError, OverflowError, TypeError) as e: raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") from e # Check if the value is still representable by the dtype - if (fill_value == "NaN" and np.isnan(casted_value)) or ( - fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value) + if ( + (fill_value == "NaN" and np.isnan(casted_value)) + or (fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value)) + or (fill_value == "NaT" and np.isnat(casted_value)) + or (np.dtype(np_dtype).kind in "Mm" and np.isnat(casted_value) and np.isnat(fill_value)) ): pass elif np_dtype.kind == "f": @@ -576,7 +618,6 @@ def parse_fill_value( else: if fill_value != casted_value: raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") - return casted_value @@ -585,9 +626,17 @@ def default_fill_value(dtype: DataType) -> str | bytes | np.generic: return "" elif dtype == DataType.bytes: return b"" + np_dtype = dtype.to_numpy() + np_dtype = cast(np.dtype[Any], np_dtype) + if np_dtype.kind in "Mm": + # datetime64 values have an associated precision + match = re.search(r"\[(.*?)\]", np_dtype.str) + if match: + precision = match.group(1) + else: + precision = "s" + return np_dtype.type("nat", precision) # type: ignore[misc,call-arg] else: - np_dtype = dtype.to_numpy() - np_dtype = cast(np.dtype[Any], np_dtype) return np_dtype.type(0) # type: ignore[misc] @@ -610,6 +659,24 @@ class DataType(Enum): float64 = "float64" complex64 = "complex64" complex128 = "complex128" + datetime64ns = "datetime64ns" + datetime64ms = "datetime64ms" + datetime64s = "datetime64s" + datetime64m = "datetime64m" + datetime64h = "datetime64h" + datetime64D = "datetime64D" + datetime64W = "datetime64W" + datetime64M = "datetime64M" + datetime64Y = "datetime64Y" + timedelta64ns = "timedelta64ns" + timedelta64ms = "timedelta64ms" + timedelta64s = "timedelta64s" + timedelta64m = "timedelta64m" + timedelta64h = "timedelta64h" + timedelta64D = "timedelta64D" + timedelta64W = "timedelta64W" + timedelta64M = "timedelta64M" + timedelta64Y = "timedelta64Y" string = "string" bytes = "bytes" @@ -630,6 +697,24 @@ def byte_count(self) -> int | None: DataType.float64: 8, DataType.complex64: 8, DataType.complex128: 16, + DataType.datetime64ns: 8, + DataType.datetime64ms: 8, + DataType.datetime64s: 8, + DataType.datetime64m: 8, + DataType.datetime64h: 8, + DataType.datetime64D: 8, + DataType.datetime64W: 8, + DataType.datetime64M: 8, + DataType.datetime64Y: 8, + DataType.timedelta64ns: 8, + DataType.timedelta64ms: 8, + DataType.timedelta64s: 8, + DataType.timedelta64m: 8, + DataType.timedelta64h: 8, + DataType.timedelta64D: 8, + DataType.timedelta64W: 8, + DataType.timedelta64M: 8, + DataType.timedelta64Y: 8, } try: return data_type_byte_counts[self] @@ -657,6 +742,24 @@ def to_numpy_shortname(self) -> str: DataType.float64: "f8", DataType.complex64: "c8", DataType.complex128: "c16", + DataType.datetime64ns: "M8[ns]", + DataType.datetime64ms: "M8[ms]", + DataType.datetime64s: "M8[s]", + DataType.datetime64m: "M8[m]", + DataType.datetime64h: "M8[h]", + DataType.datetime64D: "M8[D]", + DataType.datetime64W: "M8[W]", + DataType.datetime64M: "M8[M]", + DataType.datetime64Y: "M8[Y]", + DataType.timedelta64ns: "m8[ns]", + DataType.timedelta64ms: "m8[ms]", + DataType.timedelta64s: "m8[s]", + DataType.timedelta64m: "m8[m]", + DataType.timedelta64h: "m8[h]", + DataType.timedelta64D: "m8[D]", + DataType.timedelta64W: "m8[W]", + DataType.timedelta64M: "m8[M]", + DataType.timedelta64Y: "m8[Y]", } return data_type_to_numpy[self] @@ -700,6 +803,24 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: " None: shape = (10,) @@ -221,9 +229,13 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str chunks=shape, fill_value=fill_value, ) - - assert arr.fill_value == np.dtype(dtype_str).type(fill_value) assert arr.fill_value.dtype == arr.dtype + if np.isfinite(arr.fill_value): + assert arr.fill_value == np.dtype(dtype_str).type(fill_value) + else: + if arr.dtype.kind in "Mm": + assert np.isnat(arr.fill_value) + assert np.isnat(np.dtype(dtype_str).type(fill_value)) def test_create_positional_args_deprecated() -> None: diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index a47cbf43bb..89d2e1d5e3 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -56,7 +56,37 @@ complex_dtypes = ("complex64", "complex128") vlen_dtypes = ("string", "bytes") -dtypes = (*bool_dtypes, *int_dtypes, *float_dtypes, *complex_dtypes, *vlen_dtypes) +datetime_dtypes = ( + "datetime64ns", + "datetime64ms", + "datetime64s", + "datetime64m", + "datetime64h", + "datetime64D", + "datetime64W", + "datetime64M", + "datetime64Y", +) +deltatime_dtypes = ( + "timedelta64ns", + "timedelta64ms", + "timedelta64s", + "timedelta64m", + "timedelta64h", + "timedelta64D", + "timedelta64W", + "timedelta64M", + "timedelta64Y", +) +dtypes = ( + *bool_dtypes, + *int_dtypes, + *float_dtypes, + *complex_dtypes, + *vlen_dtypes, + *datetime_dtypes, + *deltatime_dtypes, +) @pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"]) @@ -119,6 +149,8 @@ def test_default_fill_value(dtype_str: str) -> None: assert fill_value == "" elif dtype == DataType.bytes: assert fill_value == b"" + elif np.dtype(dtype.to_numpy()).kind in "Mm": + assert np.isnat(fill_value.view()) else: assert fill_value == dtype.to_numpy().type(0) @@ -313,43 +345,66 @@ def test_json_indent(indent: int): assert d == json.dumps(json.loads(d), indent=indent).encode() -# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) -# @pytest.mark.parametrize("precision", ["ns", "D"]) -# async def test_datetime_metadata(fill_value: int, precision: str) -> None: -# metadata_dict = { -# "zarr_format": 3, -# "node_type": "array", -# "shape": (1,), -# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, -# "data_type": f" None: +@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897, "NaT"]) +@pytest.mark.parametrize("precision", ["ns", "ms", "s", "m", "h", "D", "W", "M", "Y"]) +async def test_datetime_metadata(fill_value: int, precision: str) -> None: metadata_dict = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, - "data_type": " None: + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": f" None: + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": data, + "chunk_key_encoding": {"name": "default", "separator": "."}, + "codecs": [BytesCodec()], + "fill_value": "", } with pytest.raises(ValueError, match=r"Invalid Zarr format 3 data_type: .*"): ArrayV3Metadata.from_dict(metadata_dict) -@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()]) +@pytest.mark.parametrize("data", ["foo", object()]) def test_parse_invalid_dtype_raises(data): with pytest.raises(ValueError, match=r"Invalid Zarr format 3 data_type: .*"): DataType.parse(data)