Skip to content

Commit df6f9a7

Browse files
committed
tests for nchunks_initialized, chunks_initialized; add selection_shape kwarg to grid iteration; make chunk grid iterators consistent for array and async array
1 parent f65a6e8 commit df6f9a7

File tree

4 files changed

+304
-34
lines changed

4 files changed

+304
-34
lines changed

src/zarr/core/array.py

Lines changed: 167 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -461,25 +461,83 @@ def nchunks(self) -> int:
461461
"""
462462
return product(self.cdata_shape)
463463

464-
def _iter_chunk_coords(self, origin: Sequence[int] | None = None) -> Iterator[ChunkCoords]:
464+
@property
465+
def nchunks_initialized(self) -> int:
466+
"""
467+
The number of chunks that have been persisted in storage.
465468
"""
466-
Produce an iterator over the coordinates of each chunk, in chunk grid space, relative to
467-
an optional origin.
469+
return nchunks_initialized(self)
470+
471+
def _iter_chunk_coords(
472+
self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
473+
) -> Iterator[ChunkCoords]:
468474
"""
469-
return _iter_grid(self.cdata_shape, origin=origin)
475+
Create an iterator over the coordinates of chunks in chunk grid space. If the `origin`
476+
keyword is used, iteration will start at the chunk index specified by `origin`.
477+
The default behavior is to start at the origin of the grid coordinate space.
478+
If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region
479+
ranging from `[origin, origin + selection_shape]`, where the upper bound is exclusive as
480+
per python indexing conventions.
470481
471-
def _iter_chunk_keys(self, origin: Sequence[int] | None = None) -> Iterator[str]:
482+
Parameters
483+
----------
484+
origin: Sequence[int] | None, default=None
485+
The origin of the selection relative to the array's chunk grid.
486+
selection_shape: Sequence[int] | None, default=None
487+
The shape of the selection in chunk grid coordinates.
488+
489+
Yields
490+
------
491+
chunk_coords: ChunkCoords
492+
The coordinates of each chunk in the selection.
472493
"""
473-
Return an iterator over the storage keys of each chunk, relative to an optional origin.
494+
return _iter_grid(self.cdata_shape, origin=origin, selection_shape=selection_shape)
495+
496+
def _iter_chunk_keys(
497+
self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
498+
) -> Iterator[str]:
499+
"""
500+
Iterate over the storage keys of each chunk, relative to an optional origin, and optionally
501+
limited to a contiguous region in chunk grid coordinates.
502+
503+
Parameters
504+
----------
505+
origin: Sequence[int] | None, default=None
506+
The origin of the selection relative to the array's chunk grid.
507+
selection_shape: Sequence[int] | None, default=None
508+
The shape of the selection in chunk grid coordinates.
509+
510+
Yields
511+
------
512+
key: str
513+
The storage key of each chunk in the selection.
474514
"""
475-
for k in self._iter_chunk_coords(origin=origin):
515+
# Iterate over the coordinates of chunks in chunk grid space.
516+
for k in self._iter_chunk_coords(origin=origin, selection_shape=selection_shape):
517+
# Encode the chunk key from the chunk coordinates.
476518
yield self.metadata.encode_chunk_key(k)
477519

478-
def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]:
520+
def _iter_chunk_regions(
521+
self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
522+
) -> Iterator[tuple[slice, ...]]:
479523
"""
480524
Iterate over the regions spanned by each chunk.
525+
526+
Parameters
527+
----------
528+
origin: Sequence[int] | None, default=None
529+
The origin of the selection relative to the array's chunk grid.
530+
selection_shape: Sequence[int] | None, default=None
531+
The shape of the selection in chunk grid coordinates.
532+
533+
Yields
534+
------
535+
region: tuple[slice, ...]
536+
A tuple of slice objects representing the region spanned by each chunk in the selection.
481537
"""
482-
for cgrid_position in self._iter_chunk_coords():
538+
for cgrid_position in self._iter_chunk_coords(
539+
origin=origin, selection_shape=selection_shape
540+
):
483541
out: tuple[slice, ...] = ()
484542
for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False):
485543
start = c_pos * c_shape
@@ -816,12 +874,32 @@ def nchunks(self) -> int:
816874
"""
817875
return self._async_array.nchunks
818876

819-
def _iter_chunks(self, origin: Sequence[int] | None = None) -> Iterator[ChunkCoords]:
877+
def _iter_chunk_coords(
878+
self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
879+
) -> Iterator[ChunkCoords]:
820880
"""
821-
Produce an iterator over the coordinates of each chunk, in chunk grid space, relative
822-
to an optional origin.
881+
Create an iterator over the coordinates of chunks in chunk grid space. If the `origin`
882+
keyword is used, iteration will start at the chunk index specified by `origin`.
883+
The default behavior is to start at the origin of the grid coordinate space.
884+
If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region
885+
ranging from `[origin, origin + selection_shape]`, where the upper bound is exclusive as
886+
per python indexing conventions.
887+
888+
Parameters
889+
----------
890+
origin: Sequence[int] | None, default=None
891+
The origin of the selection relative to the array's chunk grid.
892+
selection_shape: Sequence[int] | None, default=None
893+
The shape of the selection in chunk grid coordinates.
894+
895+
Yields
896+
------
897+
chunk_coords: ChunkCoords
898+
The coordinates of each chunk in the selection.
823899
"""
824-
yield from self._async_array._iter_chunk_coords(origin=origin)
900+
yield from self._async_array._iter_chunk_coords(
901+
origin=origin, selection_shape=selection_shape
902+
)
825903

