Skip to content

Commit cbc0887

Browse files
brokkoli71normanrz
andauthored
Use config to select implementation (#1982)
* make codec pipeline implementation configurable * add test_config_codec_pipeline_class_in_env * make codec implementation configurable * remove snake case support for class names in config * use registry for codec pipeline config * typing * load codec pipeline from entrypoints * test if configured codec implementation and codec pipeline is used * make ndbuffer implementation configurable * fix circular import * change class method calls on NDBuffer to use get_ndbuffer_class() * make buffer implementation configurable * format * fix tests * ignore mypy in tests * add test to lazy load (nd)buffer from entrypoint * better assertion message * fix merge * fix merge * formatting * fix mypy * fix ruff formatting * fix merge * fix mypy * use numpy_buffer_prototype for reading shard index * rename buffer and entrypoint test-classes * document interaction registry and config * change config prefix from zarr_python to zarr * use fully_qualified_name for implementation config * refactor registry dicts * fix default_buffer_prototype access in tests * allow multiple implementations per entry_point * add tests for multiple implementations per entry_point * fix DeprecationWarning: SelectableGroups in registry.py * fix DeprecationWarning: EntryPoints list interface in registry.py * clarify _collect_entrypoints docstring Co-authored-by: Norman Rzepka <[email protected]> --------- Co-authored-by: Norman Rzepka <[email protected]>
1 parent 325786a commit cbc0887

33 files changed

+802
-228
lines changed

src/zarr/abc/codec.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing_extensions import Self
1818

1919
from zarr.array_spec import ArraySpec
20+
from zarr.common import JSON
2021
from zarr.indexing import SelectorTuple
2122

2223
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
@@ -254,7 +255,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
254255

255256
@classmethod
256257
@abstractmethod
257-
def from_list(cls, codecs: list[Codec]) -> Self:
258+
def from_list(cls, codecs: Iterable[Codec]) -> Self:
258259
"""Creates a codec pipeline from a list of codecs.
259260
260261
Parameters
@@ -390,6 +391,15 @@ async def write(
390391
"""
391392
...
392393

394+
@classmethod
395+
def from_dict(cls, data: Iterable[JSON | Codec]) -> Self:
396+
"""
397+
Create an instance of the model from a dictionary
398+
"""
399+
...
400+
401+
return cls(**data)
402+
393403

394404
async def batching_helper(
395405
func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]],

