Skip to content

Commit 6805332

Browse files
RFLeijenaard-v-b
andauthored
Add registry for chunk key encodings for extensibility (#3436)
* Add registry for chunk key encodings. * Fix error message for unknown chunk key encoding in create_array test * Removed unneccsary type ignore * Use entrypoint.name as the key for registering chunk key encodings. - Change register_chunk_key_encoding function to take key as first arg similar to codec. * Move parsing of init args in CKE to __post_init__. This enables users to add additional fields to a custom ChunkKeyEncoding without having to override __init__ and taking care of immutability of the attrs. * Clarify ChunkKeyEncoding base class - Enforce encode_chunk_key to be implemented (abstractmethod in ABC) - Make decode_chunk_key optional (raise NotImplementedError by default) Note, the latter is never raised by the current zarr implementation. * Make `name` a ClassVar in ChunkKeyEncoding. This automatically removes it as an init argument. * Remove `separator` from ChunkKeyEncoding base. * Fix typing errors. * Update docs output to match code changes. * Add release notes --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent 3c883a3 commit 6805332

File tree

9 files changed

+146
-94
lines changed

9 files changed

+146
-94
lines changed

changes/3436.feature.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Adds a registry for chunk key encodings for extensibility.
2+
This allows users to implement a custom `ChunkKeyEncoding`, which can be registered via `register_chunk_key_encoding` or as an entry point under `zarr.chunk_key_encoding`.

docs/user-guide/consolidated_metadata.rst

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,44 +49,41 @@ that can be used.:
4949
>>> from pprint import pprint
5050
>>> pprint(dict(consolidated_metadata.items()))
5151
{'a': ArrayV3Metadata(shape=(1,),
52-
data_type=Float64(endianness='little'),
53-
chunk_grid=RegularChunkGrid(chunk_shape=(1,)),
54-
chunk_key_encoding=DefaultChunkKeyEncoding(name='default',
55-
separator='/'),
56-
fill_value=np.float64(0.0),
57-
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
58-
ZstdCodec(level=0, checksum=False)),
59-
attributes={},
60-
dimension_names=None,
61-
zarr_format=3,
62-
node_type='array',
63-
storage_transformers=()),
64-
'b': ArrayV3Metadata(shape=(2, 2),
65-
data_type=Float64(endianness='little'),
66-
chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)),
67-
chunk_key_encoding=DefaultChunkKeyEncoding(name='default',
68-
separator='/'),
69-
fill_value=np.float64(0.0),
70-
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
71-
ZstdCodec(level=0, checksum=False)),
72-
attributes={},
73-
dimension_names=None,
74-
zarr_format=3,
75-
node_type='array',
76-
storage_transformers=()),
77-
'c': ArrayV3Metadata(shape=(3, 3, 3),
78-
data_type=Float64(endianness='little'),
79-
chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)),
80-
chunk_key_encoding=DefaultChunkKeyEncoding(name='default',
81-
separator='/'),
82-
fill_value=np.float64(0.0),
83-
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
84-
ZstdCodec(level=0, checksum=False)),
85-
attributes={},
86-
dimension_names=None,
87-
zarr_format=3,
88-
node_type='array',
89-
storage_transformers=())}
52+
data_type=Float64(endianness='little'),
53+
chunk_grid=RegularChunkGrid(chunk_shape=(1,)),
54+
chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'),
55+
fill_value=np.float64(0.0),
56+
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
57+
ZstdCodec(level=0, checksum=False)),
58+
attributes={},
59+
dimension_names=None,
60+
zarr_format=3,
61+
node_type='array',
62+
storage_transformers=()),
63+
'b': ArrayV3Metadata(shape=(2, 2),
64+
data_type=Float64(endianness='little'),
65+
chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)),
66+
chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'),
67+
fill_value=np.float64(0.0),
68+
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
69+
ZstdCodec(level=0, checksum=False)),
70+
attributes={},
71+
dimension_names=None,
72+
zarr_format=3,
73+
node_type='array',
74+
storage_transformers=()),
75+
'c': ArrayV3Metadata(shape=(3, 3, 3),
76+
data_type=Float64(endianness='little'),
77+
chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)),
78+
chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'),
79+
fill_value=np.float64(0.0),
80+
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
81+
ZstdCodec(level=0, checksum=False)),
82+
attributes={},
83+
dimension_names=None,
84+
zarr_format=3,
85+
node_type='array',
86+
storage_transformers=())}
9087

