Skip to content

Commit b8d9789

Browse files
committed
minimal typedicts for array metadata
1 parent 67a0df6 commit b8d9789

File tree

5 files changed

+140
-68
lines changed

5 files changed

+140
-68
lines changed

src/zarr/api/asynchronous.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from zarr.core.common import JSON, AccessModeLiteral, ChunkCoords, MemoryOrder, ZarrFormat
1313
from zarr.core.config import config
1414
from zarr.core.group import AsyncGroup
15+
from zarr.core.metadata.common import ArrayMetadataDict
1516
from zarr.core.metadata.v2 import ArrayV2Metadata
1617
from zarr.core.metadata.v3 import ArrayV3Metadata
1718
from zarr.storage import (
@@ -236,11 +237,13 @@ async def open(
236237
if "shape" not in kwargs and mode in {"a", "w", "w-"}:
237238
try:
238239
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
240+
# TODO: remove this cast when we fix typing for array metadata dicts
241+
_metadata_dict = cast(ArrayMetadataDict, metadata_dict)
239242
# for v2, the above would already have raised an exception if not an array
240-
zarr_format = metadata_dict["zarr_format"]
241-
is_v3_array = zarr_format == 3 and metadata_dict.get("node_type") == "array"
243+
zarr_format = _metadata_dict["zarr_format"]
244+
is_v3_array = zarr_format == 3 and _metadata_dict.get("node_type") == "array"
242245
if is_v3_array or zarr_format == 2:
243-
return AsyncArray(store_path=store_path, metadata=metadata_dict)
246+
return AsyncArray(store_path=store_path, metadata=_metadata_dict)
244247
except (AssertionError, FileNotFoundError):
245248
pass
246249
return await open_group(store=store_path, zarr_format=zarr_format, mode=mode, **kwargs)

src/zarr/core/array.py

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@
6464
is_scalar,
6565
pop_fields,
6666
)
67-
from zarr.core.metadata.v2 import ArrayV2Metadata
68-
from zarr.core.metadata.v3 import ArrayV3Metadata
67+
from zarr.core.metadata.common import ArrayMetadata, ArrayMetadataDict
68+
from zarr.core.metadata.v2 import ArrayV2Metadata, ArrayV2MetadataDict
69+
from zarr.core.metadata.v3 import ArrayV3Metadata, ArrayV3MetadataDict
6970
from zarr.core.sync import collect_aiterator, sync
7071
from zarr.registry import get_pipeline_class
7172
from zarr.storage import StoreLike, make_store_path
@@ -84,8 +85,8 @@
8485
logger = getLogger(__name__)
8586

8687

87-
def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata:
88-
if isinstance(data, ArrayV2Metadata | ArrayV3Metadata):
88+
def parse_array_metadata(data: Any) -> ArrayMetadata:
89+
if isinstance(data, ArrayMetadata):
8990
return data
9091
elif isinstance(data, dict):
9192
if data["zarr_format"] == 3:
@@ -102,7 +103,7 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata:
102103
raise TypeError
103104

104105

105-
def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> CodecPipeline:
106+
def create_codec_pipeline(metadata: ArrayMetadata) -> CodecPipeline:
106107
if isinstance(metadata, ArrayV3Metadata):
107108
return get_pipeline_class().from_codecs(metadata.codecs)
108109
elif isinstance(metadata, ArrayV2Metadata):
@@ -115,7 +116,7 @@ def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> CodecP
115116

116117
async def get_array_metadata(
117118
store_path: StorePath, zarr_format: ZarrFormat | None = 3
118-
) -> dict[str, Any]:
119+
) -> dict[str, JSON]:
119120
if zarr_format == 2:
120121
zarray_bytes, zattrs_bytes = await gather(
121122
(store_path / ZARRAY_JSON).get(), (store_path / ZATTRS_JSON).get()
@@ -146,7 +147,7 @@ async def get_array_metadata(
146147
else:
147148
raise ValueError(f"unexpected zarr_format: {zarr_format}")
148149

149-
metadata_dict: dict[str, Any]
150+
metadata_dict: dict[str, JSON]
150151
if zarr_format == 2:
151152
# V2 arrays are comprised of a .zarray and .zattrs objects
152153
assert zarray_bytes is not None
@@ -170,18 +171,38 @@ class AsyncArray(Generic[TArrayMeta]):
170171
codec_pipeline: CodecPipeline = field(init=False)
171172
order: Literal["C", "F"]
172173

174+
@overload
175+
def __init__(
176+
self: AsyncArray[ArrayV2Metadata],
177+
metadata: ArrayV2Metadata | ArrayV2MetadataDict,
178+
store_path: StorePath,
179+
order: Literal["C", "F"] | None = None,
180+
) -> None: ...
181+
182+
@overload
183+
def __init__(
184+
self: AsyncArray[ArrayV3Metadata],
185+
metadata: ArrayV3Metadata | ArrayV3MetadataDict,
186+
store_path: StorePath,
187+
order: Literal["C", "F"] | None = None,
188+
) -> None: ...
189+
173190
def __init__(
174191
self,
175-
metadata: ArrayV2Metadata | ArrayV3Metadata | dict[str, Any],
192+
metadata: ArrayMetadata | ArrayMetadataDict,
176193
store_path: StorePath,
177194
order: Literal["C", "F"] | None = None,
178195
) -> None:
179196
if isinstance(metadata, dict):
180197
zarr_format = metadata["zarr_format"]
198+
# TODO: remove this when we extensively type the dict representation of metadata
199+
_metadata = cast(dict[str, JSON], metadata)
181200
if zarr_format == 2:
182-
metadata = ArrayV2Metadata.from_dict(metadata)
201+
metadata = ArrayV2Metadata.from_dict(_metadata)
202+
elif zarr_format == 3:
203+
metadata = ArrayV3Metadata.from_dict(_metadata)
183204
else:
184-
metadata = ArrayV3Metadata.from_dict(metadata)
205+
raise ValueError(f"Invalid zarr_format: {zarr_format}. Expected 2 or 3")
185206

186207
metadata_parsed = parse_array_metadata(metadata)
187208
order_parsed = parse_indexing_order(order or config.get("array.order"))
@@ -222,7 +243,7 @@ async def create(
222243
# runtime
223244
exists_ok: bool = False,
224245
data: npt.ArrayLike | None = None,
225-
) -> AsyncArray[ArrayV2Metadata]:...
246+
) -> AsyncArray[ArrayV2Metadata]: ...
226247

227248
@overload
228249
@classmethod
@@ -255,7 +276,74 @@ async def create(
255276
# runtime
256277
exists_ok: bool = False,
257278
data: npt.ArrayLike | None = None,
258-
) -> AsyncArray[ArrayV3Metadata]:...
279+
) -> AsyncArray[ArrayV3Metadata]: ...
280+
281+
# this overload is necessary to handle the case where the `zarr_format` kwarg is unspecified
282+
@overload
283+
@classmethod
284+
async def create(
285+
cls,
286+
store: StoreLike,
287+
*,
288+
# v2 and v3
289+
shape: ShapeLike,
290+
dtype: npt.DTypeLike,
291+
zarr_format: Literal[3] = 3,
292+
fill_value: Any | None = None,
293+
attributes: dict[str, JSON] | None = None,
294+
# v3 only
295+
chunk_shape: ChunkCoords | None = None,
296+
chunk_key_encoding: (
297+
ChunkKeyEncoding
298+
| tuple[Literal["default"], Literal[".", "/"]]
299+
| tuple[Literal["v2"], Literal[".", "/"]]
300+
| None
301+
) = None,
302+
codecs: Iterable[Codec | dict[str, JSON]] | None = None,
303+
dimension_names: Iterable[str] | None = None,
304+
# v2 only
305+
chunks: ShapeLike | None = None,
306+
dimension_separator: Literal[".", "/"] | None = None,
307+
order: Literal["C", "F"] | None = None,
308+
filters: list[dict[str, JSON]] | None = None,
309+
compressor: dict[str, JSON] | None = None,
310+
# runtime
311+
exists_ok: bool = False,
312+
data: npt.ArrayLike | None = None,
313+
) -> AsyncArray[ArrayV3Metadata]: ...
314+
315+
@overload
316+
@classmethod
317+
async def create(
318+
cls,
319+
store: StoreLike,
320+
*,
321+
# v2 and v3
322+
shape: ShapeLike,
323+
dtype: npt.DTypeLike,
324+
zarr_format: ZarrFormat,
325+
fill_value: Any | None = None,
326+
attributes: dict[str, JSON] | None = None,
327+
# v3 only
328+
chunk_shape: ChunkCoords | None = None,
329+
chunk_key_encoding: (
330+
ChunkKeyEncoding
331+
| tuple[Literal["default"], Literal[".", "/"]]
332+
| tuple[Literal["v2"], Literal[".", "/"]]
333+
| None
334+
) = None,
335+
codecs: Iterable[Codec | dict[str, JSON]] | None = None,
336+
dimension_names: Iterable[str] | None = None,
337+
# v2 only
338+
chunks: ShapeLike | None = None,
339+
dimension_separator: Literal[".", "/"] | None = None,
340+
order: Literal["C", "F"] | None = None,
341+
filters: list[dict[str, JSON]] | None = None,
342+
compressor: dict[str, JSON] | None = None,
343+
# runtime
344+
exists_ok: bool = False,
345+
data: npt.ArrayLike | None = None,
346+
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]: ...
259347

