Skip to content

Commit a1ceba5

Browse files
committed
Refactor Codec interface
This refactors the interface of `Codec.read` and `Codec.write` to move from an `Iterable[tuple[...]` to an `Iterable[BatchInfo]`. Two things motivate this change 1. Readability: I struggle to remember what the 4th member of these complex tuples. Having the name `info.out_selection` to remind me is helpful. 2. Possible future-proofing: right now, any change to the interface is a hard break since the number of elements in the tuple will change. There may be a class of changes to the interface where we can add additional information to `BatchInfo` without breaking backwards compatibility. I don't want to oversell motivaiton 2 though. If something is important enough to add to the interface, then presumably we expectd implementations to, you know, use it.
1 parent b873691 commit a1ceba5

File tree

6 files changed

+249
-114
lines changed

6 files changed

+249
-114
lines changed

src/zarr/abc/codec.py

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from collections.abc import Mapping
5-
from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar
4+
from collections.abc import Iterator, Mapping
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar
67

78
from typing_extensions import ReadOnly, TypedDict
89

910
from zarr.abc.metadata import Metadata
11+
from zarr.abc.store import ByteGetter, ByteSetter
1012
from zarr.core.buffer import Buffer, NDBuffer
1113
from zarr.core.common import NamedConfig, concurrent_map
1214
from zarr.core.config import config
@@ -15,7 +17,7 @@
1517
from collections.abc import Awaitable, Callable, Iterable
1618
from typing import Self
1719

18-
from zarr.abc.store import ByteGetter, ByteSetter, Store
20+
from zarr.abc.store import Store
1921
from zarr.core.array_spec import ArraySpec
2022
from zarr.core.chunk_grids import ChunkGrid
2123
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
@@ -28,10 +30,13 @@
2830
"ArrayBytesCodecPartialDecodeMixin",
2931
"ArrayBytesCodecPartialEncodeMixin",
3032
"BaseCodec",
33+
"BatchInfo",
3134
"BytesBytesCodec",
3235
"CodecInput",
3336
"CodecOutput",
3437
"CodecPipeline",
38+
"ReadBatchInfo",
39+
"WriteBatchInfo",
3540
]
3641

3742
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
@@ -59,6 +64,58 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
5964
"""The widest type of JSON-like input that could specify a codec."""
6065

6166

