Skip to content

Commit f1b01ac

Browse files
committed
fix: validate v3 dtypes when loading/creating v3 metadata
1 parent 8c5038a commit f1b01ac

File tree

6 files changed

+79
-25
lines changed

6 files changed

+79
-25
lines changed

src/zarr/core/array_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, Any, Literal
55

6-
from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike
6+
from zarr.core.common import parse_fill_value, parse_order, parse_shapelike
77

88
if TYPE_CHECKING:
99
import numpy as np
@@ -29,7 +29,7 @@ def __init__(
2929
prototype: BufferPrototype,
3030
) -> None:
3131
shape_parsed = parse_shapelike(shape)
32-
dtype_parsed = parse_dtype(dtype)
32+
dtype_parsed = dtype # parsing is likely not needed here
3333
fill_value_parsed = parse_fill_value(fill_value)
3434
order_parsed = parse_order(order)
3535

src/zarr/core/common.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
if TYPE_CHECKING:
2020
from collections.abc import Awaitable, Callable, Iterator
2121

22-
import numpy as np
23-
import numpy.typing as npt
2422

2523
ZARR_JSON = "zarr.json"
2624
ZARRAY_JSON = ".zarray"
@@ -154,11 +152,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]:
154152
return data_tuple
155153

156154

157-
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
158-
# todo: real validation
159-
return np.dtype(data)
160-
161-
162155
def parse_fill_value(data: Any) -> Any:
163156
# todo: real validation
164157
return data

src/zarr/core/metadata/v2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from zarr.core.array_spec import ArraySpec
2222
from zarr.core.chunk_grids import RegularChunkGrid
2323
from zarr.core.chunk_key_encodings import parse_separator
24-
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
24+
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
2525
from zarr.core.config import config, parse_indexing_order
2626
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
2727

@@ -157,6 +157,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
157157
return replace(self, attributes=attributes)
158158

159159

160+
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
161+
# todo: real validation
162+
return np.dtype(data)
163+
164+
160165
def parse_zarr_format(data: object) -> Literal[2]:
161166
if data == 2:
162167
return 2

src/zarr/core/metadata/v3.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from zarr.core.buffer import default_buffer_prototype
2525
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
2626
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
27-
from zarr.core.common import ZARR_JSON, parse_dtype, parse_named_configuration, parse_shapelike
27+
from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike
2828
from zarr.core.config import config
2929
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
3030
from zarr.registry import get_codec_class
@@ -215,6 +215,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
215215
# check that the node_type attribute is correct
216216
_ = parse_node_type_array(_data.pop("node_type"))
217217

218+
# check that the data_type attribute is valid
219+
if _data["data_type"] not in DataType:
220+
raise ValueError(f"Invalid V3 data_type: {_data['data_type']}")
221+
218222
# dimension_names key is optional, normalize missing to `None`
219223
_data["dimension_names"] = _data.pop("dimension_names", None)
220224
# attributes key is optional, normalize missing to `None`
@@ -345,8 +349,11 @@ class DataType(Enum):
345349
uint16 = "uint16"
346350
uint32 = "uint32"
347351
uint64 = "uint64"
352+
float16 = "float16"
348353
float32 = "float32"
349354
float64 = "float64"
355+
complex64 = "complex64"
356+
complex128 = "complex128"
350357

351358
@property
352359
def byte_count(self) -> int:
@@ -360,8 +367,11 @@ def byte_count(self) -> int:
360367
DataType.uint16: 2,
361368
DataType.uint32: 4,
362369
DataType.uint64: 8,
370+
DataType.float16: 2,
363371
DataType.float32: 4,
364372
DataType.float64: 8,
373+
DataType.complex64: 8,
374+
DataType.complex128: 16,
365375
}
366376
return data_type_byte_counts[self]
367377

@@ -381,8 +391,11 @@ def to_numpy_shortname(self) -> str:
381391
DataType.uint16: "u2",
382392
DataType.uint32: "u4",
383393
DataType.uint64: "u8",
394+
DataType.float16: "f2",
384395
DataType.float32: "f4",
385396
DataType.float64: "f8",
397+
DataType.complex64: "c8",
398+
DataType.complex128: "c16",
386399
}
387400
return data_type_to_numpy[self]
388401

