Skip to content

Commit 5dcd80b

Browse files
committed
test for auto sharding
1 parent ae1832d commit 5dcd80b

File tree

3 files changed

+61
-13
lines changed

3 files changed

+61
-13
lines changed

src/zarr/core/array.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3472,8 +3472,8 @@ async def create_array(
34723472
name: str | None = None,
34733473
shape: ShapeLike,
34743474
dtype: npt.DTypeLike,
3475-
chunk_shape: ChunkCoords | Literal["auto"] = "auto",
3476-
shard_shape: ChunkCoords | None = None,
3475+
chunks: ChunkCoords | Literal["auto"] = "auto",
3476+
shards: ChunkCoords | Literal["auto"] | None = None,
34773477
filters: FiltersParam = "auto",
34783478
compression: CompressionParam = "auto",
34793479
fill_value: Any | None = 0,
@@ -3500,9 +3500,9 @@ async def create_array(
35003500
Shape of the array.
35013501
dtype : npt.DTypeLike
35023502
Data type of the array.
3503-
chunk_shape : ChunkCoords
3503+
chunks : ChunkCoords
35043504
Chunk shape of the array.
3505-
shard_shape : ChunkCoords, optional
3505+
shards : ChunkCoords, optional
35063506
Shard shape of the array. The default value of ``None`` results in no sharding at all.
35073507
filters : Iterable[Codec], optional
35083508
List of filters to apply to the array.
@@ -3552,15 +3552,16 @@ async def create_array(
35523552
)
35533553
store_path = await make_store_path(store, path=name, mode=mode, storage_options=storage_options)
35543554
shard_shape_parsed, chunk_shape_parsed = _auto_partition(
3555-
shape_parsed, shard_shape, chunk_shape, dtype_parsed
3555+
array_shape=shape_parsed, shard_shape=shards, chunk_shape=chunks, dtype=dtype_parsed
35563556
)
3557+
chunks_out: tuple[int, ...]
35573558
result: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
35583559

35593560
if zarr_format == 2:
35603561
if shard_shape_parsed is not None:
35613562
msg = (
3562-
'Zarr v2 arrays can only be created with `shard_shape` set to `None` or `"auto"`.'
3563-
f"Got `shard_shape={shard_shape}` instead."
3563+
"Zarr v2 arrays can only be created with `shard_shape` set to `None`."
3564+
f"Got `shard_shape={shards}` instead."
35643565
)
35653566

35663567
raise ValueError(msg)
@@ -3604,10 +3605,10 @@ async def create_array(
36043605
sharding_codec.validate(
36053606
shape=chunk_shape_parsed,
36063607
dtype=dtype_parsed,
3607-
chunk_grid=RegularChunkGrid(chunk_shape=shard_shape),
3608+
chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed),
36083609
)
36093610
codecs_out = (sharding_codec,)
3610-
chunks_out = shard_shape
3611+
chunks_out = shard_shape_parsed
36113612
else:
36123613
chunks_out = chunk_shape_parsed
36133614
codecs_out = sub_codecs

src/zarr/core/chunk_grids.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,10 @@ def get_nchunks(self, array_shape: ChunkCoords) -> int:
197197

198198

199199
def _auto_partition(
200+
*,
200201
array_shape: tuple[int, ...],
201-
shard_shape: tuple[int, ...] | Literal["auto"] | None,
202202
chunk_shape: tuple[int, ...] | Literal["auto"],
203+
shard_shape: tuple[int, ...] | Literal["auto"] | None,
203204
dtype: np.dtype[Any],
204205
) -> tuple[tuple[int, ...] | None, tuple[int, ...]]:
205206
"""
@@ -210,7 +211,6 @@ def _auto_partition(
210211
of the array; if the `chunk_shape` is also "auto", then the chunks will be set heuristically as well,
211212
given the dtype and shard shape. Otherwise, the chunks will be returned as-is.
212213
"""
213-
214214
item_size = dtype.itemsize
215215
if shard_shape is None:
216216
_shards_out: None | tuple[int, ...] = None
@@ -229,9 +229,9 @@ def _auto_partition(
229229
_shards_out = ()
230230
for a_shape, c_shape in zip(array_shape, _chunks_out, strict=True):
231231
# TODO: make a better heuristic than this.
232-
# for each axis, if there are more than 16 chunks along that axis, then make put
232+
# for each axis, if there are more than 8 chunks along that axis, then put
233233
# 2 chunks in each shard for that axis.
234-
if a_shape // c_shape > 16:
234+
if a_shape // c_shape > 8:
235235
_shards_out += (c_shape * 2,)
236236
else:
237237
_shards_out += (c_shape,)

tests/test_array.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
import zarr.api.asynchronous
1414
from zarr import Array, AsyncArray, Group
1515
from zarr.codecs import BytesCodec, VLenBytesCodec, ZstdCodec
16+
from zarr.codecs.sharding import ShardingCodec
1617
from zarr.core._info import ArrayInfo
1718
from zarr.core.array import chunks_initialized
1819
from zarr.core.buffer import default_buffer_prototype
1920
from zarr.core.buffer.cpu import NDBuffer
21+
from zarr.core.chunk_grids import _auto_partition
22+
from zarr.core.codec_pipeline import BatchedCodecPipeline
2023
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
2124
from zarr.core.group import AsyncGroup
2225
from zarr.core.indexing import ceildiv
@@ -881,3 +884,47 @@ async def test_nbytes(
881884
assert arr._async_array.nbytes == np.prod(arr.shape) * arr.dtype.itemsize
882885
else:
883886
assert arr.nbytes == np.prod(arr.shape) * arr.dtype.itemsize
887+
888+
889+
def _get_partitioning(data: AsyncArray) -> tuple[tuple[int, ...], tuple[int, ...] | None]:
890+
"""
891+
Get the shard shape and chunk shape of an array. If the array is not sharded, the shard shape
892+
will be None.
893+
"""
894+
895+
shard_shape: tuple[int, ...] | None
896+
chunk_shape: tuple[int, ...]
897+
codecs = data.codec_pipeline
898+
if isinstance(codecs, BatchedCodecPipeline):
899+
if isinstance(codecs.array_bytes_codec, ShardingCodec):
900+
chunk_shape = codecs.array_bytes_codec.chunk_shape
901+
shard_shape = data.chunks
902+
else:
903+
chunk_shape = data.chunks
904+
shard_shape = None
905+
return chunk_shape, shard_shape
906+
907+
908+
@pytest.mark.parametrize(
909+
("array_shape", "chunk_shape"),
910+
[((256,), (2,))],
911+
)
912+
def test_auto_partition_auto_shards(
913+
array_shape: tuple[int, ...], chunk_shape: tuple[int, ...]
914+
) -> None:
915+
"""
916+
Test that automatically picking a shard size returns a tuple of 2 * the chunk shape for any axis
917+
where there are 8 or more chunks.
918+
"""
919+
dtype = np.dtype("uint8")
920+
expected_shards: tuple[int, ...] = ()
921+
for cs, a_len in zip(chunk_shape, array_shape, strict=False):
922+
if a_len // cs >= 8:
923+
expected_shards += (2 * cs,)
924+
else:
925+
expected_shards += (cs,)
926+
927+
auto_shards, _ = _auto_partition(
928+
array_shape=array_shape, chunk_shape=chunk_shape, shard_shape="auto", dtype=dtype
929+
)
930+
assert auto_shards == expected_shards

0 commit comments

Comments
 (0)