Skip to content

Commit 365ba5b

Browse files
committed
tests
1 parent 08a7844 commit 365ba5b

File tree

3 files changed

+111
-62
lines changed

3 files changed

+111
-62
lines changed

src/zarr/core/codec_pipeline.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,21 @@ async def write_batch(
332332
drop_axes: tuple[int, ...] = (),
333333
) -> None:
334334
if self.supports_partial_encode:
335-
await self.encode_partial_batch(
336-
[
337-
(byte_setter, value[out_selection], chunk_selection, chunk_spec)
338-
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
339-
],
340-
)
335+
# Pass scalar values as is
336+
if len(value.shape) == 0:
337+
await self.encode_partial_batch(
338+
[
339+
(byte_setter, value, chunk_selection, chunk_spec)
340+
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
341+
],
342+
)
343+
else:
344+
await self.encode_partial_batch(
345+
[
346+
(byte_setter, value[out_selection], chunk_selection, chunk_spec)
347+
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
348+
],
349+
)
341350

342351
else:
343352
# Read existing bytes if not total slice

tests/test_array.py

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
VLenUTF8Codec,
2121
ZstdCodec,
2222
)
23+
from zarr.codecs.sharding import ShardingCodec
2324
from zarr.core._info import ArrayInfo
2425
from zarr.core.array import (
2526
CompressorsLike,
@@ -479,88 +480,87 @@ def test_update_attrs(zarr_format: ZarrFormat) -> None:
479480
assert arr2.attrs["foo"] == "bar"
480481

481482

482-
@pytest.mark.parametrize(("chunks", "shards"), [((2, 2), None), (2, 2), (4, 4)])
483+
@pytest.mark.parametrize(("chunks", "shards"), [((2, 2), None), ((2, 2), (4, 4))])
483484
class TestInfo:
484-
chunks: tuple[int, int]
485-
shards: tuple[int, int] | None
486-
487-
def __init__(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None:
488-
self.chunks = chunks
489-
self.shards = shards
490-
491-
def test_info_v2(self) -> None:
492-
arr = zarr.create_array(
493-
store={}, shape=(8, 8), dtype="f8", chunks=self.chunks, zarr_format=2
494-
)
485+
def test_info_v2(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None:
486+
arr = zarr.create_array(store={}, shape=(8, 8), dtype="f8", chunks=chunks, zarr_format=2)
495487
result = arr.info
496488
expected = ArrayInfo(
497489
_zarr_format=2,
498490
_data_type=np.dtype("float64"),
499491
_shape=(8, 8),
500-
_chunk_shape=self.chunks,
492+
_chunk_shape=chunks,
501493
_shard_shape=None,
502494
_order="C",
503495
_read_only=False,
504496
_store_type="MemoryStore",
505-
_count_bytes=128,
497+
_count_bytes=512,
506498
_compressor=numcodecs.Zstd(),
507499
)
508500
assert result == expected
509501

510-
def test_info_v3(self) -> None:
511-
arr = zarr.create_array(
512-
store={}, shape=(8, 8), dtype="f8", chunks=self.chunks, shards=self.shards
513-
)
502+
def test_info_v3(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None:
503+
arr = zarr.create_array(store={}, shape=(8, 8), dtype="f8", chunks=chunks, shards=shards)
514504
result = arr.info
515505
expected = ArrayInfo(
516506
_zarr_format=3,
517507
_data_type=DataType.parse("float64"),
518508
_shape=(8, 8),
519-
_chunk_shape=(2, 2),
509+
_chunk_shape=chunks,
510+
_shard_shape=shards,
520511
_order="C",
521512
_read_only=False,
522513
_store_type="MemoryStore",
523-
_codecs=[BytesCodec(), ZstdCodec()],
524-
_count_bytes=128,
514+
_codecs=[BytesCodec(), ZstdCodec()]
515+
if shards is None
516+
else [ShardingCodec(chunk_shape=chunks, codecs=[BytesCodec(), ZstdCodec()])],
517+
_count_bytes=512,
525518
)
526519
assert result == expected
527520

528-
def test_info_complete(self) -> None:
521+
def test_info_complete(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None:
529522
arr = zarr.create_array(
530523
store={},
531524
shape=(8, 8),
532525
dtype="f8",
533-
chunks=self.chunks,
534-
shards=self.shards,
535-
compressors=None,
526+
chunks=chunks,
527+
shards=shards,
528+
compressors=(),
536529
)
537530
result = arr.info_complete()
538531
expected = ArrayInfo(
539532
_zarr_format=3,
540533
_data_type=DataType.parse("float64"),
541534
_shape=(8, 8),
542-
_chunk_shape=self.chunks,
543-
_shard_shape=self.shards,
535+
_chunk_shape=chunks,
536+
_shard_shape=shards,
544537
_order="C",
545538
_read_only=False,
546539
_store_type="MemoryStore",
547-
_codecs=[BytesCodec()],
548-
_count_bytes=128,
540+
_codecs=[BytesCodec()] if shards is None else [ShardingCodec(chunk_shape=chunks)],
541+
_count_bytes=512,
549542
_count_chunks_initialized=0,
550-
_count_bytes_stored=373, # the metadata?
543+
_count_bytes_stored=373 if shards is None else 578, # the metadata?
551544
)
552545
assert result == expected
553546

554-
arr[:2, :2] = 10
547+
arr[:4, :4] = 10
555548
result = arr.info_complete()
556-
expected = dataclasses.replace(
557-
expected, _count_chunks_initialized=1, _count_bytes_stored=405
558-
)
549+
if shards is None:
550+
expected = dataclasses.replace(
551+
expected, _count_chunks_initialized=4, _count_bytes_stored=501
552+
)
553+
else:
554+
expected = dataclasses.replace(
555+
expected, _count_chunks_initialized=1, _count_bytes_stored=774
556+
)
559557
assert result == expected
560558

561-
async def test_info_v2_async(self) -> None:
559+
async def test_info_v2_async(
560+
self, chunks: tuple[int, int], shards: tuple[int, int] | None
561+
) -> None:
562562
arr = await zarr.api.asynchronous.create_array(
563-
store={}, shape=(8, 8), dtype="f8", chunks=self.chunks, zarr_format=2
563+
store={}, shape=(8, 8), dtype="f8", chunks=chunks, zarr_format=2
564564
)
565565
result = arr.info
566566
expected = ArrayInfo(
@@ -572,65 +572,76 @@ async def test_info_v2_async(self) -> None:
572572
_order="C",
573573
_read_only=False,
574574
_store_type="MemoryStore",
575-
_count_bytes=128,
575+
_count_bytes=512,
576576
_compressor=numcodecs.Zstd(),
577577
)
578578
assert result == expected
579579

580-
async def test_info_v3_async(self) -> None:
580+
async def test_info_v3_async(
581+
self, chunks: tuple[int, int], shards: tuple[int, int] | None
582+
) -> None:
581583
arr = await zarr.api.asynchronous.create_array(
582584
store={},
583585
shape=(8, 8),
584586
dtype="f8",
585-
chunks=self.chunks,
586-
shards=self.shards,
587+
chunks=chunks,
588+
shards=shards,
587589
)
588590
result = arr.info
589591
expected = ArrayInfo(
590592
_zarr_format=3,
591593
_data_type=DataType.parse("float64"),
592594
_shape=(8, 8),
593-
_chunk_shape=self.chunks,
594-
_shard_shape=self.shards,
595+
_chunk_shape=chunks,
596+
_shard_shape=shards,
595597
_order="C",
596598
_read_only=False,
597599
_store_type="MemoryStore",
598-
_codecs=[BytesCodec(), ZstdCodec()],
599-
_count_bytes=128,
600+
_codecs=[BytesCodec(), ZstdCodec()]
601+
if shards is None
602+
else [ShardingCodec(chunk_shape=chunks, codecs=[BytesCodec(), ZstdCodec()])],
603+
_count_bytes=512,
600604
)
601605
assert result == expected
602606

603-
async def test_info_complete_async(self) -> None:
607+
async def test_info_complete_async(
608+
self, chunks: tuple[int, int], shards: tuple[int, int] | None
609+
) -> None:
604610
arr = await zarr.api.asynchronous.create_array(
605611
store={},
606612
dtype="f8",
607613
shape=(8, 8),
608-
chunks=self.chunks,
609-
shards=self.shards,
614+
chunks=chunks,
615+
shards=shards,
610616
compressors=None,
611617
)
612618
result = await arr.info_complete()
613619
expected = ArrayInfo(
614620
_zarr_format=3,
615621
_data_type=DataType.parse("float64"),
616622
_shape=(8, 8),
617-
_chunk_shape=self.chunks,
618-
_shard_shape=self.shards,
623+
_chunk_shape=chunks,
624+
_shard_shape=shards,
619625
_order="C",
620626
_read_only=False,
621627
_store_type="MemoryStore",
622-
_codecs=[BytesCodec()],
623-
_count_bytes=128,
628+
_codecs=[BytesCodec()] if shards is None else [ShardingCodec(chunk_shape=chunks)],
629+
_count_bytes=512,
624630
_count_chunks_initialized=0,
625-
_count_bytes_stored=373, # the metadata?
631+
_count_bytes_stored=373 if shards is None else 578, # the metadata?
626632
)
627633
assert result == expected
628634

629-
await arr.setitem((slice(2), slice(2)), 10)
635+
await arr.setitem((slice(4), slice(4)), 10)
630636
result = await arr.info_complete()
631-
expected = dataclasses.replace(
632-
expected, _count_chunks_initialized=1, _count_bytes_stored=405
633-
)
637+
if shards is None:
638+
expected = dataclasses.replace(
639+
expected, _count_chunks_initialized=4, _count_bytes_stored=501
640+
)
641+
else:
642+
expected = dataclasses.replace(
643+
expected, _count_chunks_initialized=1, _count_bytes_stored=774
644+
)
634645
assert result == expected
635646

636647

tests/test_codecs/test_sharding.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,35 @@ def test_sharding(
7070
assert np.array_equal(data, read_data)
7171

7272

73+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
74+
@pytest.mark.parametrize("index_location", ["start", "end"])
75+
@pytest.mark.parametrize("offset", [0, 10])
76+
def test_sharding_scalar(
77+
store: Store,
78+
index_location: ShardingCodecIndexLocation,
79+
offset: int,
80+
) -> None:
81+
"""
82+
Test that we can create an array with a sharding codec, write data to that array, and get
83+
the same data out via indexing.
84+
"""
85+
spath = StorePath(store)
86+
87+
arr = zarr.create_array(
88+
spath,
89+
shape=(128, 128),
90+
chunks=(32, 32),
91+
shards={"shape": (64, 64), "index_location": index_location},
92+
dtype="uint8",
93+
fill_value=6,
94+
filters=[TransposeCodec(order=order_from_dim("F", 2))],
95+
compressors=BloscCodec(cname="lz4"),
96+
)
97+
arr[:16, :16] = 10 # intentionally write partial chunks
98+
read_data = arr[:16, :16]
99+
np.testing.assert_array_equal(read_data, 10)
100+
101+
73102
@pytest.mark.parametrize("index_location", ["start", "end"])
74103
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
75104
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)