@@ -399,7 +412,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
399412
"<u2": "uint16",
400413
"<u4": "uint32",
401414
"<u8": "uint64",
415+
"<f2": "float16",
402416
"<f4": "float32",
403417
"<f8": "float64",
418+
"<c8": "complex64",
419+
"<c16": "complex128",
404420
}
405421
return DataType[dtype_to_data_type[dtype.str]]
422+
423+
424+
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
425+
try:
426+
dtype = np.dtype(data)
427+
except TypeError as e:
428+
raise ValueError(f"Invalid V3 data_type: {data}") from e
429+
# check that this is a valid v3 data_type
430+
try:
431+
_ = DataType.from_dtype(dtype)
432+
except KeyError as e:
433+
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
434+
435+
return dtype

src/zarr/testing/strategies.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
3636
np_arrays = npst.arrays(
3737
# TODO: re-enable timedeltas once they are supported
38-
dtype=npst.scalar_dtypes().filter(lambda x: x.kind != "m"),
38+
dtype=npst.scalar_dtypes().filter(
39+
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
40+
),
3941
shape=npst.array_shapes(max_dims=4),
4042
)
4143
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))

tests/v3/test_metadata/test_v3.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations
22

3-
import json
43
import re
54
from typing import TYPE_CHECKING, Literal
65

76
from zarr.codecs.bytes import BytesCodec
8-
from zarr.core.buffer import default_buffer_prototype
97
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
108
from zarr.core.metadata.v3 import ArrayV3Metadata
119

@@ -19,7 +17,12 @@
1917
import numpy as np
2018
import pytest
2119

22-
from zarr.core.metadata.v3 import parse_dimension_names, parse_fill_value, parse_zarr_format
20+
from zarr.core.metadata.v3 import (
21+
parse_dimension_names,
22+
parse_dtype,
23+
parse_fill_value,
24+
parse_zarr_format,
25+
)
2326

2427
bool_dtypes = ("bool",)
2528

@@ -234,22 +237,43 @@ def test_metadata_to_dict(
234237
assert observed == expected
235238

236239

237-
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
238-
@pytest.mark.parametrize("precision", ["ns", "D"])
239-
async def test_datetime_metadata(fill_value: int, precision: str) -> None:
240+
# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
241+
# @pytest.mark.parametrize("precision", ["ns", "D"])
242+
# async def test_datetime_metadata(fill_value: int, precision: str) -> None:
243+
# metadata_dict = {
244+
# "zarr_format": 3,
245+
# "node_type": "array",
246+
# "shape": (1,),
247+
# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
248+
# "data_type": f"<M8[{precision}]",
249+
# "chunk_key_encoding": {"name": "default", "separator": "."},
250+
# "codecs": (),
251+
# "fill_value": np.datetime64(fill_value, precision),
252+
# }
253+
# metadata = ArrayV3Metadata.from_dict(metadata_dict)
254+
# # ensure there isn't a TypeError here.
255+
# d = metadata.to_buffer_dict(default_buffer_prototype())
256+
257+
# result = json.loads(d["zarr.json"].to_bytes())
258+
# assert result["fill_value"] == fill_value
259+
260+
261+
async def test_invalid_dtype_raises() -> None:
240262
metadata_dict = {
241263
"zarr_format": 3,
242264
"node_type": "array",
243265
"shape": (1,),
244266
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
245-
"data_type": f"<M8[{precision}]",
267+
"data_type": "<M8[ns]",
246268
"chunk_key_encoding": {"name": "default", "separator": "."},
247269
"codecs": (),
248-
"fill_value": np.datetime64(fill_value, precision),
270+
"fill_value": np.datetime64(0, "ns"),
249271
}
250-
metadata = ArrayV3Metadata.from_dict(metadata_dict)
251-
# ensure there isn't a TypeError here.
252-
d = metadata.to_buffer_dict(default_buffer_prototype())
272+
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
273+
ArrayV3Metadata.from_dict(metadata_dict)
274+
253275

254-
result = json.loads(d["zarr.json"].to_bytes())
255-
assert result["fill_value"] == fill_value
276+
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
277+
def test_parse_invalid_dtype_raises(data):
278+
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
279+
parse_dtype(data)

0 commit comments

Comments
 (0)