9188
Operations on the group to get children automatically use the consolidated metadata.:
9289

src/zarr/core/array.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ChunkKeyEncodingLike,
4848
DefaultChunkKeyEncoding,
4949
V2ChunkKeyEncoding,
50+
parse_chunk_key_encoding,
5051
)
5152
from zarr.core.common import (
5253
JSON,
@@ -4602,6 +4603,7 @@ async def init_array(
46024603
order_parsed = zarr_config.get("array.order")
46034604
else:
46044605
order_parsed = order
4606+
chunk_key_encoding_parsed = cast("V2ChunkKeyEncoding", chunk_key_encoding_parsed)
46054607

46064608
meta = AsyncArray._create_metadata_v2(
46074609
shape=shape_parsed,
@@ -4951,13 +4953,11 @@ def _parse_chunk_key_encoding(
49514953
"""
49524954
if data is None:
49534955
if zarr_format == 2:
4954-
result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "."})
4956+
data = {"name": "v2", "configuration": {"separator": "."}}
49554957
else:
4956-
result = ChunkKeyEncoding.from_dict({"name": "default", "separator": "/"})
4957-
elif isinstance(data, ChunkKeyEncoding):
4958-
result = data
4959-
else:
4960-
result = ChunkKeyEncoding.from_dict(data)
4958+
data = {"name": "default", "configuration": {"separator": "/"}}
4959+
result = parse_chunk_key_encoding(data)
4960+
49614961
if zarr_format == 2 and result.name != "v2":
49624962
msg = (
49634963
"Invalid chunk key encoding. For Zarr format 2 arrays, the `name` field of the "
Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from __future__ import annotations
22

3-
from abc import abstractmethod
3+
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Literal, TypeAlias, TypedDict, cast
5+
from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias, TypedDict, cast
66

77
if TYPE_CHECKING:
8-
from typing import NotRequired
8+
from typing import NotRequired, Self
99

1010
from zarr.abc.metadata import Metadata
1111
from zarr.core.common import (
1212
JSON,
1313
parse_named_configuration,
1414
)
15+
from zarr.registry import get_chunk_key_encoding_class, register_chunk_key_encoding
1516

1617
SeparatorLiteral = Literal[".", "/"]
1718

@@ -28,60 +29,49 @@ class ChunkKeyEncodingParams(TypedDict):
2829

2930

3031
@dataclass(frozen=True)
31-
class ChunkKeyEncoding(Metadata):
32-
name: str
33-
separator: SeparatorLiteral = "."
32+
class ChunkKeyEncoding(ABC, Metadata):
33+
"""
34+
Defines how chunk coordinates are mapped to store keys.
3435
35-
def __init__(self, *, separator: SeparatorLiteral) -> None:
36-
separator_parsed = parse_separator(separator)
36+
Subclasses must define a class variable `name` and implement `encode_chunk_key`.
37+
"""
3738

38-
object.__setattr__(self, "separator", separator_parsed)
39+
name: ClassVar[str]
3940

4041
@classmethod
41-
def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncodingLike) -> ChunkKeyEncoding:
42-
if isinstance(data, ChunkKeyEncoding):
43-
return data
44-
45-
# handle ChunkKeyEncodingParams
46-
if "name" in data and "separator" in data:
47-
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}
48-
49-
# TODO: remove this cast when we are statically typing the JSON metadata completely.
50-
data = cast("dict[str, JSON]", data)
51-
52-
# configuration is optional for chunk key encodings
53-
name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False)
54-
if name_parsed == "default":
55-
if config_parsed is None:
56-
# for default, normalize missing configuration to use the "/" separator.
57-
config_parsed = {"separator": "/"}
58-
return DefaultChunkKeyEncoding(**config_parsed) # type: ignore[arg-type]
59-
if name_parsed == "v2":
60-
if config_parsed is None:
61-
# for v2, normalize missing configuration to use the "." separator.
62-
config_parsed = {"separator": "."}
63-
return V2ChunkKeyEncoding(**config_parsed) # type: ignore[arg-type]
64-
msg = f"Unknown chunk key encoding. Got {name_parsed}, expected one of ('v2', 'default')."
65-
raise ValueError(msg)
42+
def from_dict(cls, data: dict[str, JSON]) -> Self:
43+
_, config_parsed = parse_named_configuration(data, require_configuration=False)
44+
return cls(**config_parsed if config_parsed else {})
6645

6746
def to_dict(self) -> dict[str, JSON]:
68-
return {"name": self.name, "configuration": {"separator": self.separator}}
47+
return {"name": self.name, "configuration": super().to_dict()}
6948

70-
@abstractmethod
7149
def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
72-
pass
50+
"""
51+
Optional: decode a chunk key string into chunk coordinates.
52+
Not required for normal operation; override if needed for testing or debugging.
53+
"""
54+
raise NotImplementedError(f"{self.__class__.__name__} does not implement decode_chunk_key.")
7355

7456
@abstractmethod
7557
def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
76-
pass
58+
"""
59+
Encode chunk coordinates into a chunk key string.
60+
Must be implemented by subclasses.
61+
"""
7762

7863

79-
ChunkKeyEncodingLike: TypeAlias = ChunkKeyEncodingParams | ChunkKeyEncoding
64+
ChunkKeyEncodingLike: TypeAlias = dict[str, JSON] | ChunkKeyEncodingParams | ChunkKeyEncoding
8065

8166

8267
@dataclass(frozen=True)
8368
class DefaultChunkKeyEncoding(ChunkKeyEncoding):
84-
name: Literal["default"] = "default"
69+
name: ClassVar[Literal["default"]] = "default"
70+
separator: SeparatorLiteral = "/"
71+
72+
def __post_init__(self) -> None:
73+
separator_parsed = parse_separator(self.separator)
74+
object.__setattr__(self, "separator", separator_parsed)
8575

8676
def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
8777
if chunk_key == "c":
@@ -94,11 +84,43 @@ def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
9484

9585
@dataclass(frozen=True)
9686
class V2ChunkKeyEncoding(ChunkKeyEncoding):
97-
name: Literal["v2"] = "v2"
87+
name: ClassVar[Literal["v2"]] = "v2"
88+
separator: SeparatorLiteral = "."
89+
90+
def __post_init__(self) -> None:
91+
separator_parsed = parse_separator(self.separator)
92+
object.__setattr__(self, "separator", separator_parsed)
9893

9994
def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
10095
return tuple(map(int, chunk_key.split(self.separator)))
10196

10297
def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
10398
chunk_identifier = self.separator.join(map(str, chunk_coords))
10499
return "0" if chunk_identifier == "" else chunk_identifier
100+
101+
102+
def parse_chunk_key_encoding(data: ChunkKeyEncodingLike) -> ChunkKeyEncoding:
103+
"""
104+
Take an implicit specification of a chunk key encoding and parse it into a ChunkKeyEncoding object.
105+
"""
106+
if isinstance(data, ChunkKeyEncoding):
107+
return data
108+
109+
# handle ChunkKeyEncodingParams
110+
if "name" in data and "separator" in data:
111+
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}
112+
113+
# Now must be a named config
114+
data = cast("dict[str, JSON]", data)
115+
116+
name_parsed, _ = parse_named_configuration(data, require_configuration=False)
117+
try:
118+
chunk_key_encoding = get_chunk_key_encoding_class(name_parsed).from_dict(data)
119+
except KeyError as e:
120+
raise ValueError(f"Unknown chunk key encoding: {e.args[0]!r}") from e
121+
122+
return chunk_key_encoding
123+
124+
125+
register_chunk_key_encoding("default", DefaultChunkKeyEncoding)
126+
register_chunk_key_encoding("v2", V2ChunkKeyEncoding)

src/zarr/core/metadata/v3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2525
from zarr.core.array_spec import ArrayConfig, ArraySpec
2626
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
27-
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
27+
from zarr.core.chunk_key_encodings import (
28+
ChunkKeyEncoding,
29+
ChunkKeyEncodingLike,
30+
parse_chunk_key_encoding,
31+
)
2832
from zarr.core.common import (
2933
JSON,
3034
ZARR_JSON,
@@ -174,7 +178,7 @@ def __init__(
174178

175179
shape_parsed = parse_shapelike(shape)
176180
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
177-
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
181+
chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding)
178182
dimension_names_parsed = parse_dimension_names(dimension_names)
179183
# Note: relying on a type method is numpy-specific
180184
fill_value_parsed = data_type.cast_scalar(fill_value)

src/zarr/registry.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@
2222
)
2323
from zarr.abc.numcodec import Numcodec
2424
from zarr.core.buffer import Buffer, NDBuffer
25+
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
2526
from zarr.core.common import JSON
2627

2728
__all__ = [
2829
"Registry",
2930
"get_buffer_class",
31+
"get_chunk_key_encoding_class",
3032
"get_codec_class",
3133
"get_ndbuffer_class",
3234
"get_pipeline_class",
3335
"register_buffer",
36+
"register_chunk_key_encoding",
3437
"register_codec",
3538
"register_ndbuffer",
3639
"register_pipeline",
@@ -44,9 +47,9 @@ def __init__(self) -> None:
4447
super().__init__()
4548
self.lazy_load_list: list[EntryPoint] = []
4649

47-
def lazy_load(self) -> None:
50+
def lazy_load(self, use_entrypoint_name: bool = False) -> None:
4851
for e in self.lazy_load_list:
49-
self.register(e.load())
52+
self.register(e.load(), qualname=e.name if use_entrypoint_name else None)
5053

5154
self.lazy_load_list.clear()
5255

@@ -60,10 +63,11 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
6063
__pipeline_registry: Registry[CodecPipeline] = Registry()
6164
__buffer_registry: Registry[Buffer] = Registry()
6265
__ndbuffer_registry: Registry[NDBuffer] = Registry()
66+
__chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry()
6367

6468
"""
6569
The registry module is responsible for managing implementations of codecs,
66-
pipelines, buffers and ndbuffers and collecting them from entrypoints.
70+
pipelines, buffers, ndbuffers, and chunk key encodings and collecting them from entrypoints.
6771
The implementation used is determined by the config.
6872
6973
The registry module is also responsible for managing dtypes.
@@ -99,6 +103,13 @@ def _collect_entrypoints() -> list[Registry[Any]]:
99103
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type"))
100104
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type"))
101105

106+
__chunk_key_encoding_registry.lazy_load_list.extend(
107+
entry_points.select(group="zarr.chunk_key_encoding")
108+
)
109+
__chunk_key_encoding_registry.lazy_load_list.extend(
110+
entry_points.select(group="zarr", name="chunk_key_encoding")
111+
)
112+
102113
__pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline"))
103114
__pipeline_registry.lazy_load_list.extend(
104115
entry_points.select(group="zarr", name="codec_pipeline")
@@ -114,6 +125,7 @@ def _collect_entrypoints() -> list[Registry[Any]]:
114125
__pipeline_registry,
115126
__buffer_registry,
116127
__ndbuffer_registry,
128+
__chunk_key_encoding_registry,
117129
]
118130

119131

@@ -144,6 +156,10 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None:
144156
__buffer_registry.register(cls, qualname)
145157

146158

159+
def register_chunk_key_encoding(key: str, cls: type) -> None:
160+
__chunk_key_encoding_registry.register(cls, key)
161+
162+
147163
def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]:
148164
if reload_config:
149165
_reload_config()
@@ -280,6 +296,15 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]:
280296
)
281297

282298

299+
def get_chunk_key_encoding_class(key: str) -> type[ChunkKeyEncoding]:
300+
__chunk_key_encoding_registry.lazy_load(use_entrypoint_name=True)
301+
if key not in __chunk_key_encoding_registry:
302+
raise KeyError(
303+
f"Chunk key encoding '{key}' not found in registered chunk key encodings: {list(__chunk_key_encoding_registry)}."
304+
)
305+
return __chunk_key_encoding_registry[key]
306+
307+
283308
_collect_entrypoints()
284309

285310

0 commit comments

Comments
 (0)