src/zarr/array.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding
2626
from zarr.codecs import BytesCodec
2727
from zarr.codecs._v2 import V2Compressor, V2Filters
28-
from zarr.codecs.pipeline import BatchedCodecPipeline
2928
from zarr.common import (
3029
JSON,
3130
ZARR_JSON,
@@ -61,6 +60,7 @@
6160
pop_fields,
6261
)
6362
from zarr.metadata import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata
63+
from zarr.registry import get_pipeline_class
6464
from zarr.store import StoreLike, StorePath, make_store_path
6565
from zarr.store.core import (
6666
ensure_no_existing_node,
@@ -79,11 +79,11 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata:
7979
raise TypeError
8080

8181

82-
def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> BatchedCodecPipeline:
82+
def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> CodecPipeline:
8383
if isinstance(metadata, ArrayV3Metadata):
84-
return BatchedCodecPipeline.from_list(metadata.codecs)
84+
return get_pipeline_class().from_list(metadata.codecs)
8585
elif isinstance(metadata, ArrayV2Metadata):
86-
return BatchedCodecPipeline.from_list(
86+
return get_pipeline_class().from_list(
8787
[V2Filters(metadata.filters or []), V2Compressor(metadata.compressor)]
8888
)
8989
else:
@@ -483,8 +483,13 @@ async def _get_selection(
483483
return out_buffer.as_ndarray_like()
484484

485485
async def getitem(
486-
self, selection: BasicSelection, *, prototype: BufferPrototype = default_buffer_prototype
486+
self,
487+
selection: BasicSelection,
488+
*,
489+
prototype: BufferPrototype | None = None,
487490
) -> NDArrayLike:
491+
if prototype is None:
492+
prototype = default_buffer_prototype()
488493
indexer = BasicIndexer(
489494
selection,
490495
shape=self.metadata.shape,
@@ -493,7 +498,7 @@ async def getitem(
493498
return await self._get_selection(indexer, prototype=prototype)
494499

495500
async def _save_metadata(self, metadata: ArrayMetadata) -> None:
496-
to_save = metadata.to_buffer_dict()
501+
to_save = metadata.to_buffer_dict(default_buffer_prototype())
497502
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
498503
await gather(*awaitables)
499504

@@ -545,8 +550,10 @@ async def setitem(
545550
self,
546551
selection: BasicSelection,
547552
value: npt.ArrayLike,
548-
prototype: BufferPrototype = default_buffer_prototype,
553+
prototype: BufferPrototype | None = None,
549554
) -> None:
555+
if prototype is None:
556+
prototype = default_buffer_prototype()
550557
indexer = BasicIndexer(
551558
selection,
552559
shape=self.metadata.shape,
@@ -1001,7 +1008,7 @@ def get_basic_selection(
10011008
selection: BasicSelection = Ellipsis,
10021009
*,
10031010
out: NDBuffer | None = None,
1004-
prototype: BufferPrototype = default_buffer_prototype,
1011+
prototype: BufferPrototype | None = None,
10051012
fields: Fields | None = None,
10061013
) -> NDArrayLike:
10071014
"""Retrieve data for an item or region of the array.
@@ -1108,6 +1115,8 @@ def get_basic_selection(
11081115
11091116
"""
11101117

1118+
if prototype is None:
1119+
prototype = default_buffer_prototype()
11111120
return sync(
11121121
self._async_array._get_selection(
11131122
BasicIndexer(selection, self.shape, self.metadata.chunk_grid),
@@ -1123,7 +1132,7 @@ def set_basic_selection(
11231132
value: npt.ArrayLike,
11241133
*,
11251134
fields: Fields | None = None,
1126-
prototype: BufferPrototype = default_buffer_prototype,
1135+
prototype: BufferPrototype | None = None,
11271136
) -> None:
11281137
"""Modify data for an item or region of the array.
11291138
@@ -1207,6 +1216,8 @@ def set_basic_selection(
12071216
vindex, oindex, blocks, __getitem__, __setitem__
12081217
12091218
"""
1219+
if prototype is None:
1220+
prototype = default_buffer_prototype()
12101221
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
12111222
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
12121223

@@ -1216,7 +1227,7 @@ def get_orthogonal_selection(
12161227
*,
12171228
out: NDBuffer | None = None,
12181229
fields: Fields | None = None,
1219-
prototype: BufferPrototype = default_buffer_prototype,
1230+
prototype: BufferPrototype | None = None,
12201231
) -> NDArrayLike:
12211232
"""Retrieve data by making a selection for each dimension of the array. For
12221233
example, if an array has 2 dimensions, allows selecting specific rows and/or
@@ -1325,6 +1336,8 @@ def get_orthogonal_selection(
13251336
vindex, oindex, blocks, __getitem__, __setitem__
13261337
13271338
"""
1339+
if prototype is None:
1340+
prototype = default_buffer_prototype()
13281341
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
13291342
return sync(
13301343
self._async_array._get_selection(
@@ -1338,7 +1351,7 @@ def set_orthogonal_selection(
13381351
value: npt.ArrayLike,
13391352
*,
13401353
fields: Fields | None = None,
1341-
prototype: BufferPrototype = default_buffer_prototype,
1354+
prototype: BufferPrototype | None = None,
13421355
) -> None:
13431356
"""Modify data via a selection for each dimension of the array.
13441357
@@ -1435,6 +1448,8 @@ def set_orthogonal_selection(
14351448
vindex, oindex, blocks, __getitem__, __setitem__
14361449
14371450
"""
1451+
if prototype is None:
1452+
prototype = default_buffer_prototype()
14381453
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
14391454
return sync(
14401455
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
@@ -1446,7 +1461,7 @@ def get_mask_selection(
14461461
*,
14471462
out: NDBuffer | None = None,
14481463
fields: Fields | None = None,
1449-
prototype: BufferPrototype = default_buffer_prototype,
1464+
prototype: BufferPrototype | None = None,
14501465
) -> NDArrayLike:
14511466
"""Retrieve a selection of individual items, by providing a Boolean array of the
14521467
same shape as the array against which the selection is being made, where True
@@ -1513,6 +1528,8 @@ def get_mask_selection(
15131528
vindex, oindex, blocks, __getitem__, __setitem__
15141529
"""
15151530

1531+
if prototype is None:
1532+
prototype = default_buffer_prototype()
15161533
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
15171534
return sync(
15181535
self._async_array._get_selection(
@@ -1526,7 +1543,7 @@ def set_mask_selection(
15261543
value: npt.ArrayLike,
15271544
*,
15281545
fields: Fields | None = None,
1529-
prototype: BufferPrototype = default_buffer_prototype,
1546+
prototype: BufferPrototype | None = None,
15301547
) -> None:
15311548
"""Modify a selection of individual items, by providing a Boolean array of the
15321549
same shape as the array against which the selection is being made, where True
@@ -1593,6 +1610,8 @@ def set_mask_selection(
15931610
vindex, oindex, blocks, __getitem__, __setitem__
15941611
15951612
"""
1613+
if prototype is None:
1614+
prototype = default_buffer_prototype()
15961615
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
15971616
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
15981617

@@ -1602,7 +1621,7 @@ def get_coordinate_selection(
16021621
*,
16031622
out: NDBuffer | None = None,
16041623
fields: Fields | None = None,
1605-
prototype: BufferPrototype = default_buffer_prototype,
1624+
prototype: BufferPrototype | None = None,
16061625
) -> NDArrayLike:
16071626
"""Retrieve a selection of individual items, by providing the indices
16081627
(coordinates) for each selected item.
@@ -1671,6 +1690,8 @@ def get_coordinate_selection(
16711690
vindex, oindex, blocks, __getitem__, __setitem__
16721691
16731692
"""
1693+
if prototype is None:
1694+
prototype = default_buffer_prototype()
16741695
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
16751696
out_array = sync(
16761697
self._async_array._get_selection(
@@ -1689,7 +1710,7 @@ def set_coordinate_selection(
16891710
value: npt.ArrayLike,
16901711
*,
16911712
fields: Fields | None = None,
1692-
prototype: BufferPrototype = default_buffer_prototype,
1713+
prototype: BufferPrototype | None = None,
16931714
) -> None:
16941715
"""Modify a selection of individual items, by providing the indices (coordinates)
16951716
for each item to be modified.
@@ -1753,6 +1774,8 @@ def set_coordinate_selection(
17531774
vindex, oindex, blocks, __getitem__, __setitem__
17541775
17551776
"""
1777+
if prototype is None:
1778+
prototype = default_buffer_prototype()
17561779
# setup indexer
17571780
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
17581781

@@ -1776,7 +1799,7 @@ def get_block_selection(
17761799
*,
17771800
out: NDBuffer | None = None,
17781801
fields: Fields | None = None,
1779-
prototype: BufferPrototype = default_buffer_prototype,
1802+
prototype: BufferPrototype | None = None,
17801803
) -> NDArrayLike:
17811804
"""Retrieve a selection of individual items, by providing the indices
17821805
(coordinates) for each selected item.
@@ -1859,6 +1882,8 @@ def get_block_selection(
18591882
vindex, oindex, blocks, __getitem__, __setitem__
18601883
18611884
"""
1885+
if prototype is None:
1886+
prototype = default_buffer_prototype()
18621887
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
18631888
return sync(
18641889
self._async_array._get_selection(
@@ -1872,7 +1897,7 @@ def set_block_selection(
18721897
value: npt.ArrayLike,
18731898
*,
18741899
fields: Fields | None = None,
1875-
prototype: BufferPrototype = default_buffer_prototype,
1900+
prototype: BufferPrototype | None = None,
18761901
) -> None:
18771902
"""Modify a selection of individual blocks, by providing the chunk indices
18781903
(coordinates) for each block to be modified.
@@ -1950,6 +1975,8 @@ def set_block_selection(
19501975
vindex, oindex, blocks, __getitem__, __setitem__
19511976
19521977
"""
1978+
if prototype is None:
1979+
prototype = default_buffer_prototype()
19531980
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
19541981
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
19551982

src/zarr/buffer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
import numpy.typing as npt
1717

1818
from zarr.common import ChunkCoords
19+
from zarr.registry import (
20+
get_buffer_class,
21+
get_ndbuffer_class,
22+
register_buffer,
23+
register_ndbuffer,
24+
)
1925

2026
if TYPE_CHECKING:
2127
from typing_extensions import Self
@@ -479,4 +485,14 @@ class BufferPrototype(NamedTuple):
479485

480486

481487
# The default buffer prototype used throughout the Zarr codebase.
482-
default_buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)
488+
def default_buffer_prototype() -> BufferPrototype:
489+
return BufferPrototype(buffer=get_buffer_class(), nd_buffer=get_ndbuffer_class())
490+
491+
492+
# The numpy prototype used for E.g. when reading the shard index
493+
def numpy_buffer_prototype() -> BufferPrototype:
494+
return BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)
495+
496+
497+
register_buffer(Buffer)
498+
register_ndbuffer(NDBuffer)

src/zarr/codecs/_v2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec
99
from zarr.array_spec import ArraySpec
10-
from zarr.buffer import Buffer, NDBuffer
10+
from zarr.buffer import Buffer, NDBuffer, default_buffer_prototype
1111
from zarr.common import JSON, to_thread
12+
from zarr.registry import get_ndbuffer_class
1213

1314

1415
@dataclass(frozen=True)
@@ -34,7 +35,7 @@ async def _decode_single(
3435
if str(chunk_numpy_array.dtype) != chunk_spec.dtype:
3536
chunk_numpy_array = chunk_numpy_array.view(chunk_spec.dtype)
3637

37-
return NDBuffer.from_numpy_array(chunk_numpy_array)
38+
return get_ndbuffer_class().from_numpy_array(chunk_numpy_array)
3839

3940
async def _encode_single(
4041
self,
@@ -55,7 +56,7 @@ async def _encode_single(
5556
else:
5657
encoded_chunk_bytes = ensure_bytes(chunk_numpy_array)
5758

58-
return Buffer.from_bytes(encoded_chunk_bytes)
59+
return default_buffer_prototype().buffer.from_bytes(encoded_chunk_bytes)
5960

6061
def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int:
6162
raise NotImplementedError
@@ -86,7 +87,7 @@ async def _decode_single(
8687
order=chunk_spec.order,
8788
)
8889

89-
return NDBuffer.from_ndarray_like(chunk_ndarray)
90+
return get_ndbuffer_class().from_ndarray_like(chunk_ndarray)
9091

9192
async def _encode_single(
9293
self,
@@ -99,7 +100,7 @@ async def _encode_single(
99100
filter = numcodecs.get_codec(filter_metadata)
100101
chunk_ndarray = await to_thread(filter.encode, chunk_ndarray)
101102

102-
return NDBuffer.from_ndarray_like(chunk_ndarray)
103+
return get_ndbuffer_class().from_ndarray_like(chunk_ndarray)
103104

104105
def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int:
105106
raise NotImplementedError

src/zarr/codecs/blosc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from zarr.abc.codec import BytesBytesCodec
1212
from zarr.array_spec import ArraySpec
1313
from zarr.buffer import Buffer, as_numpy_array_wrapper
14-
from zarr.codecs.registry import register_codec
1514
from zarr.common import JSON, parse_enum, parse_named_configuration, to_thread
15+
from zarr.registry import register_codec
1616

1717
if TYPE_CHECKING:
1818
from typing_extensions import Self

src/zarr/codecs/bytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from zarr.abc.codec import ArrayBytesCodec
1111
from zarr.array_spec import ArraySpec
1212
from zarr.buffer import Buffer, NDArrayLike, NDBuffer
13-
from zarr.codecs.registry import register_codec
1413
from zarr.common import JSON, parse_enum, parse_named_configuration
14+
from zarr.registry import register_codec
1515

1616
if TYPE_CHECKING:
1717
from typing_extensions import Self

src/zarr/codecs/crc32c_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from zarr.abc.codec import BytesBytesCodec
1010
from zarr.array_spec import ArraySpec
1111
from zarr.buffer import Buffer
12-
from zarr.codecs.registry import register_codec
1312
from zarr.common import JSON, parse_named_configuration
13+
from zarr.registry import register_codec
1414

1515
if TYPE_CHECKING:
1616
from typing_extensions import Self

src/zarr/codecs/gzip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from zarr.abc.codec import BytesBytesCodec
99
from zarr.array_spec import ArraySpec
1010
from zarr.buffer import Buffer, as_numpy_array_wrapper
11-
from zarr.codecs.registry import register_codec
1211
from zarr.common import JSON, parse_named_configuration, to_thread
12+
from zarr.registry import register_codec
1313

1414
if TYPE_CHECKING:
1515
from typing_extensions import Self

0 commit comments

Comments
 (0)