Skip to content

Commit 61ca06b

Browse files
authored
fix: accept extra fields in array metadata (#3530)
* add support for reading and preserving unknown, permitted fields in array metadata. * sort in error message * sort keys in the right place * changelog * rename changelog
1 parent e3ee591 commit 61ca06b

File tree

8 files changed

+194
-32
lines changed

8 files changed

+194
-32
lines changed

changes/3530.bugfix.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixed a bug that prevented Zarr Python from opening Zarr V3 array metadata documents that contained
2+
extra keys with permissible values (dicts with a `"must_understand"` key set to `"false"`).

src/zarr/core/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@
103103
from zarr.core.metadata import (
104104
ArrayMetadata,
105105
ArrayMetadataDict,
106+
ArrayMetadataJSON_V3,
106107
ArrayV2Metadata,
107108
ArrayV2MetadataDict,
108109
ArrayV3Metadata,
109-
ArrayV3MetadataDict,
110110
T_ArrayMetadata,
111111
)
112112
from zarr.core.metadata.io import save_metadata
@@ -319,7 +319,7 @@ def __init__(
319319
@overload
320320
def __init__(
321321
self: AsyncArray[ArrayV3Metadata],
322-
metadata: ArrayV3Metadata | ArrayV3MetadataDict,
322+
metadata: ArrayV3Metadata | ArrayMetadataJSON_V3,
323323
store_path: StorePath,
324324
config: ArrayConfigLike | None = None,
325325
) -> None: ...
@@ -1004,7 +1004,7 @@ async def example():
10041004
store_path = await make_store_path(store)
10051005
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
10061006
# TODO: remove this cast when we have better type hints
1007-
_metadata_dict = cast("ArrayV3MetadataDict", metadata_dict)
1007+
_metadata_dict = cast("ArrayMetadataJSON_V3", metadata_dict)
10081008
return cls(store_path=store_path, metadata=_metadata_dict)
10091009

10101010
@property

src/zarr/core/chunk_grids.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from zarr.abc.metadata import Metadata
1616
from zarr.core.common import (
1717
JSON,
18+
NamedConfig,
1819
ShapeLike,
1920
ceildiv,
2021
parse_named_configuration,
@@ -152,7 +153,7 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl
152153
@dataclass(frozen=True)
153154
class ChunkGrid(Metadata):
154155
@classmethod
155-
def from_dict(cls, data: dict[str, JSON] | ChunkGrid) -> ChunkGrid:
156+
def from_dict(cls, data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]) -> ChunkGrid:
156157
if isinstance(data, ChunkGrid):
157158
return data
158159

@@ -180,7 +181,7 @@ def __init__(self, *, chunk_shape: ShapeLike) -> None:
180181
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
181182

182183
@classmethod
183-
def _from_dict(cls, data: dict[str, JSON]) -> Self:
184+
def _from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self:
184185
_, configuration_parsed = parse_named_configuration(data, "regular")
185186

186187
return cls(**configuration_parsed) # type: ignore[arg-type]

src/zarr/core/chunk_key_encodings.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias, TypedDict, cast
5+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypedDict, cast
66

77
if TYPE_CHECKING:
88
from typing import NotRequired, Self
99

1010
from zarr.abc.metadata import Metadata
1111
from zarr.core.common import (
1212
JSON,
13+
NamedConfig,
1314
parse_named_configuration,
1415
)
1516
from zarr.registry import get_chunk_key_encoding_class, register_chunk_key_encoding
@@ -61,7 +62,9 @@ def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
6162
"""
6263

6364

64-
ChunkKeyEncodingLike: TypeAlias = dict[str, JSON] | ChunkKeyEncodingParams | ChunkKeyEncoding
65+
ChunkKeyEncodingLike: TypeAlias = (
66+
dict[str, JSON] | ChunkKeyEncodingParams | ChunkKeyEncoding | NamedConfig[str, Any]
67+
)
6568

6669

6770
@dataclass(frozen=True)
@@ -108,7 +111,7 @@ def parse_chunk_key_encoding(data: ChunkKeyEncodingLike) -> ChunkKeyEncoding:
108111

109112
# handle ChunkKeyEncodingParams
110113
if "name" in data and "separator" in data:
111-
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}
114+
data = {"name": data["name"], "configuration": {"separator": data["separator"]}} # type: ignore[typeddict-item]
112115

113116
# Now must be a named config
114117
data = cast("dict[str, JSON]", data)

src/zarr/core/common.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Final,
1515
Generic,
1616
Literal,
17+
NotRequired,
1718
TypedDict,
1819
TypeVar,
1920
cast,
@@ -63,8 +64,8 @@ class NamedConfig(TypedDict, Generic[TName, TConfig]):
6364
name: ReadOnly[TName]
6465
"""The name of the object."""
6566

66-
configuration: ReadOnly[TConfig]
67-
"""The configuration of the object."""
67+
configuration: NotRequired[ReadOnly[TConfig]]
68+
"""The configuration of the object. Not required."""
6869

6970

7071
def product(tup: tuple[int, ...]) -> int:
@@ -134,18 +135,24 @@ def parse_configuration(data: JSON) -> JSON:
134135

135136
@overload
136137
def parse_named_configuration(
137-
data: JSON, expected_name: str | None = None
138+
data: JSON | NamedConfig[str, Any], expected_name: str | None = None
138139
) -> tuple[str, dict[str, JSON]]: ...
139140

140141

141142
@overload
142143
def parse_named_configuration(
143-
data: JSON, expected_name: str | None = None, *, require_configuration: bool = True
144+
data: JSON | NamedConfig[str, Any],
145+
expected_name: str | None = None,
146+
*,
147+
require_configuration: bool = True,
144148
) -> tuple[str, dict[str, JSON] | None]: ...
145149

146150

147151
def parse_named_configuration(
148-
data: JSON, expected_name: str | None = None, *, require_configuration: bool = True
152+
data: JSON | NamedConfig[str, Any],
153+
expected_name: str | None = None,
154+
*,
155+
require_configuration: bool = True,
149156
) -> tuple[str, JSON | None]:
150157
if not isinstance(data, dict):
151158
raise TypeError(f"Expected dict, got {type(data)}")

src/zarr/core/metadata/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from typing import TypeAlias, TypeVar
22

33
from .v2 import ArrayV2Metadata, ArrayV2MetadataDict
4-
from .v3 import ArrayV3Metadata, ArrayV3MetadataDict
4+
from .v3 import ArrayMetadataJSON_V3, ArrayV3Metadata
55

66
ArrayMetadata: TypeAlias = ArrayV2Metadata | ArrayV3Metadata
7-
ArrayMetadataDict: TypeAlias = ArrayV2MetadataDict | ArrayV3MetadataDict
7+
ArrayMetadataDict: TypeAlias = ArrayV2MetadataDict | ArrayMetadataJSON_V3
88
T_ArrayMetadata = TypeVar("T_ArrayMetadata", ArrayV2Metadata, ArrayV3Metadata)
99

1010
__all__ = [
1111
"ArrayMetadata",
1212
"ArrayMetadataDict",
13+
"ArrayMetadataJSON_V3",
1314
"ArrayV2Metadata",
1415
"ArrayV2MetadataDict",
1516
"ArrayV3Metadata",
16-
"ArrayV3MetadataDict",
1717
]

src/zarr/core/metadata/v3.py

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, TypedDict
3+
from collections.abc import Mapping
4+
from typing import TYPE_CHECKING, NotRequired, TypedDict, TypeGuard, cast
45

56
from zarr.abc.metadata import Metadata
67
from zarr.core.buffer.core import default_buffer_prototype
@@ -33,6 +34,7 @@
3334
JSON,
3435
ZARR_JSON,
3536
DimensionNames,
37+
NamedConfig,
3638
parse_named_configuration,
3739
parse_shapelike,
3840
)
@@ -136,13 +138,61 @@ def parse_storage_transformers(data: object) -> tuple[dict[str, JSON], ...]:
136138
)
137139

138140

139-
class ArrayV3MetadataDict(TypedDict):
141+
class AllowedExtraField(TypedDict):
142+
"""
143+
This class models allowed extra fields in array metadata.
144+
They are ignored by Zarr Python.
145+
"""
146+
147+
must_understand: Literal[False]
148+
149+
150+
def check_allowed_extra_field(data: object) -> TypeGuard[AllowedExtraField]:
151+
"""
152+
Check if the extra field is allowed according to the Zarr v3 spec. The object
153+
must be a mapping with a "must_understand" key set to `False`.
154+
"""
155+
return isinstance(data, Mapping) and data.get("must_understand") is False
156+
157+
158+
def parse_extra_fields(
159+
data: Mapping[str, AllowedExtraField] | None,
160+
) -> dict[str, AllowedExtraField]:
161+
if data is None:
162+
return {}
163+
else:
164+
conflict_keys = ARRAY_METADATA_KEYS & set(data.keys())
165+
if len(conflict_keys) > 0:
166+
msg = (
167+
"Invalid extra fields. "
168+
"The following keys: "
169+
f"{sorted(conflict_keys)} "
170+
"are invalid because they collide with keys reserved for use by the "
171+
"array metadata document."
172+
)
173+
raise ValueError(msg)
174+
return dict(data)
175+
176+
177+
class ArrayMetadataJSON_V3(TypedDict):
140178
"""
141179
A typed dictionary model for zarr v3 metadata.
142180
"""
143181

144182
zarr_format: Literal[3]
145-
attributes: dict[str, JSON]
183+
node_type: Literal["array"]
184+
data_type: str | NamedConfig[str, Mapping[str, object]]
185+
shape: tuple[int, ...]
186+
chunk_grid: NamedConfig[str, Mapping[str, object]]
187+
chunk_key_encoding: NamedConfig[str, Mapping[str, object]]
188+
fill_value: object
189+
codecs: tuple[str | NamedConfig[str, Mapping[str, object]], ...]
190+
attributes: NotRequired[Mapping[str, JSON]]
191+
storage_transformers: NotRequired[tuple[NamedConfig[str, Mapping[str, object]], ...]]
192+
dimension_names: NotRequired[tuple[str | None]]
193+
194+
195+
ARRAY_METADATA_KEYS = set(ArrayMetadataJSON_V3.__annotations__.keys())
146196

147197

148198
@dataclass(frozen=True, kw_only=True)
@@ -158,19 +208,21 @@ class ArrayV3Metadata(Metadata):
158208
zarr_format: Literal[3] = field(default=3, init=False)
159209
node_type: Literal["array"] = field(default="array", init=False)
160210
storage_transformers: tuple[dict[str, JSON], ...]
211+
extra_fields: dict[str, AllowedExtraField]
161212

162213
def __init__(
163214
self,
164215
*,
165216
shape: Iterable[int],
166217
data_type: ZDType[TBaseDType, TBaseScalar],
167-
chunk_grid: dict[str, JSON] | ChunkGrid,
218+
chunk_grid: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any],
168219
chunk_key_encoding: ChunkKeyEncodingLike,
169220
fill_value: object,
170-
codecs: Iterable[Codec | dict[str, JSON]],
221+
codecs: Iterable[Codec | dict[str, JSON] | NamedConfig[str, Any] | str],
171222
attributes: dict[str, JSON] | None,
172223
dimension_names: DimensionNames,
173224
storage_transformers: Iterable[dict[str, JSON]] | None = None,
225+
extra_fields: Mapping[str, AllowedExtraField] | None = None,
174226
) -> None:
175227
"""
176228
Because the class is a frozen dataclass, we set attributes using object.__setattr__
@@ -185,7 +237,7 @@ def __init__(
185237
attributes_parsed = parse_attributes(attributes)
186238
codecs_parsed_partial = parse_codecs(codecs)
187239
storage_transformers_parsed = parse_storage_transformers(storage_transformers)
188-
240+
extra_fields_parsed = parse_extra_fields(extra_fields)
189241
array_spec = ArraySpec(
190242
shape=shape_parsed,
191243
dtype=data_type,
@@ -205,6 +257,7 @@ def __init__(
205257
object.__setattr__(self, "fill_value", fill_value_parsed)
206258
object.__setattr__(self, "attributes", attributes_parsed)
207259
object.__setattr__(self, "storage_transformers", storage_transformers_parsed)
260+
object.__setattr__(self, "extra_fields", extra_fields_parsed)
208261

209262
self._validate_metadata()
210263

@@ -323,16 +376,45 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
323376
except ValueError as e:
324377
raise TypeError(f"Invalid fill_value: {fill!r}") from e
325378

326-
# dimension_names key is optional, normalize missing to `None`
327-
_data["dimension_names"] = _data.pop("dimension_names", None)
328-
329-
# attributes key is optional, normalize missing to `None`
330-
_data["attributes"] = _data.pop("attributes", None)
331-
332-
return cls(**_data, fill_value=fill_value_parsed, data_type=data_type) # type: ignore[arg-type]
379+
# check if there are extra keys
380+
extra_keys = set(_data.keys()) - ARRAY_METADATA_KEYS
381+
allowed_extra_fields: dict[str, AllowedExtraField] = {}
382+
invalid_extra_fields = {}
383+
for key in extra_keys:
384+
val = _data[key]
385+
if check_allowed_extra_field(val):
386+
allowed_extra_fields[key] = val
387+
else:
388+
invalid_extra_fields[key] = val
389+
if len(invalid_extra_fields) > 0:
390+
msg = (
391+
"Got a Zarr V3 metadata document with the following disallowed extra fields:"
392+
f"{sorted(invalid_extra_fields.keys())}."
393+
'Extra fields are not allowed unless they are a dict with a "must_understand" key'
394+
"which is assigned the value `False`."
395+
)
396+
raise MetadataValidationError(msg)
397+
# TODO: replace this with a real type check!
398+
_data_typed = cast(ArrayMetadataJSON_V3, _data)
399+
400+
return cls(
401+
shape=_data_typed["shape"],
402+
chunk_grid=_data_typed["chunk_grid"],
403+
chunk_key_encoding=_data_typed["chunk_key_encoding"],
404+
codecs=_data_typed["codecs"],
405+
attributes=_data_typed.get("attributes", {}), # type: ignore[arg-type]
406+
dimension_names=_data_typed.get("dimension_names", None),
407+
fill_value=fill_value_parsed,
408+
data_type=data_type,
409+
extra_fields=allowed_extra_fields,
410+
storage_transformers=_data_typed.get("storage_transformers", ()), # type: ignore[arg-type]
411+
)
333412

334413
def to_dict(self) -> dict[str, JSON]:
335414
out_dict = super().to_dict()
415+
extra_fields = out_dict.pop("extra_fields")
416+
out_dict = out_dict | extra_fields # type: ignore[operator]
417+
336418
out_dict["fill_value"] = self.data_type.to_json_scalar(
337419
self.fill_value, zarr_format=self.zarr_format
338420
)
@@ -351,7 +433,6 @@ def to_dict(self) -> dict[str, JSON]:
351433
dtype_meta = out_dict["data_type"]
352434
if isinstance(dtype_meta, ZDType):
353435
out_dict["data_type"] = dtype_meta.to_json(zarr_format=3) # type: ignore[unreachable]
354-
355436
return out_dict
356437

357438
def update_shape(self, shape: tuple[int, ...]) -> Self:

0 commit comments

Comments
 (0)