826904
@property
827905
def nbytes(self) -> int:
@@ -830,17 +908,57 @@ def nbytes(self) -> int:
830908
"""
831909
return self._async_array.nbytes
832910

833-
def _iter_chunk_keys(self, origin: Sequence[int] | None = None) -> Iterator[str]:
911+
@property
912+
def nchunks_initialized(self) -> int:
913+
"""
914+
The number of chunks that have been initialized in the stored representation of this array.
915+
"""
916+
return self._async_array.nchunks_initialized
917+
918+
def _iter_chunk_keys(
919+
self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
920+
) -> Iterator[str]:
834921
"""
835-
Return an iterator over the keys of each chunk, relative to an optional origin
922+
Iterate over the storage keys of each chunk, relative to an optional origin, and optionally
923+
limited to a contiguous region in chunk grid coordinates.
924+
925+
Parameters
926+
----------
927+
origin: Sequence[int] | None, default=None
928+
The origin of the selection relative to the array's chunk grid.
929+
selection_shape: Sequence[int] | None, default=None
930+
The shape of the selection in chunk grid coordinates.
931+
932+
Yields
933+
------
934+
key: str
935+
The storage key of each chunk in the selection.
836936
"""
837-
yield from self._async_array._iter_chunk_keys(origin=origin)
937+
yield from self._async_array._iter_chunk_keys(
938+
origin=origin, selection_shape=selection_shape
939+
)
838940

839-
def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]:
941+
def _iter_chunk_regions(
942+
self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
943+
) -> Iterator[tuple[slice, ...]]:
840944
"""
841945
Iterate over the regions spanned by each chunk.
946+
947+
Parameters
948+
----------
949+
origin: Sequence[int] | None, default=None
950+
The origin of the selection relative to the array's chunk grid.
951+
selection_shape: Sequence[int] | None, default=None
952+
The shape of the selection in chunk grid coordinates.
953+
954+
Yields
955+
------
956+
region: tuple[slice, ...]
957+
A tuple of slice objects representing the region spanned by each chunk in the selection.
842958
"""
843-
yield from self._async_array._iter_chunk_regions()
959+
yield from self._async_array._iter_chunk_regions(
960+
origin=origin, selection_shape=selection_shape
961+
)
844962

845963
def __array__(
846964
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
@@ -2175,17 +2293,46 @@ def info(self) -> None:
21752293
)
21762294

21772295

2178-
def nchunks_initialized(array: Array) -> int:
2296+
def nchunks_initialized(array: AsyncArray | Array) -> int:
21792297
"""
21802298
Calculate the number of chunks that have been initialized, i.e. the number of chunks that have
21812299
been persisted to the storage backend.
2300+
2301+
Parameters
2302+
----------
2303+
array : Array
2304+
The array to inspect.
2305+
2306+
Returns
2307+
-------
2308+
nchunks_initialized : int
2309+
The number of chunks that have been initialized.
2310+
2311+
See Also
2312+
--------
2313+
chunks_initialized
21822314
"""
21832315
return len(chunks_initialized(array))
21842316

21852317

2186-
def chunks_initialized(array: Array) -> tuple[str, ...]:
2318+
def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
21872319
"""
21882320
Return the keys of the chunks that have been persisted to the storage backend.
2321+
2322+
Parameters
2323+
----------
2324+
array : Array
2325+
The array to inspect.
2326+
2327+
Returns
2328+
-------
2329+
chunks_initialized : tuple[str, ...]
2330+
The keys of the chunks that have been initialized.
2331+
2332+
See Also
2333+
--------
2334+
nchunks_initialized
2335+
21892336
"""
21902337
# TODO: make this compose with the underlying async iterator
21912338
store_contents = list(

src/zarr/core/indexing.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,24 @@ def ceildiv(a: float, b: float) -> int:
9393

9494

9595
def _iter_grid(
96-
shape: Sequence[int],
96+
grid_shape: Sequence[int],
9797
*,
9898
origin: Sequence[int] | None = None,
99+
selection_shape: Sequence[int] | None = None,
99100
order: _ArrayIndexingOrder = "lexicographic",
100101
) -> Iterator[ChunkCoords]:
101102
"""
102-
Iterate over the elements of grid of integers.
103-
104-
Takes a grid shape expressed as a sequence of integers and an optional origin and
105-
yields tuples bounded by [origin, origin + grid_shape].
103+
Iterate over the elements of grid of integers, with the option to restrict the domain of
104+
iteration to those from a contiguous subregion of that grid.
106105
107106
Parameters
108107
---------
109-
shape: Sequence[int]
108+
grid_shape: Sequence[int]
110109
The size of the domain to iterate over.
111110
origin: Sequence[int] | None, default=None
112-
The first coordinate of the domain.
111+
The first coordinate of the domain to return.
112+
selection_shape: Sequence[int] | None, default=None
113+
The shape of the selection.
113114
order: Literal["lexicographic"], default="lexicographic"
114115
The linear indexing order to use.
115116
@@ -129,22 +130,37 @@ def _iter_grid(
129130
130131
>>> tuple(iter_grid((2,3)), origin=(1,1))
131132
((1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3))
133+
134+
>>> tuple(iter_grid((2,3)), origin=(1,1), selection_shape=(2,2))
135+
((1, 1), (1, 2), (1, 3), (2, 1))
132136
"""
133137
if origin is None:
134-
origin_parsed = (0,) * len(shape)
138+
origin_parsed = (0,) * len(grid_shape)
135139
else:
136-
if len(origin) != len(shape):
140+
if len(origin) != len(grid_shape):
137141
msg = (
138142
"Shape and origin parameters must have the same length."
139-
f"Got {len(shape)} elements in shape, but {len(origin)} elements in origin."
143+
f"Got {len(grid_shape)} elements in shape, but {len(origin)} elements in origin."
140144
)
141145
raise ValueError(msg)
142146
origin_parsed = tuple(origin)
143-
144-
if order == "lexicographic":
145-
yield from itertools.product(
146-
*(range(o, o + s) for o, s in zip(origin_parsed, shape, strict=True))
147+
if selection_shape is None:
148+
selection_shape_parsed = tuple(
149+
g - o for o, g in zip(origin_parsed, grid_shape, strict=True)
147150
)
151+
else:
152+
selection_shape_parsed = tuple(selection_shape)
153+
if order == "lexicographic":
154+
dimensions: tuple[range, ...] = ()
155+
for idx, (o, gs, ss) in enumerate(
156+
zip(origin_parsed, grid_shape, selection_shape_parsed, strict=True)
157+
):
158+
if o + ss > gs:
159+
raise IndexError(
160+
f"Invalid selection shape ({selection_shape}) for origin ({origin}) and grid shape ({grid_shape}) at axis {idx}."
161+
)
162+
dimensions += (range(o, o + ss),)
163+
yield from itertools.product(*(dimensions))
148164

149165
else:
150166
msg = f"Indexing order {order} is not supported at this time." # type: ignore[unreachable]

tests/v3/test_array.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import pickle
2+
from itertools import accumulate
23
from typing import Literal
34

45
import numpy as np
56
import pytest
67

8+
from src.zarr.core.array import chunks_initialized
79
from zarr import Array, AsyncArray, Group
810
from zarr.core.buffer.cpu import NDBuffer
911
from zarr.core.common import ZarrFormat
12+
from zarr.core.sync import sync
1013
from zarr.errors import ContainsArrayError, ContainsGroupError
1114
from zarr.store import LocalStore, MemoryStore
1215
from zarr.store.common import StorePath
@@ -232,3 +235,55 @@ def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) ->
232235

233236
assert actual == expected
234237
np.testing.assert_array_equal(actual[:], expected[:])
238+
239+
240+
@pytest.mark.parametrize("test_cls", [Array, AsyncArray])
241+
def test_nchunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None:
242+
"""
243+
Test that nchunks_initialized accurately returns the number of stored chunks.
244+
"""
245+
store = MemoryStore({}, mode="w")
246+
arr = Array.create(store, shape=(100,), chunks=(10,), dtype="i4")
247+
248+
# write chunks one at a time
249+
for idx, region in enumerate(arr._iter_chunk_regions()):
250+
arr[region] = 1
251+
expected = idx + 1
252+
if test_cls == Array:
253+
observed = arr.nchunks_initialized
254+
else:
255+
observed = arr._async_array.nchunks_initialized
256+
assert observed == expected
257+
258+
# delete chunks
259+
for idx, key in enumerate(arr._iter_chunk_keys()):
260+
sync(arr.store_path.store.delete(key))
261+
if test_cls == Array:
262+
observed = arr.nchunks_initialized
263+
else:
264+
observed = arr._async_array.nchunks_initialized
265+
expected = arr.nchunks - idx - 1
266+
assert observed == expected
267+
268+
269+
@pytest.mark.parametrize("test_cls", [Array, AsyncArray])
270+
def test_chunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None:
271+
"""
272+
Test that chunks_initialized accurately returns the keys of stored chunks.
273+
"""
274+
store = MemoryStore({}, mode="w")
275+
arr = Array.create(store, shape=(100,), chunks=(10,), dtype="i4")
276+
277+
chunks_accumulated = tuple(
278+
accumulate(tuple(map(lambda v: tuple(v.split(" ")), arr._iter_chunk_keys())))
279+
)
280+
for keys, region in zip(chunks_accumulated, arr._iter_chunk_regions(), strict=False):
281+
arr[region] = 1
282+
283+
if test_cls == Array:
284+
observed = sorted(chunks_initialized(arr))
285+
else:
286+
observed = sorted(chunks_initialized(arr._async_array))
287+
288+
expected = sorted(keys)
289+
assert observed == expected

0 commit comments

Comments
 (0)