diff --git a/changes/3436.feature.rst b/changes/3436.feature.rst new file mode 100644 index 0000000000..85e28bb8b1 --- /dev/null +++ b/changes/3436.feature.rst @@ -0,0 +1,2 @@ +Adds a registry for chunk key encodings for extensibility. +This allows users to implement a custom `ChunkKeyEncoding`, which can be registered via `register_chunk_key_encoding` or as an entry point under `zarr.chunk_key_encoding`. diff --git a/docs/user-guide/consolidated_metadata.rst b/docs/user-guide/consolidated_metadata.rst index 9d05231f4a..ae50c602ca 100644 --- a/docs/user-guide/consolidated_metadata.rst +++ b/docs/user-guide/consolidated_metadata.rst @@ -49,44 +49,41 @@ that can be used.: >>> from pprint import pprint >>> pprint(dict(consolidated_metadata.items())) {'a': ArrayV3Metadata(shape=(1,), - data_type=Float64(endianness='little'), - chunk_grid=RegularChunkGrid(chunk_shape=(1,)), - chunk_key_encoding=DefaultChunkKeyEncoding(name='default', - separator='/'), - fill_value=np.float64(0.0), - codecs=(BytesCodec(endian=), - ZstdCodec(level=0, checksum=False)), - attributes={}, - dimension_names=None, - zarr_format=3, - node_type='array', - storage_transformers=()), - 'b': ArrayV3Metadata(shape=(2, 2), - data_type=Float64(endianness='little'), - chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)), - chunk_key_encoding=DefaultChunkKeyEncoding(name='default', - separator='/'), - fill_value=np.float64(0.0), - codecs=(BytesCodec(endian=), - ZstdCodec(level=0, checksum=False)), - attributes={}, - dimension_names=None, - zarr_format=3, - node_type='array', - storage_transformers=()), - 'c': ArrayV3Metadata(shape=(3, 3, 3), - data_type=Float64(endianness='little'), - chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)), - chunk_key_encoding=DefaultChunkKeyEncoding(name='default', - separator='/'), - fill_value=np.float64(0.0), - codecs=(BytesCodec(endian=), - ZstdCodec(level=0, checksum=False)), - attributes={}, - dimension_names=None, - zarr_format=3, - node_type='array', - storage_transformers=())} + data_type=Float64(endianness='little'), + chunk_grid=RegularChunkGrid(chunk_shape=(1,)), + chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'), + fill_value=np.float64(0.0), + codecs=(BytesCodec(endian=), + ZstdCodec(level=0, checksum=False)), + attributes={}, + dimension_names=None, + zarr_format=3, + node_type='array', + storage_transformers=()), + 'b': ArrayV3Metadata(shape=(2, 2), + data_type=Float64(endianness='little'), + chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)), + chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'), + fill_value=np.float64(0.0), + codecs=(BytesCodec(endian=), + ZstdCodec(level=0, checksum=False)), + attributes={}, + dimension_names=None, + zarr_format=3, + node_type='array', + storage_transformers=()), + 'c': ArrayV3Metadata(shape=(3, 3, 3), + data_type=Float64(endianness='little'), + chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)), + chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'), + fill_value=np.float64(0.0), + codecs=(BytesCodec(endian=), + ZstdCodec(level=0, checksum=False)), + attributes={}, + dimension_names=None, + zarr_format=3, + node_type='array', + storage_transformers=())} Operations on the group to get children automatically use the consolidated metadata.: diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 960b322a25..e5fa451914 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -47,6 +47,7 @@ ChunkKeyEncodingLike, DefaultChunkKeyEncoding, V2ChunkKeyEncoding, + parse_chunk_key_encoding, ) from zarr.core.common import ( JSON, @@ -4602,6 +4603,7 @@ async def init_array( order_parsed = zarr_config.get("array.order") else: order_parsed = order + chunk_key_encoding_parsed = cast("V2ChunkKeyEncoding", chunk_key_encoding_parsed) meta = AsyncArray._create_metadata_v2( shape=shape_parsed, @@ -4951,13 +4953,11 @@ def _parse_chunk_key_encoding( """ if data is None: if zarr_format == 2: - result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "."}) + data = {"name": "v2", "configuration": {"separator": "."}} else: - result = ChunkKeyEncoding.from_dict({"name": "default", "separator": "/"}) - elif isinstance(data, ChunkKeyEncoding): - result = data - else: - result = ChunkKeyEncoding.from_dict(data) + data = {"name": "default", "configuration": {"separator": "/"}} + result = parse_chunk_key_encoding(data) + if zarr_format == 2 and result.name != "v2": msg = ( "Invalid chunk key encoding. For Zarr format 2 arrays, the `name` field of the " diff --git a/src/zarr/core/chunk_key_encodings.py b/src/zarr/core/chunk_key_encodings.py index 89a34e6052..42d7615c61 100644 --- a/src/zarr/core/chunk_key_encodings.py +++ b/src/zarr/core/chunk_key_encodings.py @@ -1,17 +1,18 @@ from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias, TypedDict, cast if TYPE_CHECKING: - from typing import NotRequired + from typing import NotRequired, Self from zarr.abc.metadata import Metadata from zarr.core.common import ( JSON, parse_named_configuration, ) +from zarr.registry import get_chunk_key_encoding_class, register_chunk_key_encoding SeparatorLiteral = Literal[".", "/"] @@ -28,60 +29,49 @@ class ChunkKeyEncodingParams(TypedDict): @dataclass(frozen=True) -class ChunkKeyEncoding(Metadata): - name: str - separator: SeparatorLiteral = "." +class ChunkKeyEncoding(ABC, Metadata): + """ + Defines how chunk coordinates are mapped to store keys. - def __init__(self, *, separator: SeparatorLiteral) -> None: - separator_parsed = parse_separator(separator) + Subclasses must define a class variable `name` and implement `encode_chunk_key`. + """ - object.__setattr__(self, "separator", separator_parsed) + name: ClassVar[str] @classmethod - def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncodingLike) -> ChunkKeyEncoding: - if isinstance(data, ChunkKeyEncoding): - return data - - # handle ChunkKeyEncodingParams - if "name" in data and "separator" in data: - data = {"name": data["name"], "configuration": {"separator": data["separator"]}} - - # TODO: remove this cast when we are statically typing the JSON metadata completely. - data = cast("dict[str, JSON]", data) - - # configuration is optional for chunk key encodings - name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False) - if name_parsed == "default": - if config_parsed is None: - # for default, normalize missing configuration to use the "/" separator. - config_parsed = {"separator": "/"} - return DefaultChunkKeyEncoding(**config_parsed) # type: ignore[arg-type] - if name_parsed == "v2": - if config_parsed is None: - # for v2, normalize missing configuration to use the "." separator. - config_parsed = {"separator": "."} - return V2ChunkKeyEncoding(**config_parsed) # type: ignore[arg-type] - msg = f"Unknown chunk key encoding. Got {name_parsed}, expected one of ('v2', 'default')." - raise ValueError(msg) + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, config_parsed = parse_named_configuration(data, require_configuration=False) + return cls(**config_parsed if config_parsed else {}) def to_dict(self) -> dict[str, JSON]: - return {"name": self.name, "configuration": {"separator": self.separator}} + return {"name": self.name, "configuration": super().to_dict()} - @abstractmethod def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]: - pass + """ + Optional: decode a chunk key string into chunk coordinates. + Not required for normal operation; override if needed for testing or debugging. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not implement decode_chunk_key.") @abstractmethod def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str: - pass + """ + Encode chunk coordinates into a chunk key string. + Must be implemented by subclasses. + """ -ChunkKeyEncodingLike: TypeAlias = ChunkKeyEncodingParams | ChunkKeyEncoding +ChunkKeyEncodingLike: TypeAlias = dict[str, JSON] | ChunkKeyEncodingParams | ChunkKeyEncoding @dataclass(frozen=True) class DefaultChunkKeyEncoding(ChunkKeyEncoding): - name: Literal["default"] = "default" + name: ClassVar[Literal["default"]] = "default" + separator: SeparatorLiteral = "/" + + def __post_init__(self) -> None: + separator_parsed = parse_separator(self.separator) + object.__setattr__(self, "separator", separator_parsed) def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]: if chunk_key == "c": @@ -94,7 +84,12 @@ def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str: @dataclass(frozen=True) class V2ChunkKeyEncoding(ChunkKeyEncoding): - name: Literal["v2"] = "v2" + name: ClassVar[Literal["v2"]] = "v2" + separator: SeparatorLiteral = "." + + def __post_init__(self) -> None: + separator_parsed = parse_separator(self.separator) + object.__setattr__(self, "separator", separator_parsed) def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]: return tuple(map(int, chunk_key.split(self.separator))) @@ -102,3 +97,30 @@ def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]: def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str: chunk_identifier = self.separator.join(map(str, chunk_coords)) return "0" if chunk_identifier == "" else chunk_identifier + + +def parse_chunk_key_encoding(data: ChunkKeyEncodingLike) -> ChunkKeyEncoding: + """ + Take an implicit specification of a chunk key encoding and parse it into a ChunkKeyEncoding object. + """ + if isinstance(data, ChunkKeyEncoding): + return data + + # handle ChunkKeyEncodingParams + if "name" in data and "separator" in data: + data = {"name": data["name"], "configuration": {"separator": data["separator"]}} + + # Now must be a named config + data = cast("dict[str, JSON]", data) + + name_parsed, _ = parse_named_configuration(data, require_configuration=False) + try: + chunk_key_encoding = get_chunk_key_encoding_class(name_parsed).from_dict(data) + except KeyError as e: + raise ValueError(f"Unknown chunk key encoding: {e.args[0]!r}") from e + + return chunk_key_encoding + + +register_chunk_key_encoding("default", DefaultChunkKeyEncoding) +register_chunk_key_encoding("v2", V2ChunkKeyEncoding) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 649a490409..cafcb99281 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -24,7 +24,11 @@ from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid -from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike +from zarr.core.chunk_key_encodings import ( + ChunkKeyEncoding, + ChunkKeyEncodingLike, + parse_chunk_key_encoding, +) from zarr.core.common import ( JSON, ZARR_JSON, @@ -174,7 +178,7 @@ def __init__( shape_parsed = parse_shapelike(shape) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) - chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) + chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) # Note: relying on a type method is numpy-specific fill_value_parsed = data_type.cast_scalar(fill_value) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 5483b65c54..092b4cafc0 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -22,15 +22,18 @@ ) from zarr.abc.numcodec import Numcodec from zarr.core.buffer import Buffer, NDBuffer + from zarr.core.chunk_key_encodings import ChunkKeyEncoding from zarr.core.common import JSON __all__ = [ "Registry", "get_buffer_class", + "get_chunk_key_encoding_class", "get_codec_class", "get_ndbuffer_class", "get_pipeline_class", "register_buffer", + "register_chunk_key_encoding", "register_codec", "register_ndbuffer", "register_pipeline", @@ -44,9 +47,9 @@ def __init__(self) -> None: super().__init__() self.lazy_load_list: list[EntryPoint] = [] - def lazy_load(self) -> None: + def lazy_load(self, use_entrypoint_name: bool = False) -> None: for e in self.lazy_load_list: - self.register(e.load()) + self.register(e.load(), qualname=e.name if use_entrypoint_name else None) self.lazy_load_list.clear() @@ -60,10 +63,11 @@ def register(self, cls: type[T], qualname: str | None = None) -> None: __pipeline_registry: Registry[CodecPipeline] = Registry() __buffer_registry: Registry[Buffer] = Registry() __ndbuffer_registry: Registry[NDBuffer] = Registry() +__chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry() """ The registry module is responsible for managing implementations of codecs, -pipelines, buffers and ndbuffers and collecting them from entrypoints. +pipelines, buffers, ndbuffers, and chunk key encodings and collecting them from entrypoints. The implementation used is determined by the config. The registry module is also responsible for managing dtypes. @@ -99,6 +103,13 @@ def _collect_entrypoints() -> list[Registry[Any]]: data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type")) data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type")) + __chunk_key_encoding_registry.lazy_load_list.extend( + entry_points.select(group="zarr.chunk_key_encoding") + ) + __chunk_key_encoding_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="chunk_key_encoding") + ) + __pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline")) __pipeline_registry.lazy_load_list.extend( entry_points.select(group="zarr", name="codec_pipeline") @@ -114,6 +125,7 @@ def _collect_entrypoints() -> list[Registry[Any]]: __pipeline_registry, __buffer_registry, __ndbuffer_registry, + __chunk_key_encoding_registry, ] @@ -144,6 +156,10 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None: __buffer_registry.register(cls, qualname) +def register_chunk_key_encoding(key: str, cls: type) -> None: + __chunk_key_encoding_registry.register(cls, key) + + def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if reload_config: _reload_config() @@ -280,6 +296,15 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: ) +def get_chunk_key_encoding_class(key: str) -> type[ChunkKeyEncoding]: + __chunk_key_encoding_registry.lazy_load(use_entrypoint_name=True) + if key not in __chunk_key_encoding_registry: + raise KeyError( + f"Chunk key encoding '{key}' not found in registered chunk key encodings: {list(__chunk_key_encoding_registry)}." + ) + return __chunk_key_encoding_registry[key] + + _collect_entrypoints() diff --git a/tests/conftest.py b/tests/conftest.py index 91975408aa..63c8950cff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import sys from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np import numpy.typing as npt @@ -52,6 +52,7 @@ from zarr.core.chunk_key_encodings import ( ChunkKeyEncoding, ChunkKeyEncodingLike, + V2ChunkKeyEncoding, ) from zarr.core.dtype.wrapper import ZDType @@ -339,6 +340,7 @@ def create_array_metadata( filters_parsed, compressor_parsed = _parse_chunk_encoding_v2( compressor=compressors, filters=filters, dtype=dtype_parsed ) + chunk_key_encoding_parsed = cast("V2ChunkKeyEncoding", chunk_key_encoding_parsed) return ArrayV2Metadata( shape=shape_parsed, dtype=dtype_parsed, diff --git a/tests/test_array.py b/tests/test_array.py index 92a5dc77e9..5e3c10dce4 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1247,11 +1247,11 @@ async def test_chunk_key_encoding( chunk_key_encoding = ChunkKeyEncodingParams(name=name, separator=separator) # type: ignore[typeddict-item] error_msg = "" if name == "invalid": - error_msg = "Unknown chunk key encoding." + error_msg = r'Unknown chunk key encoding: "Chunk key encoding \'invalid\' not found in registered chunk key encodings: \[.*\]."' if zarr_format == 2 and name == "default": error_msg = "Invalid chunk key encoding. For Zarr format 2 arrays, the `name` field of the chunk key encoding must be 'v2'." if error_msg: - with pytest.raises(ValueError, match=re.escape(error_msg)): + with pytest.raises(ValueError, match=error_msg): arr = await create_array( store=store, dtype="uint8", diff --git a/tests/test_codecs/test_codecs.py b/tests/test_codecs/test_codecs.py index dfedbb83de..1884d501a5 100644 --- a/tests/test_codecs/test_codecs.py +++ b/tests/test_codecs/test_codecs.py @@ -308,7 +308,7 @@ def test_invalid_metadata(codecs: tuple[Codec, ...]) -> None: ArrayV3Metadata( shape=shape, chunk_grid={"name": "regular", "configuration": {"chunk_shape": chunks}}, - chunk_key_encoding={"name": "default", "configuration": {"separator": "/"}}, # type: ignore[arg-type] + chunk_key_encoding={"name": "default", "configuration": {"separator": "/"}}, fill_value=0, data_type=data_type, codecs=codecs,