Skip to content

Commit 5746680

Browse files
committed
fix: sharding codec with fancy indexing
1 parent 2f8b88a commit 5746680

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/zarr/codecs/sharding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,12 @@ async def _decode_partial_single(
481481
)
482482

483483
# setup output array
484+
if hasattr(indexer, "sel_shape"):
485+
out_shape = indexer.sel_shape
486+
else:
487+
out_shape = indexer.shape
484488
out = shard_spec.prototype.nd_buffer.create(
485-
shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
489+
shape=out_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
486490
)
487491

488492
indexed_chunks = list(indexer)

tests/test_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,3 +1420,18 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser
14201420

14211421
results = pool.starmap(_index_array, [(arr, slice(len(data)))])
14221422
assert all(np.array_equal(r, data) for r in results)
1423+
1424+
1425+
async def test_sharding_coordinate_selection() -> None:
1426+
store = MemoryStore()
1427+
g = zarr.open_group(store, mode="w")
1428+
arr = g.create_array(
1429+
name="a",
1430+
shape=(10, 20, 30),
1431+
chunks=(5, 1, 30),
1432+
overwrite=True,
1433+
dtype=np.float32,
1434+
shards=(5, 20, 30),
1435+
)
1436+
arr[:] = 1
1437+
assert (arr[5, [0, 1]] == 1).all()

0 commit comments

Comments
 (0)