Skip to content

Commit 0289779

Browse files
committed
Add registry for chunk key encodings.
1 parent e76b1e0 commit 0289779

File tree

4 files changed

+74
-35
lines changed

4 files changed

+74
-35
lines changed

src/zarr/core/array.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ChunkKeyEncodingLike,
4848
DefaultChunkKeyEncoding,
4949
V2ChunkKeyEncoding,
50+
parse_chunk_key_encoding,
5051
)
5152
from zarr.core.common import (
5253
JSON,
@@ -4934,13 +4935,11 @@ def _parse_chunk_key_encoding(
49344935
"""
49354936
if data is None:
49364937
if zarr_format == 2:
4937-
result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "."})
4938+
data = {"name": "v2", "configuration": {"separator": "."}}
49384939
else:
4939-
result = ChunkKeyEncoding.from_dict({"name": "default", "separator": "/"})
4940-
elif isinstance(data, ChunkKeyEncoding):
4941-
result = data
4942-
else:
4943-
result = ChunkKeyEncoding.from_dict(data)
4940+
data = {"name": "default", "configuration": {"separator": "/"}}
4941+
result = parse_chunk_key_encoding(data)
4942+
49444943
if zarr_format == 2 and result.name != "v2":
49454944
msg = (
49464945
"Invalid chunk key encoding. For Zarr format 2 arrays, the `name` field of the "

src/zarr/core/chunk_key_encodings.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from typing import TYPE_CHECKING, Literal, TypeAlias, TypedDict, cast
66

77
if TYPE_CHECKING:
8-
from typing import NotRequired
8+
from typing import NotRequired, Self
99

1010
from zarr.abc.metadata import Metadata
1111
from zarr.core.common import (
1212
JSON,
1313
parse_named_configuration,
1414
)
15+
from zarr.registry import get_chunk_key_encoding_class, register_chunk_key_encoding
1516

1617
SeparatorLiteral = Literal[".", "/"]
1718

@@ -38,31 +39,9 @@ def __init__(self, *, separator: SeparatorLiteral) -> None:
3839
object.__setattr__(self, "separator", separator_parsed)
3940

4041
@classmethod
41-
def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncodingLike) -> ChunkKeyEncoding:
42-
if isinstance(data, ChunkKeyEncoding):
43-
return data
44-
45-
# handle ChunkKeyEncodingParams
46-
if "name" in data and "separator" in data:
47-
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}
48-
49-
# TODO: remove this cast when we are statically typing the JSON metadata completely.
50-
data = cast("dict[str, JSON]", data)
51-
52-
# configuration is optional for chunk key encodings
42+
def from_dict(cls, data: dict[str, JSON]) -> Self:
5343
name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False)
54-
if name_parsed == "default":
55-
if config_parsed is None:
56-
# for default, normalize missing configuration to use the "/" separator.
57-
config_parsed = {"separator": "/"}
58-
return DefaultChunkKeyEncoding(**config_parsed) # type: ignore[arg-type]
59-
if name_parsed == "v2":
60-
if config_parsed is None:
61-
# for v2, normalize missing configuration to use the "." separator.
62-
config_parsed = {"separator": "."}
63-
return V2ChunkKeyEncoding(**config_parsed) # type: ignore[arg-type]
64-
msg = f"Unknown chunk key encoding. Got {name_parsed}, expected one of ('v2', 'default')."
65-
raise ValueError(msg)
44+
return cls(**config_parsed if config_parsed else {}) # type: ignore[arg-type]
6645

6746
def to_dict(self) -> dict[str, JSON]:
6847
return {"name": self.name, "configuration": {"separator": self.separator}}
@@ -76,12 +55,13 @@ def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
7655
pass
7756

7857

79-
ChunkKeyEncodingLike: TypeAlias = ChunkKeyEncodingParams | ChunkKeyEncoding
58+
ChunkKeyEncodingLike: TypeAlias = dict[str, JSON] | ChunkKeyEncodingParams | ChunkKeyEncoding
8059

8160

8261
@dataclass(frozen=True)
8362
class DefaultChunkKeyEncoding(ChunkKeyEncoding):
8463
name: Literal["default"] = "default"
64+
separator: SeparatorLiteral = "/" # default
8565

8666
def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
8767
if chunk_key == "c":
@@ -95,10 +75,38 @@ def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
9575
@dataclass(frozen=True)
9676
class V2ChunkKeyEncoding(ChunkKeyEncoding):
9777
name: Literal["v2"] = "v2"
78+
separator: SeparatorLiteral = "." # default
9879

9980
def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
10081
return tuple(map(int, chunk_key.split(self.separator)))
10182

10283
def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
10384
chunk_identifier = self.separator.join(map(str, chunk_coords))
10485
return "0" if chunk_identifier == "" else chunk_identifier
86+
87+
88+
def parse_chunk_key_encoding(data: ChunkKeyEncodingLike) -> ChunkKeyEncoding:
89+
"""
90+
Take an implicit specification of a chunk key encoding and parse it into a ChunkKeyEncoding object.
91+
"""
92+
if isinstance(data, ChunkKeyEncoding):
93+
return data
94+
95+
# handle ChunkKeyEncodingParams
96+
if "name" in data and "separator" in data:
97+
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}
98+
99+
# Now must be a named config
100+
data = cast("dict[str, JSON]", data)
101+
102+
name_parsed, _ = parse_named_configuration(data, require_configuration=False)
103+
try:
104+
chunk_key_encoding = get_chunk_key_encoding_class(name_parsed).from_dict(data)
105+
except KeyError as e:
106+
raise ValueError(f"Unknown chunk key encoding: {e.args[0]!r}") from e
107+
108+
return chunk_key_encoding
109+
110+
111+
register_chunk_key_encoding(DefaultChunkKeyEncoding, qualname="default")
112+
register_chunk_key_encoding(V2ChunkKeyEncoding, qualname="v2")

src/zarr/core/metadata/v3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2525
from zarr.core.array_spec import ArrayConfig, ArraySpec
2626
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
27-
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
27+
from zarr.core.chunk_key_encodings import (
28+
ChunkKeyEncoding,
29+
ChunkKeyEncodingLike,
30+
parse_chunk_key_encoding,
31+
)
2832
from zarr.core.common import (
2933
JSON,
3034
ZARR_JSON,
@@ -174,7 +178,7 @@ def __init__(
174178

175179
shape_parsed = parse_shapelike(shape)
176180
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
177-
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
181+
chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding)
178182
dimension_names_parsed = parse_dimension_names(dimension_names)
179183
# Note: relying on a type method is numpy-specific
180184
fill_value_parsed = data_type.cast_scalar(fill_value)

src/zarr/registry.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@
2222
)
2323
from zarr.abc.numcodec import Numcodec
2424
from zarr.core.buffer import Buffer, NDBuffer
25+
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
2526
from zarr.core.common import JSON
2627

28+
# CHANGE: Consider adding here
2729
__all__ = [
2830
"Registry",
2931
"get_buffer_class",
32+
"get_chunk_key_encoding_class",
3033
"get_codec_class",
3134
"get_ndbuffer_class",
3235
"get_pipeline_class",
3336
"register_buffer",
37+
"register_chunk_key_encoding",
3438
"register_codec",
3539
"register_ndbuffer",
3640
"register_pipeline",
@@ -60,10 +64,12 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
6064
__pipeline_registry: Registry[CodecPipeline] = Registry()
6165
__buffer_registry: Registry[Buffer] = Registry()
6266
__ndbuffer_registry: Registry[NDBuffer] = Registry()
67+
__chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry()
6368

69+
# CHANGE: Consider updating docstring
6470
"""
6571
The registry module is responsible for managing implementations of codecs,
66-
pipelines, buffers and ndbuffers and collecting them from entrypoints.
72+
pipelines, buffers, ndbuffers, and chunk key encodings and collecting them from entrypoints.
6773
The implementation used is determined by the config.
6874
6975
The registry module is also responsible for managing dtypes.
@@ -99,6 +105,13 @@ def _collect_entrypoints() -> list[Registry[Any]]:
99105
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type"))
100106
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type"))
101107

108+
__chunk_key_encoding_registry.lazy_load_list.extend(
109+
entry_points.select(group="zarr.chunk_key_encoding")
110+
)
111+
__chunk_key_encoding_registry.lazy_load_list.extend(
112+
entry_points.select(group="zarr", name="chunk_key_encoding")
113+
)
114+
102115
__pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline"))
103116
__pipeline_registry.lazy_load_list.extend(
104117
entry_points.select(group="zarr", name="codec_pipeline")
@@ -114,6 +127,7 @@ def _collect_entrypoints() -> list[Registry[Any]]:
114127
__pipeline_registry,
115128
__buffer_registry,
116129
__ndbuffer_registry,
130+
__chunk_key_encoding_registry,
117131
]
118132

119133

@@ -144,6 +158,10 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None:
144158
__buffer_registry.register(cls, qualname)
145159

146160

161+
def register_chunk_key_encoding(cls: type, qualname: str | None = None) -> None:
162+
__chunk_key_encoding_registry.register(cls, qualname)
163+
164+
147165
def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]:
148166
if reload_config:
149167
_reload_config()
@@ -281,6 +299,16 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]:
281299
)
282300

283301

302+
def get_chunk_key_encoding_class(key: str) -> type[ChunkKeyEncoding]:
303+
__chunk_key_encoding_registry.lazy_load()
304+
if key not in __chunk_key_encoding_registry:
305+
raise KeyError(
306+
f"Chunk key encoding '{key}' not found in registered chunk key encodings: {list(__chunk_key_encoding_registry)}."
307+
)
308+
309+
return __chunk_key_encoding_registry[key]
310+
311+
284312
_collect_entrypoints()
285313

286314

0 commit comments

Comments
 (0)