Skip to content

Commit 1d3d7a5

Browse files
committed
change v3.metadata.data_type type
1 parent 6454c69 commit 1d3d7a5

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

src/zarr/core/metadata/v3.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
if TYPE_CHECKING:
77
from typing import Self
88

9-
import numpy.typing as npt
10-
119
from zarr.core.buffer import Buffer, BufferPrototype
1210
from zarr.core.chunk_grids import ChunkGrid
1311
from zarr.core.common import JSON, ChunkCoords
@@ -20,6 +18,7 @@
2018

2119
import numcodecs.abc
2220
import numpy as np
21+
import numpy.typing as npt
2322

2423
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2524
from zarr.core.array_spec import ArraySpec
@@ -152,7 +151,7 @@ def _replace_special_floats(obj: object) -> Any:
152151
@dataclass(frozen=True, kw_only=True)
153152
class ArrayV3Metadata(ArrayMetadata):
154153
shape: ChunkCoords
155-
data_type: np.dtype[Any]
154+
data_type: DataType
156155
chunk_grid: ChunkGrid
157156
chunk_key_encoding: ChunkKeyEncoding
158157
fill_value: Any
@@ -167,7 +166,7 @@ def __init__(
167166
self,
168167
*,
169168
shape: Iterable[int],
170-
data_type: npt.DTypeLike,
169+
data_type: npt.DTypeLike | DataType,
171170
chunk_grid: dict[str, JSON] | ChunkGrid,
172171
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
173172
fill_value: Any,
@@ -180,18 +179,18 @@ def __init__(
180179
Because the class is a frozen dataclass, we set attributes using object.__setattr__
181180
"""
182181
shape_parsed = parse_shapelike(shape)
183-
data_type_parsed = parse_dtype(data_type)
182+
data_type_parsed = DataType.parse(data_type)
184183
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
185184
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
186185
dimension_names_parsed = parse_dimension_names(dimension_names)
187-
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed)
186+
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy_dtype())
188187
attributes_parsed = parse_attributes(attributes)
189188
codecs_parsed_partial = parse_codecs(codecs)
190189
storage_transformers_parsed = parse_storage_transformers(storage_transformers)
191190

192191
array_spec = ArraySpec(
193192
shape=shape_parsed,
194-
dtype=data_type_parsed,
193+
dtype=data_type_parsed.to_numpy_dtype(),
195194
fill_value=fill_value_parsed,
196195
order="C", # TODO: order is not needed here.
197196
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
@@ -224,11 +223,14 @@ def _validate_metadata(self) -> None:
224223
if self.fill_value is None:
225224
raise ValueError("`fill_value` is required.")
226225
for codec in self.codecs:
227-
codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid)
226+
codec.validate(
227+
shape=self.shape, dtype=self.data_type.to_numpy_dtype(), chunk_grid=self.chunk_grid
228+
)
228229

229230
@property
230231
def dtype(self) -> np.dtype[Any]:
231-
return self.data_type
232+
"""Interpret Zarr dtype as NumPy dtype"""
233+
return self.data_type.to_numpy_dtype()
232234

233235
@property
234236
def ndim(self) -> int:
@@ -266,13 +268,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
266268
_ = parse_node_type_array(_data.pop("node_type"))
267269

268270
# check that the data_type attribute is valid
269-
_ = DataType(_data["data_type"])
271+
data_type = DataType.parse(_data.pop("data_type"))
270272

271273
# dimension_names key is optional, normalize missing to `None`
272274
_data["dimension_names"] = _data.pop("dimension_names", None)
273275
# attributes key is optional, normalize missing to `None`
274276
_data["attributes"] = _data.pop("attributes", None)
275-
return cls(**_data) # type: ignore[arg-type]
277+
return cls(**_data, data_type=data_type) # type: ignore[arg-type]
276278

277279
def to_dict(self) -> dict[str, JSON]:
278280
out_dict = super().to_dict()
@@ -490,8 +492,11 @@ def to_numpy_shortname(self) -> str:
490492
}
491493
return data_type_to_numpy[self]
492494

495+
def to_numpy_dtype(self) -> np.dtype[Any]:
496+
return np.dtype(self.to_numpy_shortname())
497+
493498
@classmethod
494-
def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
499+
def from_numpy_dtype(cls, dtype: np.dtype[Any]) -> DataType:
495500
dtype_to_data_type = {
496501
"|b1": "bool",
497502
"bool": "bool",
@@ -511,16 +516,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
511516
}
512517
return DataType[dtype_to_data_type[dtype.str]]
513518

514-
515-
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
516-
try:
517-
dtype = np.dtype(data)
518-
except (ValueError, TypeError) as e:
519-
raise ValueError(f"Invalid V3 data_type: {data}") from e
520-
# check that this is a valid v3 data_type
521-
try:
522-
_ = DataType.from_dtype(dtype)
523-
except KeyError as e:
524-
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
525-
526-
return dtype
519+
@classmethod
520+
def parse(cls, dtype: None | DataType | Any) -> DataType:
521+
if dtype is None:
522+
# the default dtype
523+
return DataType.float64
524+
if isinstance(dtype, DataType):
525+
return dtype
526+
else:
527+
try:
528+
dtype = np.dtype(dtype)
529+
except (ValueError, TypeError) as e:
530+
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
531+
# check that this is a valid v3 data_type
532+
try:
533+
data_type = DataType.from_numpy_dtype(dtype)
534+
except KeyError as e:
535+
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
536+
return data_type

tests/v3/test_metadata/test_v3.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from zarr.codecs.bytes import BytesCodec
88
from zarr.core.buffer import default_buffer_prototype
99
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
10-
from zarr.core.metadata.v3 import ArrayV3Metadata
10+
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Sequence
@@ -22,7 +22,6 @@
2222

2323
from zarr.core.metadata.v3 import (
2424
parse_dimension_names,
25-
parse_dtype,
2625
parse_fill_value,
2726
parse_zarr_format,
2827
)
@@ -209,7 +208,7 @@ def test_metadata_to_dict(
209208
storage_transformers: None | tuple[dict[str, JSON]],
210209
) -> None:
211210
shape = (1, 2, 3)
212-
data_type = "uint8"
211+
data_type = DataType.uint8
213212
if chunk_grid == "regular":
214213
cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}}
215214

@@ -290,7 +289,7 @@ def test_metadata_to_dict(
290289
# assert result["fill_value"] == fill_value
291290

292291

293-
async def test_invalid_dtype_raises() -> None:
292+
def test_invalid_dtype_raises() -> None:
294293
metadata_dict = {
295294
"zarr_format": 3,
296295
"node_type": "array",
@@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None:
301300
"codecs": (),
302301
"fill_value": np.datetime64(0, "ns"),
303302
}
304-
with pytest.raises(ValueError, match=r".* is not a valid DataType"):
303+
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
305304
ArrayV3Metadata.from_dict(metadata_dict)
306305

307306

308307
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
309308
def test_parse_invalid_dtype_raises(data):
310309
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
311-
parse_dtype(data)
310+
DataType.parse(data)
312311

313312

314313
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)