Skip to content

Commit cc3698b

Browse files
committed
factor array element iteration routines into stand-alone functions, and add a failing test
1 parent a4e9660 commit cc3698b

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

src/zarr/core/array.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4928,3 +4928,96 @@ def _parse_data_params(
49284928
raise ValueError(msg)
49294929
dtype_out = data.dtype
49304930
return data, shape_out, dtype_out
4931+
4932+
4933+
def iter_chunk_coords(
4934+
array: Array | AsyncArray[Any],
4935+
*,
4936+
origin: Sequence[int] | None = None,
4937+
selection_shape: Sequence[int] | None = None,
4938+
) -> Iterator[ChunkCoords]:
4939+
"""
4940+
Create an iterator over the coordinates of chunks in chunk grid space. If the `origin`
4941+
keyword is used, iteration will start at the chunk index specified by `origin`.
4942+
The default behavior is to start at the origin of the grid coordinate space.
4943+
If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region
4944+
ranging from `[origin, origin selection_shape]`, where the upper bound is exclusive as
4945+
per python indexing conventions.
4946+
4947+
Parameters
4948+
----------
4949+
array : Array | AsyncArray
4950+
The array to iterate over.
4951+
origin : Sequence[int] | None, default=None
4952+
The origin of the selection relative to the array's chunk grid.
4953+
selection_shape : Sequence[int] | None, default=None
4954+
The shape of the selection in chunk grid coordinates.
4955+
4956+
Yields
4957+
------
4958+
chunk_coords: ChunkCoords
4959+
The coordinates of each chunk in the selection.
4960+
"""
4961+
return _iter_grid(array.cdata_shape, origin=origin, selection_shape=selection_shape)
4962+
4963+
4964+
def iter_chunk_keys(
4965+
array: Array | AsyncArray[Any],
4966+
*,
4967+
origin: Sequence[int] | None = None,
4968+
selection_shape: Sequence[int] | None = None,
4969+
) -> Iterator[str]:
4970+
"""
4971+
Iterate over the storage keys of each chunk, relative to an optional origin, and optionally
4972+
limited to a contiguous region in chunk grid coordinates.
4973+
4974+
Parameters
4975+
----------
4976+
array : Array | AsyncArray
4977+
The array to iterate over.
4978+
origin : Sequence[int] | None, default=None
4979+
The origin of the selection relative to the array's chunk grid.
4980+
selection_shape : Sequence[int] | None, default=None
4981+
The shape of the selection in chunk grid coordinates.
4982+
4983+
Yields
4984+
------
4985+
key: str
4986+
The storage key of each chunk in the selection.
4987+
"""
4988+
# Iterate over the coordinates of chunks in chunk grid space.
4989+
for k in iter_chunk_coords(array, origin=origin, selection_shape=selection_shape):
4990+
# Encode the chunk key from the chunk coordinates.
4991+
yield array.metadata.encode_chunk_key(k)
4992+
4993+
4994+
def iter_chunk_regions(
4995+
array: Array | AsyncArray[Any],
4996+
*,
4997+
origin: Sequence[int] | None = None,
4998+
selection_shape: Sequence[int] | None = None,
4999+
) -> Iterator[tuple[slice, ...]]:
5000+
"""
5001+
Iterate over the regions spanned by each chunk.
5002+
5003+
Parameters
5004+
----------
5005+
array : Array | AsyncArray
5006+
The array to iterate over.
5007+
origin : Sequence[int] | None, default=None
5008+
The origin of the selection relative to the array's chunk grid.
5009+
selection_shape : Sequence[int] | None, default=None
5010+
The shape of the selection in chunk grid coordinates.
5011+
5012+
Yields
5013+
------
5014+
region: tuple[slice, ...]
5015+
A tuple of slice objects representing the region spanned by each chunk in the selection.
5016+
"""
5017+
for cgrid_position in iter_chunk_coords(array, origin=origin, selection_shape=selection_shape):
5018+
out: tuple[slice, ...] = ()
5019+
for c_pos, c_shape in zip(cgrid_position, array.chunks, strict=False):
5020+
start = c_pos * c_shape
5021+
stop = start + c_shape
5022+
out += (slice(start, stop, 1),)
5023+
yield out

tests/test_array.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pickle
77
import re
88
import sys
9-
from itertools import accumulate
9+
from itertools import accumulate, starmap
1010
from typing import TYPE_CHECKING, Any, Literal
1111
from unittest import mock
1212

@@ -37,6 +37,7 @@
3737
create_array,
3838
default_filters_v2,
3939
default_serializer_v3,
40+
iter_chunk_keys,
4041
)
4142
from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype
4243
from zarr.core.chunk_grids import _auto_partition
@@ -59,7 +60,7 @@
5960
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
6061
from zarr.core.dtype.npy.string import UTF8Base
6162
from zarr.core.group import AsyncGroup
62-
from zarr.core.indexing import BasicIndexer, ceildiv
63+
from zarr.core.indexing import BasicIndexer, _iter_grid, ceildiv
6364
from zarr.core.metadata.v2 import ArrayV2Metadata
6465
from zarr.core.metadata.v3 import ArrayV3Metadata
6566
from zarr.core.sync import sync
@@ -1835,3 +1836,35 @@ def test_unknown_object_codec_default_filters_v2() -> None:
18351836
msg = f"Data type {dtype} requires an unknown object codec: {dtype.object_codec_id!r}."
18361837
with pytest.raises(ValueError, match=re.escape(msg)):
18371838
default_filters_v2(dtype)
1839+
1840+
1841+
@pytest.mark.parametrize(
1842+
("shard_size", "chunk_size"),
1843+
[
1844+
((8,), (8,)),
1845+
((8,), (2,)),
1846+
(
1847+
(
1848+
8,
1849+
10,
1850+
),
1851+
(2, 2),
1852+
),
1853+
],
1854+
)
1855+
def test_iter_chunk_keys(shard_size: tuple[int, ...], chunk_size: tuple[int, ...]) -> None:
1856+
store = {}
1857+
arr = zarr.create_array(
1858+
store,
1859+
dtype="uint8",
1860+
shape=tuple(2 * x for x in shard_size),
1861+
chunks=chunk_size,
1862+
shards=shard_size,
1863+
zarr_format=3,
1864+
)
1865+
shard_grid_shape = tuple(starmap(ceildiv, zip(arr.shape, arr.shards, strict=True)))
1866+
expected_keys = tuple(
1867+
arr.metadata.chunk_key_encoding.encode_chunk_key(region)
1868+
for region in _iter_grid(shard_grid_shape)
1869+
)
1870+
assert tuple(iter_chunk_keys(arr)) == expected_keys

0 commit comments

Comments
 (0)