11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , TypedDict
3+ from collections .abc import Mapping
4+ from typing import TYPE_CHECKING , NotRequired , TypedDict , TypeGuard , cast
45
56from zarr .abc .metadata import Metadata
67from zarr .core .buffer .core import default_buffer_prototype
3334 JSON ,
3435 ZARR_JSON ,
3536 DimensionNames ,
37+ NamedConfig ,
3638 parse_named_configuration ,
3739 parse_shapelike ,
3840)
@@ -136,13 +138,61 @@ def parse_storage_transformers(data: object) -> tuple[dict[str, JSON], ...]:
136138 )
137139
138140
139- class ArrayV3MetadataDict (TypedDict ):
141+ class AllowedExtraField (TypedDict ):
142+ """
143+ This class models allowed extra fields in array metadata.
144+ They are ignored by Zarr Python.
145+ """
146+
147+ must_understand : Literal [False ]
148+
149+
150+ def check_allowed_extra_field (data : object ) -> TypeGuard [AllowedExtraField ]:
151+ """
152+ Check if the extra field is allowed according to the Zarr v3 spec. The object
153+ must be a mapping with a "must_understand" key set to `False`.
154+ """
155+ return isinstance (data , Mapping ) and data .get ("must_understand" ) is False
156+
157+
158+ def parse_extra_fields (
159+ data : Mapping [str , AllowedExtraField ] | None ,
160+ ) -> dict [str , AllowedExtraField ]:
161+ if data is None :
162+ return {}
163+ else :
164+ conflict_keys = ARRAY_METADATA_KEYS & set (data .keys ())
165+ if len (conflict_keys ) > 0 :
166+ msg = (
167+ "Invalid extra fields. "
168+ "The following keys: "
169+ f"{ sorted (conflict_keys )} "
170+ "are invalid because they collide with keys reserved for use by the "
171+ "array metadata document."
172+ )
173+ raise ValueError (msg )
174+ return dict (data )
175+
176+
177+ class ArrayMetadataJSON_V3 (TypedDict ):
140178 """
141179 A typed dictionary model for zarr v3 metadata.
142180 """
143181
144182 zarr_format : Literal [3 ]
145- attributes : dict [str , JSON ]
183+ node_type : Literal ["array" ]
184+ data_type : str | NamedConfig [str , Mapping [str , object ]]
185+ shape : tuple [int , ...]
186+ chunk_grid : NamedConfig [str , Mapping [str , object ]]
187+ chunk_key_encoding : NamedConfig [str , Mapping [str , object ]]
188+ fill_value : object
189+ codecs : tuple [str | NamedConfig [str , Mapping [str , object ]], ...]
190+ attributes : NotRequired [Mapping [str , JSON ]]
191+ storage_transformers : NotRequired [tuple [NamedConfig [str , Mapping [str , object ]], ...]]
192+ dimension_names : NotRequired [tuple [str | None ]]
193+
194+
195+ ARRAY_METADATA_KEYS = set (ArrayMetadataJSON_V3 .__annotations__ .keys ())
146196
147197
148198@dataclass (frozen = True , kw_only = True )
@@ -158,19 +208,21 @@ class ArrayV3Metadata(Metadata):
158208 zarr_format : Literal [3 ] = field (default = 3 , init = False )
159209 node_type : Literal ["array" ] = field (default = "array" , init = False )
160210 storage_transformers : tuple [dict [str , JSON ], ...]
211+ extra_fields : dict [str , AllowedExtraField ]
161212
162213 def __init__ (
163214 self ,
164215 * ,
165216 shape : Iterable [int ],
166217 data_type : ZDType [TBaseDType , TBaseScalar ],
167- chunk_grid : dict [str , JSON ] | ChunkGrid ,
218+ chunk_grid : dict [str , JSON ] | ChunkGrid | NamedConfig [ str , Any ] ,
168219 chunk_key_encoding : ChunkKeyEncodingLike ,
169220 fill_value : object ,
170- codecs : Iterable [Codec | dict [str , JSON ]],
221+ codecs : Iterable [Codec | dict [str , JSON ] | NamedConfig [ str , Any ] | str ],
171222 attributes : dict [str , JSON ] | None ,
172223 dimension_names : DimensionNames ,
173224 storage_transformers : Iterable [dict [str , JSON ]] | None = None ,
225+ extra_fields : Mapping [str , AllowedExtraField ] | None = None ,
174226 ) -> None :
175227 """
176228 Because the class is a frozen dataclass, we set attributes using object.__setattr__
@@ -185,7 +237,7 @@ def __init__(
185237 attributes_parsed = parse_attributes (attributes )
186238 codecs_parsed_partial = parse_codecs (codecs )
187239 storage_transformers_parsed = parse_storage_transformers (storage_transformers )
188-
240+ extra_fields_parsed = parse_extra_fields ( extra_fields )
189241 array_spec = ArraySpec (
190242 shape = shape_parsed ,
191243 dtype = data_type ,
@@ -205,6 +257,7 @@ def __init__(
205257 object .__setattr__ (self , "fill_value" , fill_value_parsed )
206258 object .__setattr__ (self , "attributes" , attributes_parsed )
207259 object .__setattr__ (self , "storage_transformers" , storage_transformers_parsed )
260+ object .__setattr__ (self , "extra_fields" , extra_fields_parsed )
208261
209262 self ._validate_metadata ()
210263
@@ -323,16 +376,45 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
323376 except ValueError as e :
324377 raise TypeError (f"Invalid fill_value: { fill !r} " ) from e
325378
326- # dimension_names key is optional, normalize missing to `None`
327- _data ["dimension_names" ] = _data .pop ("dimension_names" , None )
328-
329- # attributes key is optional, normalize missing to `None`
330- _data ["attributes" ] = _data .pop ("attributes" , None )
331-
332- return cls (** _data , fill_value = fill_value_parsed , data_type = data_type ) # type: ignore[arg-type]
379+ # check if there are extra keys
380+ extra_keys = set (_data .keys ()) - ARRAY_METADATA_KEYS
381+ allowed_extra_fields : dict [str , AllowedExtraField ] = {}
382+ invalid_extra_fields = {}
383+ for key in extra_keys :
384+ val = _data [key ]
385+ if check_allowed_extra_field (val ):
386+ allowed_extra_fields [key ] = val
387+ else :
388+ invalid_extra_fields [key ] = val
389+ if len (invalid_extra_fields ) > 0 :
390+ msg = (
391+ "Got a Zarr V3 metadata document with the following disallowed extra fields:"
392+ f"{ sorted (invalid_extra_fields .keys ())} ."
393+ 'Extra fields are not allowed unless they are a dict with a "must_understand" key'
394+ "which is assigned the value `False`."
395+ )
396+ raise MetadataValidationError (msg )
397+ # TODO: replace this with a real type check!
398+ _data_typed = cast (ArrayMetadataJSON_V3 , _data )
399+
400+ return cls (
401+ shape = _data_typed ["shape" ],
402+ chunk_grid = _data_typed ["chunk_grid" ],
403+ chunk_key_encoding = _data_typed ["chunk_key_encoding" ],
404+ codecs = _data_typed ["codecs" ],
405+ attributes = _data_typed .get ("attributes" , {}), # type: ignore[arg-type]
406+ dimension_names = _data_typed .get ("dimension_names" , None ),
407+ fill_value = fill_value_parsed ,
408+ data_type = data_type ,
409+ extra_fields = allowed_extra_fields ,
410+ storage_transformers = _data_typed .get ("storage_transformers" , ()), # type: ignore[arg-type]
411+ )
333412
334413 def to_dict (self ) -> dict [str , JSON ]:
335414 out_dict = super ().to_dict ()
415+ extra_fields = out_dict .pop ("extra_fields" )
416+ out_dict = out_dict | extra_fields # type: ignore[operator]
417+
336418 out_dict ["fill_value" ] = self .data_type .to_json_scalar (
337419 self .fill_value , zarr_format = self .zarr_format
338420 )
@@ -351,7 +433,6 @@ def to_dict(self) -> dict[str, JSON]:
351433 dtype_meta = out_dict ["data_type" ]
352434 if isinstance (dtype_meta , ZDType ):
353435 out_dict ["data_type" ] = dtype_meta .to_json (zarr_format = 3 ) # type: ignore[unreachable]
354-
355436 return out_dict
356437
357438 def update_shape (self , shape : tuple [int , ...]) -> Self :
0 commit comments