260348
@classmethod
261349
async def create(
@@ -471,7 +559,9 @@ async def open(
471559
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]:
472560
store_path = await make_store_path(store)
473561
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
474-
return cls(store_path=store_path, metadata=metadata_dict)
562+
# TODO: remove this cast when we have better type hints
563+
_metadata_dict = cast(ArrayV3MetadataDict, metadata_dict)
564+
return cls(store_path=store_path, metadata=_metadata_dict)
475565

476566
@property
477567
def store(self) -> Store:
@@ -700,9 +790,7 @@ async def getitem(
700790
)
701791
return await self._get_selection(indexer, prototype=prototype)
702792

703-
async def _save_metadata(
704-
self, metadata: ArrayV2Metadata | ArrayV3Metadata, ensure_parents: bool = False
705-
) -> None:
793+
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
706794
to_save = metadata.to_buffer_dict(default_buffer_prototype())
707795
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
708796

@@ -948,7 +1036,7 @@ def basename(self) -> str | None:
9481036
return self._async_array.basename
9491037

9501038
@property
951-
def metadata(self) -> ArrayV2Metadata | ArrayV3Metadata:
1039+
def metadata(self) -> ArrayMetadata:
9521040
return self._async_array.metadata
9531041

9541042
@property
@@ -2378,7 +2466,8 @@ def resize(self, new_shape: ChunkCoords) -> Array:
23782466
the data falling outside the new array but inside the boundary chunks
23792467
would be restored by a subsequent resize operation that grows the array size.
23802468
"""
2381-
return type(self)(sync(self._async_array.resize(new_shape)))
2469+
resized = sync(self._async_array.resize(new_shape))
2470+
return type(self)(resized)
23822471

23832472
def update_attributes(self, new_attributes: dict[str, JSON]) -> Array:
23842473
return type(self)(sync(self._async_array.update_attributes(new_attributes)))

src/zarr/core/metadata/common.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,16 @@
22

33
from typing import TYPE_CHECKING
44

5-
if TYPE_CHECKING:
6-
from typing import Any, Literal, Self
7-
8-
from zarr.core.array_spec import ArraySpec
9-
from zarr.core.buffer import Buffer, BufferPrototype
10-
from zarr.core.chunk_grids import ChunkGrid
11-
from zarr.core.common import JSON, ChunkCoords, ZarrFormat
12-
13-
from abc import ABC, abstractmethod
14-
from dataclasses import dataclass
15-
16-
from zarr.abc.metadata import Metadata
17-
5+
from .v2 import ArrayV2Metadata, ArrayV2MetadataDict
6+
from .v3 import ArrayV3Metadata, ArrayV3MetadataDict
187

19-
@dataclass(frozen=True, kw_only=True)
20-
class ArrayMetadata(Metadata, ABC):
21-
shape: ChunkCoords
22-
fill_value: Any
23-
chunk_grid: ChunkGrid
24-
attributes: dict[str, JSON]
25-
zarr_format: ZarrFormat
26-
27-
@property
28-
@abstractmethod
29-
def ndim(self) -> int:
30-
pass
31-
32-
@abstractmethod
33-
def get_chunk_spec(
34-
self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype
35-
) -> ArraySpec:
36-
pass
37-
38-
@abstractmethod
39-
def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
40-
pass
41-
42-
@abstractmethod
43-
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
44-
pass
8+
if TYPE_CHECKING:
9+
from typing import TypeAlias
4510

46-
@abstractmethod
47-
def update_shape(self, shape: ChunkCoords) -> Self:
48-
pass
11+
from zarr.core.common import JSON
4912

50-
@abstractmethod
51-
def update_attributes(self, attributes: dict[str, JSON]) -> Self:
52-
pass
13+
ArrayMetadata: TypeAlias = ArrayV2Metadata | ArrayV3Metadata
14+
ArrayMetadataDict: TypeAlias = ArrayV2MetadataDict | ArrayV3MetadataDict
5315

5416

5517
def parse_attributes(data: None | dict[str, JSON]) -> dict[str, JSON]:

src/zarr/core/metadata/v2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Iterable
44
from enum import Enum
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, TypedDict
66

77
from zarr.abc.metadata import Metadata
88

@@ -28,6 +28,15 @@
2828
from zarr.core.metadata.common import parse_attributes
2929

3030

31+
class ArrayV2MetadataDict(TypedDict):
32+
"""
33+
A typed dictionary model for zarr v2 metadata.
34+
"""
35+
36+
zarr_format: Literal[2]
37+
attributes: dict[str, JSON]
38+
39+
3140
@dataclass(frozen=True, kw_only=True)
3241
class ArrayV2Metadata(Metadata):
3342
shape: ChunkCoords

src/zarr/core/metadata/v3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import TYPE_CHECKING, cast, overload
4+
from typing import TYPE_CHECKING, TypedDict, cast, overload
55

66
from zarr.abc.metadata import Metadata
77

@@ -151,6 +151,15 @@ def _replace_special_floats(obj: object) -> Any:
151151
return obj
152152

153153

154+
class ArrayV3MetadataDict(TypedDict):
155+
"""
156+
A typed dictionary model for zarr v3 metadata.
157+
"""
158+
159+
zarr_format: Literal[3]
160+
attributes: dict[str, JSON]
161+
162+
154163
@dataclass(frozen=True, kw_only=True)
155164
class ArrayV3Metadata(Metadata):
156165
shape: ChunkCoords

0 commit comments

Comments
 (0)