67+
TByteOperator = TypeVar("TByteOperator", bound="ByteGetter")
68+
69+
70+
@dataclass(frozen=True)
71+
class BatchInfo(Generic[TByteOperator]):
72+
"""Information about a chunk to be read/written from/to the store.
73+
74+
This class is generic over the byte operator type:
75+
- BatchInfo[ByteGetter] (aliased as ReadBatchInfo) for read operations
76+
- BatchInfo[ByteSetter] (aliased as WriteBatchInfo) for write operations
77+
78+
Attributes
79+
----------
80+
byte_operator : TByteOperator
81+
Used to fetch/write the chunk bytes from/to the store.
82+
For reads, this is a ByteGetter. For writes, this is a ByteSetter.
83+
array_spec : ArraySpec
84+
Specification of the chunk array (shape, dtype, fill value, etc.).
85+
chunk_selection : SelectorTuple
86+
Slice selection determining which parts of the chunk to read/encode.
87+
out_selection : SelectorTuple
88+
Slice selection determining where in the output/value array the chunk data will be written/is located.
89+
is_complete_chunk : bool
90+
Whether this represents a complete chunk (vs. a partial chunk at array boundaries).
91+
"""
92+
93+
byte_operator: TByteOperator
94+
array_spec: ArraySpec
95+
chunk_selection: SelectorTuple
96+
out_selection: SelectorTuple
97+
is_complete_chunk: bool
98+
99+
def __iter__(self) -> Iterator[Any]:
100+
"""Iterate over fields for backwards compatibility with tuple unpacking."""
101+
yield self.byte_operator
102+
yield self.array_spec
103+
yield self.chunk_selection
104+
yield self.out_selection
105+
yield self.is_complete_chunk
106+
107+
def __getitem__(self, index: int) -> Any:
108+
"""Index access for backwards compatibility with tuple indexing."""
109+
return list(self)[index]
110+
111+
112+
ReadBatchInfo: TypeAlias = BatchInfo[ByteGetter]
113+
"""Information about a chunk to be read from the store and decoded."""
114+
115+
WriteBatchInfo: TypeAlias = BatchInfo[ByteSetter]
116+
"""Information about a chunk to be encoded and written to the store."""
117+
118+
62119
class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]):
63120
"""Generic base class for codecs.
64121
@@ -412,7 +469,7 @@ async def encode(
412469
@abstractmethod
413470
async def read(
414471
self,
415-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
472+
batch_info: Iterable[ReadBatchInfo],
416473
out: NDBuffer,
417474
drop_axes: tuple[int, ...] = (),
418475
) -> None:
@@ -421,25 +478,24 @@ async def read(
421478
422479
Parameters
423480
----------
424-
batch_info : Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]]
481+
batch_info : Iterable[ReadBatchInfo]
425482
Ordered set of information about the chunks.
426-
The first slice selection determines which parts of the chunk will be fetched.
427-
The second slice selection determines where in the output array the chunk data will be written.
428-
The ByteGetter is used to fetch the necessary bytes.
429-
The chunk spec contains information about the construction of an array from the bytes.
483+
See ReadBatchInfo for details on the fields.
430484
431485
If the Store returns ``None`` for a chunk, then the chunk was not
432486
written and the implementation must set the values of that chunk (or
433487
``out``) to the fill value for the array.
434488
435489
out : NDBuffer
490+
drop_axes : tuple[int, ...]
491+
Axes to drop from the chunk data when writing to the output array.
436492
"""
437493
...
438494

