Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/3436.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Adds a registry for chunk key encodings for extensibility.
This allows users to implement a custom `ChunkKeyEncoding`, which can be registered via `register_chunk_key_encoding` or as an entry point under `zarr.chunk_key_encoding`.
73 changes: 35 additions & 38 deletions docs/user-guide/consolidated_metadata.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,44 +49,41 @@ that can be used.:
>>> from pprint import pprint
>>> pprint(dict(consolidated_metadata.items()))
{'a': ArrayV3Metadata(shape=(1,),
data_type=Float64(endianness='little'),
chunk_grid=RegularChunkGrid(chunk_shape=(1,)),
chunk_key_encoding=DefaultChunkKeyEncoding(name='default',
separator='/'),
fill_value=np.float64(0.0),
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
ZstdCodec(level=0, checksum=False)),
attributes={},
dimension_names=None,
zarr_format=3,
node_type='array',
storage_transformers=()),
'b': ArrayV3Metadata(shape=(2, 2),
data_type=Float64(endianness='little'),
chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)),
chunk_key_encoding=DefaultChunkKeyEncoding(name='default',
separator='/'),
fill_value=np.float64(0.0),
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
ZstdCodec(level=0, checksum=False)),
attributes={},
dimension_names=None,
zarr_format=3,
node_type='array',
storage_transformers=()),
'c': ArrayV3Metadata(shape=(3, 3, 3),
data_type=Float64(endianness='little'),
chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)),
chunk_key_encoding=DefaultChunkKeyEncoding(name='default',
separator='/'),
fill_value=np.float64(0.0),
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
ZstdCodec(level=0, checksum=False)),
attributes={},
dimension_names=None,
zarr_format=3,
node_type='array',
storage_transformers=())}
data_type=Float64(endianness='little'),
chunk_grid=RegularChunkGrid(chunk_shape=(1,)),
chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'),
fill_value=np.float64(0.0),
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
ZstdCodec(level=0, checksum=False)),
attributes={},
dimension_names=None,
zarr_format=3,
node_type='array',
storage_transformers=()),
'b': ArrayV3Metadata(shape=(2, 2),
data_type=Float64(endianness='little'),
chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)),
chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'),
fill_value=np.float64(0.0),
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
ZstdCodec(level=0, checksum=False)),
attributes={},
dimension_names=None,
zarr_format=3,
node_type='array',
storage_transformers=()),
'c': ArrayV3Metadata(shape=(3, 3, 3),
data_type=Float64(endianness='little'),
chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)),
chunk_key_encoding=DefaultChunkKeyEncoding(separator='/'),
fill_value=np.float64(0.0),
codecs=(BytesCodec(endian=<Endian.little: 'little'>),
ZstdCodec(level=0, checksum=False)),
attributes={},
dimension_names=None,
zarr_format=3,
node_type='array',
storage_transformers=())}

Operations on the group to get children automatically use the consolidated metadata.:

