6464 is_scalar ,
6565 pop_fields ,
6666)
67- from zarr .core .metadata .v2 import ArrayV2Metadata
68- from zarr .core .metadata .v3 import ArrayV3Metadata
67+ from zarr .core .metadata .common import ArrayMetadata , ArrayMetadataDict
68+ from zarr .core .metadata .v2 import ArrayV2Metadata , ArrayV2MetadataDict
69+ from zarr .core .metadata .v3 import ArrayV3Metadata , ArrayV3MetadataDict
6970from zarr .core .sync import collect_aiterator , sync
7071from zarr .registry import get_pipeline_class
7172from zarr .storage import StoreLike , make_store_path
8485logger = getLogger (__name__ )
8586
8687
87- def parse_array_metadata (data : Any ) -> ArrayV2Metadata | ArrayV3Metadata :
88- if isinstance (data , ArrayV2Metadata | ArrayV3Metadata ):
88+ def parse_array_metadata (data : Any ) -> ArrayMetadata :
89+ if isinstance (data , ArrayMetadata ):
8990 return data
9091 elif isinstance (data , dict ):
9192 if data ["zarr_format" ] == 3 :
@@ -102,7 +103,7 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata:
102103 raise TypeError
103104
104105
105- def create_codec_pipeline (metadata : ArrayV2Metadata | ArrayV3Metadata ) -> CodecPipeline :
106+ def create_codec_pipeline (metadata : ArrayMetadata ) -> CodecPipeline :
106107 if isinstance (metadata , ArrayV3Metadata ):
107108 return get_pipeline_class ().from_codecs (metadata .codecs )
108109 elif isinstance (metadata , ArrayV2Metadata ):
@@ -115,7 +116,7 @@ def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> CodecP
115116
116117async def get_array_metadata (
117118 store_path : StorePath , zarr_format : ZarrFormat | None = 3
118- ) -> dict [str , Any ]:
119+ ) -> dict [str , JSON ]:
119120 if zarr_format == 2 :
120121 zarray_bytes , zattrs_bytes = await gather (
121122 (store_path / ZARRAY_JSON ).get (), (store_path / ZATTRS_JSON ).get ()
@@ -146,7 +147,7 @@ async def get_array_metadata(
146147 else :
147148 raise ValueError (f"unexpected zarr_format: { zarr_format } " )
148149
149- metadata_dict : dict [str , Any ]
150+ metadata_dict : dict [str , JSON ]
150151 if zarr_format == 2 :
151152 # V2 arrays are comprised of a .zarray and .zattrs objects
152153 assert zarray_bytes is not None
@@ -170,18 +171,38 @@ class AsyncArray(Generic[TArrayMeta]):
170171 codec_pipeline : CodecPipeline = field (init = False )
171172 order : Literal ["C" , "F" ]
172173
174+ @overload
175+ def __init__ (
176+ self : AsyncArray [ArrayV2Metadata ],
177+ metadata : ArrayV2Metadata | ArrayV2MetadataDict ,
178+ store_path : StorePath ,
179+ order : Literal ["C" , "F" ] | None = None ,
180+ ) -> None : ...
181+
182+ @overload
183+ def __init__ (
184+ self : AsyncArray [ArrayV3Metadata ],
185+ metadata : ArrayV3Metadata | ArrayV3MetadataDict ,
186+ store_path : StorePath ,
187+ order : Literal ["C" , "F" ] | None = None ,
188+ ) -> None : ...
189+
173190 def __init__ (
174191 self ,
175- metadata : ArrayV2Metadata | ArrayV3Metadata | dict [ str , Any ] ,
192+ metadata : ArrayMetadata | ArrayMetadataDict ,
176193 store_path : StorePath ,
177194 order : Literal ["C" , "F" ] | None = None ,
178195 ) -> None :
179196 if isinstance (metadata , dict ):
180197 zarr_format = metadata ["zarr_format" ]
198+ # TODO: remove this when we extensively type the dict representation of metadata
199+ _metadata = cast (dict [str , JSON ], metadata )
181200 if zarr_format == 2 :
182- metadata = ArrayV2Metadata .from_dict (metadata )
201+ metadata = ArrayV2Metadata .from_dict (_metadata )
202+ elif zarr_format == 3 :
203+ metadata = ArrayV3Metadata .from_dict (_metadata )
183204 else :
184- metadata = ArrayV3Metadata . from_dict ( metadata )
205+ raise ValueError ( f"Invalid zarr_format: { zarr_format } . Expected 2 or 3" )
185206
186207 metadata_parsed = parse_array_metadata (metadata )
187208 order_parsed = parse_indexing_order (order or config .get ("array.order" ))
@@ -222,7 +243,7 @@ async def create(
222243 # runtime
223244 exists_ok : bool = False ,
224245 data : npt .ArrayLike | None = None ,
225- ) -> AsyncArray [ArrayV2Metadata ]:...
246+ ) -> AsyncArray [ArrayV2Metadata ]: ...
226247
227248 @overload
228249 @classmethod
@@ -255,7 +276,74 @@ async def create(
255276 # runtime
256277 exists_ok : bool = False ,
257278 data : npt .ArrayLike | None = None ,
258- ) -> AsyncArray [ArrayV3Metadata ]:...
279+ ) -> AsyncArray [ArrayV3Metadata ]: ...
280+
281+ # this overload is necessary to handle the case where the `zarr_format` kwarg is unspecified
282+ @overload
283+ @classmethod
284+ async def create (
285+ cls ,
286+ store : StoreLike ,
287+ * ,
288+ # v2 and v3
289+ shape : ShapeLike ,
290+ dtype : npt .DTypeLike ,
291+ zarr_format : Literal [3 ] = 3 ,
292+ fill_value : Any | None = None ,
293+ attributes : dict [str , JSON ] | None = None ,
294+ # v3 only
295+ chunk_shape : ChunkCoords | None = None ,
296+ chunk_key_encoding : (
297+ ChunkKeyEncoding
298+ | tuple [Literal ["default" ], Literal ["." , "/" ]]
299+ | tuple [Literal ["v2" ], Literal ["." , "/" ]]
300+ | None
301+ ) = None ,
302+ codecs : Iterable [Codec | dict [str , JSON ]] | None = None ,
303+ dimension_names : Iterable [str ] | None = None ,
304+ # v2 only
305+ chunks : ShapeLike | None = None ,
306+ dimension_separator : Literal ["." , "/" ] | None = None ,
307+ order : Literal ["C" , "F" ] | None = None ,
308+ filters : list [dict [str , JSON ]] | None = None ,
309+ compressor : dict [str , JSON ] | None = None ,
310+ # runtime
311+ exists_ok : bool = False ,
312+ data : npt .ArrayLike | None = None ,
313+ ) -> AsyncArray [ArrayV3Metadata ]: ...
314+
315+ @overload
316+ @classmethod
317+ async def create (
318+ cls ,
319+ store : StoreLike ,
320+ * ,
321+ # v2 and v3
322+ shape : ShapeLike ,
323+ dtype : npt .DTypeLike ,
324+ zarr_format : ZarrFormat ,
325+ fill_value : Any | None = None ,
326+ attributes : dict [str , JSON ] | None = None ,
327+ # v3 only
328+ chunk_shape : ChunkCoords | None = None ,
329+ chunk_key_encoding : (
330+ ChunkKeyEncoding
331+ | tuple [Literal ["default" ], Literal ["." , "/" ]]
332+ | tuple [Literal ["v2" ], Literal ["." , "/" ]]
333+ | None
334+ ) = None ,
335+ codecs : Iterable [Codec | dict [str , JSON ]] | None = None ,
336+ dimension_names : Iterable [str ] | None = None ,
337+ # v2 only
338+ chunks : ShapeLike | None = None ,
339+ dimension_separator : Literal ["." , "/" ] | None = None ,
340+ order : Literal ["C" , "F" ] | None = None ,
341+ filters : list [dict [str , JSON ]] | None = None ,
342+ compressor : dict [str , JSON ] | None = None ,
343+ # runtime
344+ exists_ok : bool = False ,
345+ data : npt .ArrayLike | None = None ,
346+ ) -> AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ]: ...
259347
260348 @classmethod
261349 async def create (
@@ -471,7 +559,9 @@ async def open(
471559 ) -> AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ]:
472560 store_path = await make_store_path (store )
473561 metadata_dict = await get_array_metadata (store_path , zarr_format = zarr_format )
474- return cls (store_path = store_path , metadata = metadata_dict )
562+ # TODO: remove this cast when we have better type hints
563+ _metadata_dict = cast (ArrayV3MetadataDict , metadata_dict )
564+ return cls (store_path = store_path , metadata = _metadata_dict )
475565
476566 @property
477567 def store (self ) -> Store :
@@ -700,9 +790,7 @@ async def getitem(
700790 )
701791 return await self ._get_selection (indexer , prototype = prototype )
702792
703- async def _save_metadata (
704- self , metadata : ArrayV2Metadata | ArrayV3Metadata , ensure_parents : bool = False
705- ) -> None :
793+ async def _save_metadata (self , metadata : ArrayMetadata , ensure_parents : bool = False ) -> None :
706794 to_save = metadata .to_buffer_dict (default_buffer_prototype ())
707795 awaitables = [set_or_delete (self .store_path / key , value ) for key , value in to_save .items ()]
708796
@@ -948,7 +1036,7 @@ def basename(self) -> str | None:
9481036 return self ._async_array .basename
9491037
9501038 @property
951- def metadata (self ) -> ArrayV2Metadata | ArrayV3Metadata :
1039+ def metadata (self ) -> ArrayMetadata :
9521040 return self ._async_array .metadata
9531041
9541042 @property
@@ -2378,7 +2466,8 @@ def resize(self, new_shape: ChunkCoords) -> Array:
23782466 the data falling outside the new array but inside the boundary chunks
23792467 would be restored by a subsequent resize operation that grows the array size.
23802468 """
2381- return type (self )(sync (self ._async_array .resize (new_shape )))
2469+ resized = sync (self ._async_array .resize (new_shape ))
2470+ return type (self )(resized )
23822471
23832472 def update_attributes (self , new_attributes : dict [str , JSON ]) -> Array :
23842473 return type (self )(sync (self ._async_array .update_attributes (new_attributes )))
0 commit comments