Skip to content

Commit a82fb03

Browse files
committed
make asyncarray generic w.r.t metadata
1 parent 22b686f commit a82fb03

File tree

4 files changed

+121
-74
lines changed

4 files changed

+121
-74
lines changed

src/zarr/api/asynchronous.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
2929

3030
# TODO: this type could use some more thought
31-
ArrayLike = AsyncArray | Array | npt.NDArray[Any]
31+
ArrayLike = AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | Array | npt.NDArray[Any]
3232
PathLike = str
3333

3434
__all__ = [
@@ -198,7 +198,7 @@ async def open(
198198
path: str | None = None,
199199
storage_options: dict[str, Any] | None = None,
200200
**kwargs: Any, # TODO: type kwargs as valid args to open_array
201-
) -> AsyncArray | AsyncGroup:
201+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup:
202202
"""Convenience function to open a group or array using file-mode-like semantics.
203203
204204
Parameters
@@ -400,7 +400,9 @@ async def tree(*args: Any, **kwargs: Any) -> None:
400400
raise NotImplementedError
401401

402402

403-
async def array(data: npt.ArrayLike, **kwargs: Any) -> AsyncArray:
403+
async def array(
404+
data: npt.ArrayLike, **kwargs: Any
405+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
404406
"""Create an array filled with `data`.
405407
406408
Parameters
@@ -652,7 +654,7 @@ async def create(
652654
dimension_names: Iterable[str] | None = None,
653655
storage_options: dict[str, Any] | None = None,
654656
**kwargs: Any,
655-
) -> AsyncArray:
657+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
656658
"""Create an array.
657659
658660
Parameters
@@ -804,7 +806,9 @@ async def create(
804806
)
805807

806808

807-
async def empty(shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
809+
async def empty(
810+
shape: ChunkCoords, **kwargs: Any
811+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
808812
"""Create an empty array.
809813
810814
Parameters
@@ -823,7 +827,9 @@ async def empty(shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
823827
return await create(shape=shape, fill_value=None, **kwargs)
824828

825829

826-
async def empty_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
830+
async def empty_like(
831+
a: ArrayLike, **kwargs: Any
832+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
827833
"""Create an empty array like `a`.
828834
829835
Parameters
@@ -843,7 +849,9 @@ async def empty_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
843849

844850

845851
# TODO: add type annotations for fill_value and kwargs
846-
async def full(shape: ChunkCoords, fill_value: Any, **kwargs: Any) -> AsyncArray:
852+
async def full(
853+
shape: ChunkCoords, fill_value: Any, **kwargs: Any
854+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
847855
"""Create an array, with `fill_value` being used as the default value for
848856
uninitialized portions of the array.
849857
@@ -865,7 +873,9 @@ async def full(shape: ChunkCoords, fill_value: Any, **kwargs: Any) -> AsyncArray
865873

866874

867875
# TODO: add type annotations for kwargs
868-
async def full_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
876+
async def full_like(
877+
a: ArrayLike, **kwargs: Any
878+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
869879
"""Create a filled array like `a`.
870880
871881
Parameters
@@ -886,7 +896,9 @@ async def full_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
886896
return await full(**like_kwargs)
887897

888898

889-
async def ones(shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
899+
async def ones(
900+
shape: ChunkCoords, **kwargs: Any
901+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
890902
"""Create an array, with one being used as the default value for
891903
uninitialized portions of the array.
892904
@@ -905,7 +917,9 @@ async def ones(shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
905917
return await create(shape=shape, fill_value=1, **kwargs)
906918

907919

908-
async def ones_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
920+
async def ones_like(
921+
a: ArrayLike, **kwargs: Any
922+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
909923
"""Create an array of ones like `a`.
910924
911925
Parameters
@@ -932,7 +946,7 @@ async def open_array(
932946
path: PathLike | None = None,
933947
storage_options: dict[str, Any] | None = None,
934948
**kwargs: Any, # TODO: type kwargs as valid args to save
935-
) -> AsyncArray:
949+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
936950
"""Open an array using file-mode-like semantics.
937951
938952
Parameters
@@ -975,7 +989,9 @@ async def open_array(
975989
raise
976990

977991

978-
async def open_like(a: ArrayLike, path: str, **kwargs: Any) -> AsyncArray:
992+
async def open_like(
993+
a: ArrayLike, path: str, **kwargs: Any
994+
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]:
979995
"""Open a persistent array like `a`.
980996
981997
Parameters
@@ -998,7 +1014,9 @@ async def open_like(a: ArrayLike, path: str, **kwargs: Any) -> AsyncArray:
9981014
return await open_array(path=path, **like_kwargs)
9991015

10001016

1001-
async def zeros(shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
1017+
async def zeros(
1018+
shape: ChunkCoords, **kwargs: Any
1019+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
10021020
"""Create an array, with zero being used as the default value for
10031021
uninitialized portions of the array.
10041022
@@ -1017,7 +1035,9 @@ async def zeros(shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
10171035
return await create(shape=shape, fill_value=0, **kwargs)
10181036

10191037

1020-
async def zeros_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
1038+
async def zeros_like(
1039+
a: ArrayLike, **kwargs: Any
1040+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
10211041
"""Create an array of zeros like `a`.
10221042
10231043
Parameters

src/zarr/core/array.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from asyncio import gather
55
from dataclasses import dataclass, field, replace
66
from logging import getLogger
7-
from typing import TYPE_CHECKING, Any, Literal, cast
7+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
88

99
import numpy as np
1010
import numpy.typing as npt
@@ -73,10 +73,10 @@
7373

7474
if TYPE_CHECKING:
7575
from collections.abc import Iterable, Iterator, Sequence
76+
from typing import Self
7677

7778
from zarr.abc.codec import Codec, CodecPipeline
7879
from zarr.core.group import AsyncGroup
79-
from zarr.core.metadata.common import ArrayMetadata
8080

8181
# Array and AsyncArray are defined in the base ``zarr`` namespace
8282
__all__ = ["create_codec_pipeline", "parse_array_metadata"]
@@ -160,16 +160,19 @@ async def get_array_metadata(
160160
return metadata_dict
161161

162162

163+
TArrayMeta = TypeVar("TArrayMeta", ArrayV2Metadata, ArrayV3Metadata)
164+
165+
163166
@dataclass(frozen=True)
164-
class AsyncArray:
165-
metadata: ArrayMetadata
167+
class AsyncArray(Generic[TArrayMeta]):
168+
metadata: TArrayMeta
166169
store_path: StorePath
167170
codec_pipeline: CodecPipeline = field(init=False)
168171
order: Literal["C", "F"]
169172

170173
def __init__(
171174
self,
172-
metadata: ArrayMetadata | dict[str, Any],
175+
metadata: ArrayV2Metadata | ArrayV3Metadata | dict[str, Any],
173176
store_path: StorePath,
174177
order: Literal["C", "F"] | None = None,
175178
) -> None:
@@ -218,7 +221,7 @@ async def create(
218221
# runtime
219222
exists_ok: bool = False,
220223
data: npt.ArrayLike | None = None,
221-
) -> AsyncArray:
224+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
222225
store_path = await make_store_path(store)
223226

224227
shape = parse_shapelike(shape)
@@ -231,7 +234,7 @@ async def create(
231234
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
232235
else:
233236
_chunks = normalize_chunks(chunk_shape, shape, dtype.itemsize)
234-
237+
result: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
235238
if zarr_format == 3:
236239
if dimension_separator is not None:
237240
raise ValueError(
@@ -313,7 +316,7 @@ async def _create_v3(
313316
dimension_names: Iterable[str] | None = None,
314317
attributes: dict[str, JSON] | None = None,
315318
exists_ok: bool = False,
316-
) -> AsyncArray:
319+
) -> AsyncArray[ArrayV3Metadata]:
317320
if not exists_ok:
318321
await ensure_no_existing_node(store_path, zarr_format=3)
319322

@@ -344,7 +347,9 @@ async def _create_v3(
344347

345348
array = cls(metadata=metadata, store_path=store_path)
346349
await array._save_metadata(metadata, ensure_parents=True)
347-
return array
350+
# type inference is inconsistent here and seems to conclude
351+
# that array has type Array[ArrayV2Metadata]
352+
return array # type: ignore[return-value]
348353

349354
@classmethod
350355
async def _create_v2(
@@ -361,7 +366,7 @@ async def _create_v2(
361366
compressor: dict[str, JSON] | None = None,
362367
attributes: dict[str, JSON] | None = None,
363368
exists_ok: bool = False,
364-
) -> AsyncArray:
369+
) -> AsyncArray[ArrayV2Metadata]:
365370
if not exists_ok:
366371
await ensure_no_existing_node(store_path, zarr_format=2)
367372
if order is None:
@@ -383,14 +388,14 @@ async def _create_v2(
383388
)
384389
array = cls(metadata=metadata, store_path=store_path)
385390
await array._save_metadata(metadata, ensure_parents=True)
386-
return array
391+
return array # type: ignore[return-value]
387392

388393
@classmethod
389394
def from_dict(
390395
cls,
391396
store_path: StorePath,
392397
data: dict[str, JSON],
393-
) -> AsyncArray:
398+
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]:
394399
metadata = parse_array_metadata(data)
395400
return cls(metadata=metadata, store_path=store_path)
396401

@@ -399,7 +404,7 @@ async def open(
399404
cls,
400405
store: StoreLike,
401406
zarr_format: ZarrFormat | None = 3,
402-
) -> AsyncArray:
407+
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]:
403408
store_path = await make_store_path(store)
404409
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
405410
return cls(store_path=store_path, metadata=metadata_dict)
@@ -631,7 +636,9 @@ async def getitem(
631636
)
632637
return await self._get_selection(indexer, prototype=prototype)
633638

634-
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
639+
async def _save_metadata(
640+
self, metadata: ArrayV2Metadata | ArrayV3Metadata, ensure_parents: bool = False
641+
) -> None:
635642
to_save = metadata.to_buffer_dict(default_buffer_prototype())
636643
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
637644

@@ -719,9 +726,7 @@ async def setitem(
719726
)
720727
return await self._set_selection(indexer, value, prototype=prototype)
721728

722-
async def resize(
723-
self, new_shape: ChunkCoords, delete_outside_chunks: bool = True
724-
) -> AsyncArray:
729+
async def resize(self, new_shape: ChunkCoords, delete_outside_chunks: bool = True) -> Self:
725730
assert len(new_shape) == len(self.metadata.shape)
726731
new_metadata = self.metadata.update_shape(new_shape)
727732

@@ -747,7 +752,7 @@ async def _delete_key(key: str) -> None:
747752
await self._save_metadata(new_metadata)
748753
return replace(self, metadata=new_metadata)
749754

750-
async def update_attributes(self, new_attributes: dict[str, JSON]) -> AsyncArray:
755+
async def update_attributes(self, new_attributes: dict[str, JSON]) -> Self:
751756
new_metadata = self.metadata.update_attributes(new_attributes)
752757

753758
# Write new metadata
@@ -763,7 +768,7 @@ async def info(self) -> None:
763768

764769
@dataclass(frozen=True)
765770
class Array:
766-
_async_array: AsyncArray
771+
_async_array: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
767772

768773
@classmethod
769774
@_deprecate_positional_args
@@ -879,7 +884,7 @@ def basename(self) -> str | None:
879884
return self._async_array.basename
880885

881886
@property
882-
def metadata(self) -> ArrayMetadata:
887+
def metadata(self) -> ArrayV2Metadata | ArrayV3Metadata:
883888
return self._async_array.metadata
884889

885890
@property
@@ -2309,18 +2314,10 @@ def resize(self, new_shape: ChunkCoords) -> Array:
23092314
the data falling outside the new array but inside the boundary chunks
23102315
would be restored by a subsequent resize operation that grows the array size.
23112316
"""
2312-
return type(self)(
2313-
sync(
2314-
self._async_array.resize(new_shape),
2315-
)
2316-
)
2317+
return type(self)(sync(self._async_array.resize(new_shape))) # type: ignore[arg-type]
23172318

23182319
def update_attributes(self, new_attributes: dict[str, JSON]) -> Array:
2319-
return type(self)(
2320-
sync(
2321-
self._async_array.update_attributes(new_attributes),
2322-
)
2323-
)
2320+
return type(self)(sync(self._async_array.update_attributes(new_attributes))) # type: ignore[arg-type]
23242321

23252322
def __repr__(self) -> str:
23262323
return f"<Array {self.store_path} shape={self.shape} dtype={self.dtype}>"
@@ -2331,7 +2328,9 @@ def info(self) -> None:
23312328
)
23322329

23332330

2334-
def nchunks_initialized(array: AsyncArray | Array) -> int:
2331+
def nchunks_initialized(
2332+
array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | Array,
2333+
) -> int:
23352334
"""
23362335
Calculate the number of chunks that have been initialized, i.e. the number of chunks that have
23372336
been persisted to the storage backend.
@@ -2353,7 +2352,9 @@ def nchunks_initialized(array: AsyncArray | Array) -> int:
23532352
return len(chunks_initialized(array))
23542353

23552354

2356-
def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
2355+
def chunks_initialized(
2356+
array: Array | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata],
2357+
) -> tuple[str, ...]:
23572358
"""
23582359
Return the keys of the chunks that have been persisted to the storage backend.
23592360
@@ -2385,7 +2386,9 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
23852386
return tuple(out)
23862387

23872388

2388-
def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]:
2389+
def _build_parents(
2390+
node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup,
2391+
) -> list[AsyncGroup]:
23892392
from zarr.core.group import AsyncGroup, GroupMetadata
23902393

23912394
store = node.store_path.store

0 commit comments

Comments
 (0)