|
13 | 13 | import zarr.api.asynchronous |
14 | 14 | from zarr import Array, AsyncArray, Group |
15 | 15 | from zarr.codecs import BytesCodec, VLenBytesCodec, ZstdCodec |
| 16 | +from zarr.codecs.sharding import ShardingCodec |
16 | 17 | from zarr.core._info import ArrayInfo |
17 | 18 | from zarr.core.array import chunks_initialized |
18 | 19 | from zarr.core.buffer import default_buffer_prototype |
19 | 20 | from zarr.core.buffer.cpu import NDBuffer |
| 21 | +from zarr.core.chunk_grids import _auto_partition |
| 22 | +from zarr.core.codec_pipeline import BatchedCodecPipeline |
20 | 23 | from zarr.core.common import JSON, MemoryOrder, ZarrFormat |
21 | 24 | from zarr.core.group import AsyncGroup |
22 | 25 | from zarr.core.indexing import ceildiv |
@@ -881,3 +884,47 @@ async def test_nbytes( |
881 | 884 | assert arr._async_array.nbytes == np.prod(arr.shape) * arr.dtype.itemsize |
882 | 885 | else: |
883 | 886 | assert arr.nbytes == np.prod(arr.shape) * arr.dtype.itemsize |
| 887 | + |
| 888 | + |
| 889 | +def _get_partitioning(data: AsyncArray) -> tuple[tuple[int, ...], tuple[int, ...] | None]: |
| 890 | + """ |
| 891 | + Get the shard shape and chunk shape of an array. If the array is not sharded, the shard shape |
| 892 | + will be None. |
| 893 | + """ |
| 894 | + |
| 895 | + shard_shape: tuple[int, ...] | None |
| 896 | + chunk_shape: tuple[int, ...] |
| 897 | + codecs = data.codec_pipeline |
| 898 | + if isinstance(codecs, BatchedCodecPipeline): |
| 899 | + if isinstance(codecs.array_bytes_codec, ShardingCodec): |
| 900 | + chunk_shape = codecs.array_bytes_codec.chunk_shape |
| 901 | + shard_shape = data.chunks |
| 902 | + else: |
| 903 | + chunk_shape = data.chunks |
| 904 | + shard_shape = None |
| 905 | + return chunk_shape, shard_shape |
| 906 | + |
| 907 | + |
| 908 | +@pytest.mark.parametrize( |
| 909 | + ("array_shape", "chunk_shape"), |
| 910 | + [((256,), (2,))], |
| 911 | +) |
| 912 | +def test_auto_partition_auto_shards( |
| 913 | + array_shape: tuple[int, ...], chunk_shape: tuple[int, ...] |
| 914 | +) -> None: |
| 915 | + """ |
| 916 | + Test that automatically picking a shard size returns a tuple of 2 * the chunk shape for any axis |
| 917 | + where there are 8 or more chunks. |
| 918 | + """ |
| 919 | + dtype = np.dtype("uint8") |
| 920 | + expected_shards: tuple[int, ...] = () |
| 921 | + for cs, a_len in zip(chunk_shape, array_shape, strict=False): |
| 922 | + if a_len // cs >= 8: |
| 923 | + expected_shards += (2 * cs,) |
| 924 | + else: |
| 925 | + expected_shards += (cs,) |
| 926 | + |
| 927 | + auto_shards, _ = _auto_partition( |
| 928 | + array_shape=array_shape, chunk_shape=chunk_shape, shard_shape="auto", dtype=dtype |
| 929 | + ) |
| 930 | + assert auto_shards == expected_shards |
0 commit comments