44from asyncio import gather
55from dataclasses import dataclass , field , replace
66from logging import getLogger
7- from typing import TYPE_CHECKING , Any , Literal , cast
7+ from typing import TYPE_CHECKING , Any , Generic , Literal , TypeVar , cast
88
99import numpy as np
1010import numpy .typing as npt
7373
7474if TYPE_CHECKING :
7575 from collections .abc import Iterable , Iterator , Sequence
76+ from typing import Self
7677
7778 from zarr .abc .codec import Codec , CodecPipeline
7879 from zarr .core .group import AsyncGroup
79- from zarr .core .metadata .common import ArrayMetadata
8080
8181# Array and AsyncArray are defined in the base ``zarr`` namespace
8282__all__ = ["create_codec_pipeline" , "parse_array_metadata" ]
@@ -160,16 +160,19 @@ async def get_array_metadata(
160160 return metadata_dict
161161
162162
163+ TArrayMeta = TypeVar ("TArrayMeta" , ArrayV2Metadata , ArrayV3Metadata )
164+
165+
163166@dataclass (frozen = True )
164- class AsyncArray :
165- metadata : ArrayMetadata
167+ class AsyncArray ( Generic [ TArrayMeta ]) :
168+ metadata : TArrayMeta
166169 store_path : StorePath
167170 codec_pipeline : CodecPipeline = field (init = False )
168171 order : Literal ["C" , "F" ]
169172
170173 def __init__ (
171174 self ,
172- metadata : ArrayMetadata | dict [str , Any ],
175+ metadata : ArrayV2Metadata | ArrayV3Metadata | dict [str , Any ],
173176 store_path : StorePath ,
174177 order : Literal ["C" , "F" ] | None = None ,
175178 ) -> None :
@@ -218,7 +221,7 @@ async def create(
218221 # runtime
219222 exists_ok : bool = False ,
220223 data : npt .ArrayLike | None = None ,
221- ) -> AsyncArray :
224+ ) -> AsyncArray [ ArrayV2Metadata ] | AsyncArray [ ArrayV3Metadata ] :
222225 store_path = await make_store_path (store )
223226
224227 shape = parse_shapelike (shape )
@@ -231,7 +234,7 @@ async def create(
231234 _chunks = normalize_chunks (chunks , shape , dtype .itemsize )
232235 else :
233236 _chunks = normalize_chunks (chunk_shape , shape , dtype .itemsize )
234-
237+ result : AsyncArray [ ArrayV3Metadata ] | AsyncArray [ ArrayV2Metadata ]
235238 if zarr_format == 3 :
236239 if dimension_separator is not None :
237240 raise ValueError (
@@ -313,7 +316,7 @@ async def _create_v3(
313316 dimension_names : Iterable [str ] | None = None ,
314317 attributes : dict [str , JSON ] | None = None ,
315318 exists_ok : bool = False ,
316- ) -> AsyncArray :
319+ ) -> AsyncArray [ ArrayV3Metadata ] :
317320 if not exists_ok :
318321 await ensure_no_existing_node (store_path , zarr_format = 3 )
319322
@@ -344,7 +347,9 @@ async def _create_v3(
344347
345348 array = cls (metadata = metadata , store_path = store_path )
346349 await array ._save_metadata (metadata , ensure_parents = True )
347- return array
350+ # type inference is inconsistent here and seems to conclude
351+ # that array has type Array[ArrayV2Metadata]
352+ return array # type: ignore[return-value]
348353
349354 @classmethod
350355 async def _create_v2 (
@@ -361,7 +366,7 @@ async def _create_v2(
361366 compressor : dict [str , JSON ] | None = None ,
362367 attributes : dict [str , JSON ] | None = None ,
363368 exists_ok : bool = False ,
364- ) -> AsyncArray :
369+ ) -> AsyncArray [ ArrayV2Metadata ] :
365370 if not exists_ok :
366371 await ensure_no_existing_node (store_path , zarr_format = 2 )
367372 if order is None :
@@ -383,14 +388,14 @@ async def _create_v2(
383388 )
384389 array = cls (metadata = metadata , store_path = store_path )
385390 await array ._save_metadata (metadata , ensure_parents = True )
386- return array
391+ return array # type: ignore[return-value]
387392
388393 @classmethod
389394 def from_dict (
390395 cls ,
391396 store_path : StorePath ,
392397 data : dict [str , JSON ],
393- ) -> AsyncArray :
398+ ) -> AsyncArray [ ArrayV3Metadata ] | AsyncArray [ ArrayV2Metadata ] :
394399 metadata = parse_array_metadata (data )
395400 return cls (metadata = metadata , store_path = store_path )
396401
@@ -399,7 +404,7 @@ async def open(
399404 cls ,
400405 store : StoreLike ,
401406 zarr_format : ZarrFormat | None = 3 ,
402- ) -> AsyncArray :
407+ ) -> AsyncArray [ ArrayV3Metadata ] | AsyncArray [ ArrayV2Metadata ] :
403408 store_path = await make_store_path (store )
404409 metadata_dict = await get_array_metadata (store_path , zarr_format = zarr_format )
405410 return cls (store_path = store_path , metadata = metadata_dict )
@@ -631,7 +636,9 @@ async def getitem(
631636 )
632637 return await self ._get_selection (indexer , prototype = prototype )
633638
634- async def _save_metadata (self , metadata : ArrayMetadata , ensure_parents : bool = False ) -> None :
639+ async def _save_metadata (
640+ self , metadata : ArrayV2Metadata | ArrayV3Metadata , ensure_parents : bool = False
641+ ) -> None :
635642 to_save = metadata .to_buffer_dict (default_buffer_prototype ())
636643 awaitables = [set_or_delete (self .store_path / key , value ) for key , value in to_save .items ()]
637644
@@ -719,9 +726,7 @@ async def setitem(
719726 )
720727 return await self ._set_selection (indexer , value , prototype = prototype )
721728
722- async def resize (
723- self , new_shape : ChunkCoords , delete_outside_chunks : bool = True
724- ) -> AsyncArray :
729+ async def resize (self , new_shape : ChunkCoords , delete_outside_chunks : bool = True ) -> Self :
725730 assert len (new_shape ) == len (self .metadata .shape )
726731 new_metadata = self .metadata .update_shape (new_shape )
727732
@@ -747,7 +752,7 @@ async def _delete_key(key: str) -> None:
747752 await self ._save_metadata (new_metadata )
748753 return replace (self , metadata = new_metadata )
749754
750- async def update_attributes (self , new_attributes : dict [str , JSON ]) -> AsyncArray :
755+ async def update_attributes (self , new_attributes : dict [str , JSON ]) -> Self :
751756 new_metadata = self .metadata .update_attributes (new_attributes )
752757
753758 # Write new metadata
@@ -763,7 +768,7 @@ async def info(self) -> None:
763768
764769@dataclass (frozen = True )
765770class Array :
766- _async_array : AsyncArray
771+ _async_array : AsyncArray [ ArrayV3Metadata ] | AsyncArray [ ArrayV2Metadata ]
767772
768773 @classmethod
769774 @_deprecate_positional_args
@@ -879,7 +884,7 @@ def basename(self) -> str | None:
879884 return self ._async_array .basename
880885
881886 @property
882- def metadata (self ) -> ArrayMetadata :
887+ def metadata (self ) -> ArrayV2Metadata | ArrayV3Metadata :
883888 return self ._async_array .metadata
884889
885890 @property
@@ -2309,18 +2314,10 @@ def resize(self, new_shape: ChunkCoords) -> Array:
23092314 the data falling outside the new array but inside the boundary chunks
23102315 would be restored by a subsequent resize operation that grows the array size.
23112316 """
2312- return type (self )(
2313- sync (
2314- self ._async_array .resize (new_shape ),
2315- )
2316- )
2317+ return type (self )(sync (self ._async_array .resize (new_shape ))) # type: ignore[arg-type]
23172318
23182319 def update_attributes (self , new_attributes : dict [str , JSON ]) -> Array :
2319- return type (self )(
2320- sync (
2321- self ._async_array .update_attributes (new_attributes ),
2322- )
2323- )
2320+ return type (self )(sync (self ._async_array .update_attributes (new_attributes ))) # type: ignore[arg-type]
23242321
23252322 def __repr__ (self ) -> str :
23262323 return f"<Array { self .store_path } shape={ self .shape } dtype={ self .dtype } >"
@@ -2331,7 +2328,9 @@ def info(self) -> None:
23312328 )
23322329
23332330
2334- def nchunks_initialized (array : AsyncArray | Array ) -> int :
2331+ def nchunks_initialized (
2332+ array : AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ] | Array ,
2333+ ) -> int :
23352334 """
23362335 Calculate the number of chunks that have been initialized, i.e. the number of chunks that have
23372336 been persisted to the storage backend.
@@ -2353,7 +2352,9 @@ def nchunks_initialized(array: AsyncArray | Array) -> int:
23532352 return len (chunks_initialized (array ))
23542353
23552354
2356- def chunks_initialized (array : Array | AsyncArray ) -> tuple [str , ...]:
2355+ def chunks_initialized (
2356+ array : Array | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ],
2357+ ) -> tuple [str , ...]:
23572358 """
23582359 Return the keys of the chunks that have been persisted to the storage backend.
23592360
@@ -2385,7 +2386,9 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
23852386 return tuple (out )
23862387
23872388
2388- def _build_parents (node : AsyncArray | AsyncGroup ) -> list [AsyncGroup ]:
2389+ def _build_parents (
2390+ node : AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ] | AsyncGroup ,
2391+ ) -> list [AsyncGroup ]:
23892392 from zarr .core .group import AsyncGroup , GroupMetadata
23902393
23912394 store = node .store_path .store
0 commit comments