Skip to content

Commit 0c513fc

Browse files
authored
Followup on codecs (#1889)
1 parent 549cf28 commit 0c513fc

File tree

12 files changed

+128
-201
lines changed

12 files changed

+128
-201
lines changed

src/zarr/abc/codec.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from collections.abc import Iterable
4+
from collections.abc import Awaitable, Callable, Iterable
55
from typing import TYPE_CHECKING, Generic, TypeVar
66

77
from zarr.abc.metadata import Metadata
88
from zarr.abc.store import ByteGetter, ByteSetter
99
from zarr.buffer import Buffer, NDBuffer
10+
from zarr.common import concurrent_map
11+
from zarr.config import config
1012

1113
if TYPE_CHECKING:
1214
from typing_extensions import Self
@@ -59,7 +61,7 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
5961
"""
6062
return chunk_spec
6163

62-
def evolve(self, array_spec: ArraySpec) -> Self:
64+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
6365
"""Fills in codec configuration parameters that can be automatically
6466
inferred from the array metadata.
6567
@@ -83,7 +85,9 @@ def validate(self, array_metadata: ArrayMetadata) -> None:
8385
"""
8486
...
8587

86-
@abstractmethod
88+
async def _decode_single(self, chunk_data: CodecOutput, chunk_spec: ArraySpec) -> CodecInput:
89+
raise NotImplementedError
90+
8791
async def decode(
8892
self,
8993
chunks_and_specs: Iterable[tuple[CodecOutput | None, ArraySpec]],
@@ -100,9 +104,13 @@ async def decode(
100104
-------
101105
Iterable[CodecInput | None]
102106
"""
103-
...
107+
return await batching_helper(self._decode_single, chunks_and_specs)
108+
109+
async def _encode_single(
110+
self, chunk_data: CodecInput, chunk_spec: ArraySpec
111+
) -> CodecOutput | None:
112+
raise NotImplementedError
104113

105-
@abstractmethod
106114
async def encode(
107115
self,
108116
chunks_and_specs: Iterable[tuple[CodecInput | None, ArraySpec]],
@@ -119,7 +127,7 @@ async def encode(
119127
-------
120128
Iterable[CodecOutput | None]
121129
"""
122-
...
130+
return await batching_helper(self._encode_single, chunks_and_specs)
123131

124132

125133
class ArrayArrayCodec(_Codec[NDBuffer, NDBuffer]):
@@ -146,7 +154,11 @@ class BytesBytesCodec(_Codec[Buffer, Buffer]):
146154
class ArrayBytesCodecPartialDecodeMixin:
147155
"""Mixin for array-to-bytes codecs that implement partial decoding."""
148156

149-
@abstractmethod
157+
async def _decode_partial_single(
158+
self, byte_getter: ByteGetter, selection: SliceSelection, chunk_spec: ArraySpec
159+
) -> NDBuffer | None:
160+
raise NotImplementedError
161+
150162
async def decode_partial(
151163
self,
152164
batch_info: Iterable[tuple[ByteGetter, SliceSelection, ArraySpec]],
@@ -167,13 +179,28 @@ async def decode_partial(
167179
-------
168180
Iterable[NDBuffer | None]
169181
"""
170-
...
182+
return await concurrent_map(
183+
[
184+
(byte_getter, selection, chunk_spec)
185+
for byte_getter, selection, chunk_spec in batch_info
186+
],
187+
self._decode_partial_single,
188+
config.get("async.concurrency"),
189+
)
171190

172191

173192
class ArrayBytesCodecPartialEncodeMixin:
174193
"""Mixin for array-to-bytes codecs that implement partial encoding."""
175194

176-
@abstractmethod
195+
async def _encode_partial_single(
196+
self,
197+
byte_setter: ByteSetter,
198+
chunk_array: NDBuffer,
199+
selection: SliceSelection,
200+
chunk_spec: ArraySpec,
201+
) -> None:
202+
raise NotImplementedError
203+
177204
async def encode_partial(
178205
self,
179206
batch_info: Iterable[tuple[ByteSetter, NDBuffer, SliceSelection, ArraySpec]],
@@ -192,7 +219,14 @@ async def encode_partial(
192219
The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data.
193220
The chunk spec contains information about the chunk.
194221
"""
195-
...
222+
await concurrent_map(
223+
[
224+
(byte_setter, chunk_array, selection, chunk_spec)
225+
for byte_setter, chunk_array, selection, chunk_spec in batch_info
226+
],
227+
self._encode_partial_single,
228+
config.get("async.concurrency"),
229+
)
196230

197231

198232
class CodecPipeline(Metadata):
@@ -203,7 +237,7 @@ class CodecPipeline(Metadata):
203237
and writes them to a store (via ByteSetter)."""
204238

205239
@abstractmethod
206-
def evolve(self, array_spec: ArraySpec) -> Self:
240+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
207241
"""Fills in codec configuration parameters that can be automatically
208242
inferred from the array metadata.
209243
@@ -347,3 +381,25 @@ async def write(
347381
value : NDBuffer
348382
"""
349383
...
384+
385+
386+
async def batching_helper(
387+
func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]],
388+
batch_info: Iterable[tuple[CodecInput | None, ArraySpec]],
389+
) -> list[CodecOutput | None]:
390+
return await concurrent_map(
391+
[(chunk_array, chunk_spec) for chunk_array, chunk_spec in batch_info],
392+
noop_for_none(func),
393+
config.get("async.concurrency"),
394+
)
395+
396+
397+
def noop_for_none(
398+
func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]],
399+
) -> Callable[[CodecInput | None, ArraySpec], Awaitable[CodecOutput | None]]:
400+
async def wrap(chunk: CodecInput | None, chunk_spec: ArraySpec) -> CodecOutput | None:
401+
if chunk is None:
402+
return None
403+
return await func(chunk, chunk_spec)
404+
405+
return wrap

src/zarr/codecs/_v2.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
import numcodecs
66
from numcodecs.compat import ensure_bytes, ensure_ndarray
77

8+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec
89
from zarr.buffer import Buffer, NDBuffer
9-
from zarr.codecs.mixins import ArrayArrayCodecBatchMixin, ArrayBytesCodecBatchMixin
1010
from zarr.common import JSON, ArraySpec, to_thread
1111

1212

1313
@dataclass(frozen=True)
14-
class V2Compressor(ArrayBytesCodecBatchMixin):
14+
class V2Compressor(ArrayBytesCodec):
1515
compressor: dict[str, JSON] | None
1616

1717
is_fixed_size = False
1818

19-
async def decode_single(
19+
async def _decode_single(
2020
self,
2121
chunk_bytes: Buffer,
2222
chunk_spec: ArraySpec,
@@ -38,7 +38,7 @@ async def decode_single(
3838

3939
return NDBuffer.from_numpy_array(chunk_numpy_array)
4040

41-
async def encode_single(
41+
async def _encode_single(
4242
self,
4343
chunk_array: NDBuffer,
4444
_chunk_spec: ArraySpec,
@@ -64,44 +64,44 @@ def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec)
6464

6565

6666
@dataclass(frozen=True)
67-
class V2Filters(ArrayArrayCodecBatchMixin):
67+
class V2Filters(ArrayArrayCodec):
6868
filters: list[dict[str, JSON]]
6969

7070
is_fixed_size = False
7171

72-
async def decode_single(
72+
async def _decode_single(
7373
self,
7474
chunk_array: NDBuffer,
7575
chunk_spec: ArraySpec,
7676
) -> NDBuffer:
77-
chunk_numpy_array = chunk_array.as_numpy_array()
77+
chunk_ndarray = chunk_array.as_ndarray_like()
7878
# apply filters in reverse order
7979
if self.filters is not None:
8080
for filter_metadata in self.filters[::-1]:
8181
filter = numcodecs.get_codec(filter_metadata)
82-
chunk_numpy_array = await to_thread(filter.decode, chunk_numpy_array)
82+
chunk_ndarray = await to_thread(filter.decode, chunk_ndarray)
8383

8484
# ensure correct chunk shape
85-
if chunk_numpy_array.shape != chunk_spec.shape:
86-
chunk_numpy_array = chunk_numpy_array.reshape(
85+
if chunk_ndarray.shape != chunk_spec.shape:
86+
chunk_ndarray = chunk_ndarray.reshape(
8787
chunk_spec.shape,
8888
order=chunk_spec.order,
8989
)
9090

91-
return NDBuffer.from_numpy_array(chunk_numpy_array)
91+
return NDBuffer.from_ndarray_like(chunk_ndarray)
9292

93-
async def encode_single(
93+
async def _encode_single(
9494
self,
9595
chunk_array: NDBuffer,
9696
chunk_spec: ArraySpec,
9797
) -> NDBuffer | None:
98-
chunk_numpy_array = chunk_array.as_numpy_array().ravel(order=chunk_spec.order)
98+
chunk_ndarray = chunk_array.as_ndarray_like().ravel(order=chunk_spec.order)
9999

100100
for filter_metadata in self.filters:
101101
filter = numcodecs.get_codec(filter_metadata)
102-
chunk_numpy_array = await to_thread(filter.encode, chunk_numpy_array)
102+
chunk_ndarray = await to_thread(filter.encode, chunk_ndarray)
103103

104-
return NDBuffer.from_numpy_array(chunk_numpy_array)
104+
return NDBuffer.from_ndarray_like(chunk_ndarray)
105105

106106
def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int:
107107
raise NotImplementedError

src/zarr/codecs/blosc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import numcodecs
99
from numcodecs.blosc import Blosc
1010

11+
from zarr.abc.codec import BytesBytesCodec
1112
from zarr.buffer import Buffer, as_numpy_array_wrapper
12-
from zarr.codecs.mixins import BytesBytesCodecBatchMixin
1313
from zarr.codecs.registry import register_codec
1414
from zarr.common import parse_enum, parse_named_configuration, to_thread
1515

@@ -74,7 +74,7 @@ def parse_blocksize(data: JSON) -> int:
7474

7575

7676
@dataclass(frozen=True)
77-
class BloscCodec(BytesBytesCodecBatchMixin):
77+
class BloscCodec(BytesBytesCodec):
7878
is_fixed_size = False
7979

8080
typesize: int
@@ -125,7 +125,7 @@ def to_dict(self) -> dict[str, JSON]:
125125
},
126126
}
127127

128-
def evolve(self, array_spec: ArraySpec) -> Self:
128+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
129129
new_codec = self
130130
if new_codec.typesize is None:
131131
new_codec = replace(new_codec, typesize=array_spec.dtype.itemsize)
@@ -158,14 +158,14 @@ def _blosc_codec(self) -> Blosc:
158158
}
159159
return Blosc.from_config(config_dict)
160160

161-
async def decode_single(
161+
async def _decode_single(
162162
self,
163163
chunk_bytes: Buffer,
164164
_chunk_spec: ArraySpec,
165165
) -> Buffer:
166166
return await to_thread(as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes)
167167

168-
async def encode_single(
168+
async def _encode_single(
169169
self,
170170
chunk_bytes: Buffer,
171171
chunk_spec: ArraySpec,

src/zarr/codecs/bytes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import numpy as np
99

10+
from zarr.abc.codec import ArrayBytesCodec
1011
from zarr.buffer import Buffer, NDBuffer
11-
from zarr.codecs.mixins import ArrayBytesCodecBatchMixin
1212
from zarr.codecs.registry import register_codec
1313
from zarr.common import parse_enum, parse_named_configuration
1414

@@ -27,7 +27,7 @@ class Endian(Enum):
2727

2828

2929
@dataclass(frozen=True)
30-
class BytesCodec(ArrayBytesCodecBatchMixin):
30+
class BytesCodec(ArrayBytesCodec):
3131
is_fixed_size = True
3232

3333
endian: Endian | None
@@ -51,7 +51,7 @@ def to_dict(self) -> dict[str, JSON]:
5151
else:
5252
return {"name": "bytes", "configuration": {"endian": self.endian}}
5353

54-
def evolve(self, array_spec: ArraySpec) -> Self:
54+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
5555
if array_spec.dtype.itemsize == 0:
5656
if self.endian is not None:
5757
return replace(self, endian=None)
@@ -61,7 +61,7 @@ def evolve(self, array_spec: ArraySpec) -> Self:
6161
)
6262
return self
6363

64-
async def decode_single(
64+
async def _decode_single(
6565
self,
6666
chunk_bytes: Buffer,
6767
chunk_spec: ArraySpec,
@@ -84,7 +84,7 @@ async def decode_single(
8484
)
8585
return chunk_array
8686

87-
async def encode_single(
87+
async def _encode_single(
8888
self,
8989
chunk_array: NDBuffer,
9090
_chunk_spec: ArraySpec,

src/zarr/codecs/crc32c_.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import numpy as np
77
from crc32c import crc32c
88

9+
from zarr.abc.codec import BytesBytesCodec
910
from zarr.buffer import Buffer
10-
from zarr.codecs.mixins import BytesBytesCodecBatchMixin
1111
from zarr.codecs.registry import register_codec
1212
from zarr.common import parse_named_configuration
1313

@@ -18,7 +18,7 @@
1818

1919

2020
@dataclass(frozen=True)
21-
class Crc32cCodec(BytesBytesCodecBatchMixin):
21+
class Crc32cCodec(BytesBytesCodec):
2222
is_fixed_size = True
2323

2424
@classmethod
@@ -29,7 +29,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
2929
def to_dict(self) -> dict[str, JSON]:
3030
return {"name": "crc32c"}
3131

32-
async def decode_single(
32+
async def _decode_single(
3333
self,
3434
chunk_bytes: Buffer,
3535
_chunk_spec: ArraySpec,
@@ -46,7 +46,7 @@ async def decode_single(
4646
)
4747
return Buffer.from_array_like(inner_bytes)
4848

49-
async def encode_single(
49+
async def _encode_single(
5050
self,
5151
chunk_bytes: Buffer,
5252
_chunk_spec: ArraySpec,

0 commit comments

Comments
 (0)