Expand Down
12 changes: 6 additions & 6 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ChunkKeyEncodingLike,
DefaultChunkKeyEncoding,
V2ChunkKeyEncoding,
parse_chunk_key_encoding,
)
from zarr.core.common import (
JSON,
Expand Down Expand Up @@ -4602,6 +4603,7 @@ async def init_array(
order_parsed = zarr_config.get("array.order")
else:
order_parsed = order
chunk_key_encoding_parsed = cast("V2ChunkKeyEncoding", chunk_key_encoding_parsed)

meta = AsyncArray._create_metadata_v2(
shape=shape_parsed,
Expand Down Expand Up @@ -4951,13 +4953,11 @@ def _parse_chunk_key_encoding(
"""
if data is None:
if zarr_format == 2:
result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "."})
data = {"name": "v2", "configuration": {"separator": "."}}
else:
result = ChunkKeyEncoding.from_dict({"name": "default", "separator": "/"})
elif isinstance(data, ChunkKeyEncoding):
result = data
else:
result = ChunkKeyEncoding.from_dict(data)
data = {"name": "default", "configuration": {"separator": "/"}}
result = parse_chunk_key_encoding(data)

if zarr_format == 2 and result.name != "v2":
msg = (
"Invalid chunk key encoding. For Zarr format 2 arrays, the `name` field of the "
Expand Down
104 changes: 63 additions & 41 deletions src/zarr/core/chunk_key_encodings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, TypeAlias, TypedDict, cast
from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias, TypedDict, cast

if TYPE_CHECKING:
from typing import NotRequired
from typing import NotRequired, Self

from zarr.abc.metadata import Metadata
from zarr.core.common import (
JSON,
parse_named_configuration,
)
from zarr.registry import get_chunk_key_encoding_class, register_chunk_key_encoding

SeparatorLiteral = Literal[".", "/"]

Expand All @@ -28,60 +29,49 @@ class ChunkKeyEncodingParams(TypedDict):


@dataclass(frozen=True)
class ChunkKeyEncoding(Metadata):
name: str
separator: SeparatorLiteral = "."
class ChunkKeyEncoding(ABC, Metadata):
"""
Defines how chunk coordinates are mapped to store keys.

def __init__(self, *, separator: SeparatorLiteral) -> None:
separator_parsed = parse_separator(separator)
Subclasses must define a class variable `name` and implement `encode_chunk_key`.
"""

object.__setattr__(self, "separator", separator_parsed)
name: ClassVar[str]

@classmethod
def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncodingLike) -> ChunkKeyEncoding:
if isinstance(data, ChunkKeyEncoding):
return data

# handle ChunkKeyEncodingParams
if "name" in data and "separator" in data:
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}

# TODO: remove this cast when we are statically typing the JSON metadata completely.
data = cast("dict[str, JSON]", data)

# configuration is optional for chunk key encodings
name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False)
if name_parsed == "default":
if config_parsed is None:
# for default, normalize missing configuration to use the "/" separator.
config_parsed = {"separator": "/"}
return DefaultChunkKeyEncoding(**config_parsed) # type: ignore[arg-type]
if name_parsed == "v2":
if config_parsed is None:
# for v2, normalize missing configuration to use the "." separator.
config_parsed = {"separator": "."}
return V2ChunkKeyEncoding(**config_parsed) # type: ignore[arg-type]
msg = f"Unknown chunk key encoding. Got {name_parsed}, expected one of ('v2', 'default')."
raise ValueError(msg)
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, config_parsed = parse_named_configuration(data, require_configuration=False)
return cls(**config_parsed if config_parsed else {})

def to_dict(self) -> dict[str, JSON]:
return {"name": self.name, "configuration": {"separator": self.separator}}
return {"name": self.name, "configuration": super().to_dict()}

@abstractmethod
def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
pass
"""
Optional: decode a chunk key string into chunk coordinates.
Not required for normal operation; override if needed for testing or debugging.
"""
raise NotImplementedError(f"{self.__class__.__name__} does not implement decode_chunk_key.")

@abstractmethod
def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
pass
"""
Encode chunk coordinates into a chunk key string.
Must be implemented by subclasses.
"""


ChunkKeyEncodingLike: TypeAlias = ChunkKeyEncodingParams | ChunkKeyEncoding
ChunkKeyEncodingLike: TypeAlias = dict[str, JSON] | ChunkKeyEncodingParams | ChunkKeyEncoding


@dataclass(frozen=True)
class DefaultChunkKeyEncoding(ChunkKeyEncoding):
name: Literal["default"] = "default"
name: ClassVar[Literal["default"]] = "default"
separator: SeparatorLiteral = "/"

def __post_init__(self) -> None:
separator_parsed = parse_separator(self.separator)
object.__setattr__(self, "separator", separator_parsed)

def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
if chunk_key == "c":
Expand All @@ -94,11 +84,43 @@ def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:

@dataclass(frozen=True)
class V2ChunkKeyEncoding(ChunkKeyEncoding):
name: Literal["v2"] = "v2"
name: ClassVar[Literal["v2"]] = "v2"
separator: SeparatorLiteral = "."

def __post_init__(self) -> None:
separator_parsed = parse_separator(self.separator)
object.__setattr__(self, "separator", separator_parsed)

def decode_chunk_key(self, chunk_key: str) -> tuple[int, ...]:
return tuple(map(int, chunk_key.split(self.separator)))

def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
chunk_identifier = self.separator.join(map(str, chunk_coords))
return "0" if chunk_identifier == "" else chunk_identifier


def parse_chunk_key_encoding(data: ChunkKeyEncodingLike) -> ChunkKeyEncoding:
"""
Take an implicit specification of a chunk key encoding and parse it into a ChunkKeyEncoding object.
"""
if isinstance(data, ChunkKeyEncoding):
return data

# handle ChunkKeyEncodingParams
if "name" in data and "separator" in data:
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}

# Now must be a named config
data = cast("dict[str, JSON]", data)

name_parsed, _ = parse_named_configuration(data, require_configuration=False)
try:
chunk_key_encoding = get_chunk_key_encoding_class(name_parsed).from_dict(data)
except KeyError as e:
raise ValueError(f"Unknown chunk key encoding: {e.args[0]!r}") from e

return chunk_key_encoding


register_chunk_key_encoding("default", DefaultChunkKeyEncoding)
register_chunk_key_encoding("v2", V2ChunkKeyEncoding)
8 changes: 6 additions & 2 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
from zarr.core.array_spec import ArrayConfig, ArraySpec
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
from zarr.core.chunk_key_encodings import (
ChunkKeyEncoding,
ChunkKeyEncodingLike,
parse_chunk_key_encoding,
)
from zarr.core.common import (
JSON,
ZARR_JSON,
Expand Down Expand Up @@ -174,7 +178,7 @@ def __init__(

shape_parsed = parse_shapelike(shape)
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding)
dimension_names_parsed = parse_dimension_names(dimension_names)
# Note: relying on a type method is numpy-specific
fill_value_parsed = data_type.cast_scalar(fill_value)
Expand Down
31 changes: 28 additions & 3 deletions src/zarr/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@
)
from zarr.abc.numcodec import Numcodec
from zarr.core.buffer import Buffer, NDBuffer
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
from zarr.core.common import JSON

__all__ = [
"Registry",
"get_buffer_class",
"get_chunk_key_encoding_class",
"get_codec_class",
"get_ndbuffer_class",
"get_pipeline_class",
"register_buffer",
"register_chunk_key_encoding",
"register_codec",
"register_ndbuffer",
"register_pipeline",
Expand All @@ -44,9 +47,9 @@ def __init__(self) -> None:
super().__init__()
self.lazy_load_list: list[EntryPoint] = []

def lazy_load(self) -> None:
def lazy_load(self, use_entrypoint_name: bool = False) -> None:
for e in self.lazy_load_list:
self.register(e.load())
self.register(e.load(), qualname=e.name if use_entrypoint_name else None)

self.lazy_load_list.clear()

Expand All @@ -60,10 +63,11 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
__pipeline_registry: Registry[CodecPipeline] = Registry()
__buffer_registry: Registry[Buffer] = Registry()
__ndbuffer_registry: Registry[NDBuffer] = Registry()
__chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry()

"""
The registry module is responsible for managing implementations of codecs,
pipelines, buffers and ndbuffers and collecting them from entrypoints.
pipelines, buffers, ndbuffers, and chunk key encodings and collecting them from entrypoints.
The implementation used is determined by the config.

The registry module is also responsible for managing dtypes.
Expand Down Expand Up @@ -99,6 +103,13 @@ def _collect_entrypoints() -> list[Registry[Any]]:
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type"))
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type"))

__chunk_key_encoding_registry.lazy_load_list.extend(
entry_points.select(group="zarr.chunk_key_encoding")
)
__chunk_key_encoding_registry.lazy_load_list.extend(
entry_points.select(group="zarr", name="chunk_key_encoding")
)

__pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline"))
__pipeline_registry.lazy_load_list.extend(
entry_points.select(group="zarr", name="codec_pipeline")
Expand All @@ -114,6 +125,7 @@ def _collect_entrypoints() -> list[Registry[Any]]:
__pipeline_registry,
__buffer_registry,
__ndbuffer_registry,
__chunk_key_encoding_registry,
]


Expand Down Expand Up @@ -144,6 +156,10 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None:
__buffer_registry.register(cls, qualname)


def register_chunk_key_encoding(key: str, cls: type) -> None:
__chunk_key_encoding_registry.register(cls, key)


def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]:
if reload_config:
_reload_config()
Expand Down Expand Up @@ -280,6 +296,15 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]:
)


def get_chunk_key_encoding_class(key: str) -> type[ChunkKeyEncoding]:
__chunk_key_encoding_registry.lazy_load(use_entrypoint_name=True)
if key not in __chunk_key_encoding_registry:
raise KeyError(
f"Chunk key encoding '{key}' not found in registered chunk key encodings: {list(__chunk_key_encoding_registry)}."
)
return __chunk_key_encoding_registry[key]


_collect_entrypoints()


Expand Down
Loading
Loading