439495
@abstractmethod
440496
async def write(
441497
self,
442-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
498+
batch_info: Iterable[WriteBatchInfo],
443499
value: NDBuffer,
444500
drop_axes: tuple[int, ...] = (),
445501
) -> None:
@@ -449,13 +505,13 @@ async def write(
449505
450506
Parameters
451507
----------
452-
batch_info : Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]]
508+
batch_info : Iterable[WriteBatchInfo]
453509
Ordered set of information about the chunks.
454-
The first slice selection determines which parts of the chunk will be encoded.
455-
The second slice selection determines where in the value array the chunk data is located.
456-
The ByteSetter is used to fetch and write the necessary bytes.
457-
The chunk spec contains information about the chunk.
510+
See WriteBatchInfo for details on the fields.
458511
value : NDBuffer
512+
The data to write.
513+
drop_axes : tuple[int, ...]
514+
Axes to drop from the chunk data when reading from the value array.
459515
"""
460516
...
461517

src/zarr/codecs/sharding.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
ArrayBytesCodecPartialEncodeMixin,
1717
Codec,
1818
CodecPipeline,
19+
ReadBatchInfo,
20+
WriteBatchInfo,
1921
)
2022
from zarr.abc.store import (
2123
ByteGetter,
@@ -358,12 +360,12 @@ async def _decode_single(
358360
# decoding chunks and writing them into the output buffer
359361
await self.codec_pipeline.read(
360362
[
361-
(
362-
_ShardingByteGetter(shard_dict, chunk_coords),
363-
chunk_spec,
364-
chunk_selection,
365-
out_selection,
366-
is_complete_shard,
363+
ReadBatchInfo(
364+
byte_operator=_ShardingByteGetter(shard_dict, chunk_coords),
365+
array_spec=chunk_spec,
366+
chunk_selection=chunk_selection,
367+
out_selection=out_selection,
368+
is_complete_chunk=is_complete_shard,
367369
)
368370
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
369371
],
@@ -430,12 +432,12 @@ async def _decode_partial_single(
430432
# decoding chunks and writing them into the output buffer
431433
await self.codec_pipeline.read(
432434
[
433-
(
434-
_ShardingByteGetter(shard_dict, chunk_coords),
435-
chunk_spec,
436-
chunk_selection,
437-
out_selection,
438-
is_complete_shard,
435+
ReadBatchInfo(
436+
byte_operator=_ShardingByteGetter(shard_dict, chunk_coords),
437+
array_spec=chunk_spec,
438+
chunk_selection=chunk_selection,
439+
out_selection=out_selection,
440+
is_complete_chunk=is_complete_shard,
439441
)
440442
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
441443
],
@@ -469,12 +471,12 @@ async def _encode_single(
469471

470472
await self.codec_pipeline.write(
471473
[
472-
(
473-
_ShardingByteSetter(shard_builder, chunk_coords),
474-
chunk_spec,
475-
chunk_selection,
476-
out_selection,
477-
is_complete_shard,
474+
WriteBatchInfo(
475+
byte_operator=_ShardingByteSetter(shard_builder, chunk_coords),
476+
array_spec=chunk_spec,
477+
chunk_selection=chunk_selection,
478+
out_selection=out_selection,
479+
is_complete_chunk=is_complete_shard,
478480
)
479481
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
480482
],
@@ -515,12 +517,12 @@ async def _encode_partial_single(
515517

516518
await self.codec_pipeline.write(
517519
[
518-
(
519-
_ShardingByteSetter(shard_dict, chunk_coords),
520-
chunk_spec,
521-
chunk_selection,
522-
out_selection,
523-
is_complete_shard,
520+
WriteBatchInfo(
521+
byte_operator=_ShardingByteSetter(shard_dict, chunk_coords),
522+
array_spec=chunk_spec,
523+
chunk_selection=chunk_selection,
524+
out_selection=out_selection,
525+
is_complete_chunk=is_complete_shard,
524526
)
525527
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
526528
],

src/zarr/core/array.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
from typing_extensions import deprecated
2424

2525
import zarr
26-
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
26+
from zarr.abc.codec import (
27+
ArrayArrayCodec,
28+
ArrayBytesCodec,
29+
BytesBytesCodec,
30+
Codec,
31+
ReadBatchInfo,
32+
WriteBatchInfo,
33+
)
2734
from zarr.abc.numcodec import Numcodec, _is_numcodec
2835
from zarr.codecs._v2 import V2Codec
2936
from zarr.codecs.bytes import BytesCodec
@@ -1564,12 +1571,15 @@ async def _get_selection(
15641571
# reading chunks and decoding them
15651572
await self.codec_pipeline.read(
15661573
[
1567-
(
1568-
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
1569-
self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
1570-
chunk_selection,
1571-
out_selection,
1572-
is_complete_chunk,
1574+
ReadBatchInfo(
1575+
byte_operator=self.store_path
1576+
/ self.metadata.encode_chunk_key(chunk_coords),
1577+
array_spec=self.metadata.get_chunk_spec(
1578+
chunk_coords, _config, prototype=prototype
1579+
),
1580+
chunk_selection=chunk_selection,
1581+
out_selection=out_selection,
1582+
is_complete_chunk=is_complete_chunk,
15731583
)
15741584
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
15751585
],
@@ -1735,12 +1745,12 @@ async def _set_selection(
17351745
# merging with existing data and encoding chunks
17361746
await self.codec_pipeline.write(
17371747
[
1738-
(
1739-
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
1740-
self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
1741-
chunk_selection,
1742-
out_selection,
1743-
is_complete_chunk,
1748+
WriteBatchInfo(
1749+
byte_operator=self.store_path / self.metadata.encode_chunk_key(chunk_coords),
1750+
array_spec=self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
1751+
chunk_selection=chunk_selection,
1752+
out_selection=out_selection,
1753+
is_complete_chunk=is_complete_chunk,
17441754
)
17451755
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
17461756
],

0 commit comments

Comments
 (0)