diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 73b1a598b9..d1ff0ab701 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,9 +1,9 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, NotRequired, TypedDict, TypeVar -from zarr.abc.metadata import Metadata +from zarr.abc.metadata import Metadata, T from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import ChunkCoords, concurrent_map from zarr.core.config import config @@ -35,7 +35,7 @@ CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer) -class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): +class BaseCodec(Generic[CodecInput, CodecOutput, T], Metadata[T]): """Generic base class for codecs. Codecs can be registered via zarr.codecs.registry. @@ -153,25 +153,41 @@ async def encode( return await _batching_helper(self._encode_single, chunks_and_specs) -class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): +class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer, T]): """Base class for array-to-array codecs.""" ... -class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): +class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer, T]): """Base class for array-to-bytes codecs.""" ... -class BytesBytesCodec(BaseCodec[Buffer, Buffer]): +class BytesBytesCodec(BaseCodec[Buffer, Buffer, T]): """Base class for bytes-to-bytes codecs.""" ... -Codec = ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec +Codec = ArrayArrayCodec[Any] | ArrayBytesCodec[Any] | BytesBytesCodec[Any] + + +class CodecConfigDict(TypedDict): + """A dictionary representing a codec configuration.""" + + ... + + +CodecConfigDictType = TypeVar("CodecConfigDictType", bound=CodecConfigDict) + + +class CodecDict(Generic[CodecConfigDictType], TypedDict): + """A generic dictionary representing a codec.""" + + name: str + configuration: NotRequired[CodecConfigDictType] class ArrayBytesCodecPartialDecodeMixin: diff --git a/src/zarr/abc/metadata.py b/src/zarr/abc/metadata.py index 291ceb459c..32bee7345f 100644 --- a/src/zarr/abc/metadata.py +++ b/src/zarr/abc/metadata.py @@ -1,21 +1,22 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import TYPE_CHECKING +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Generic, TypeVar, cast if TYPE_CHECKING: from typing import Self - from zarr.core.common import JSON from dataclasses import dataclass, fields __all__ = ["Metadata"] +T = TypeVar("T", bound=Mapping[str, object]) + @dataclass(frozen=True) -class Metadata: - def to_dict(self) -> dict[str, JSON]: +class Metadata(Generic[T]): + def to_dict(self) -> T: """ Recursively serialize this model to a dictionary. This method inspects the fields of self and calls `x.to_dict()` for any fields that @@ -35,10 +36,10 @@ def to_dict(self) -> dict[str, JSON]: else: out_dict[key] = value - return out_dict + return cast(T, out_dict) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: T) -> Self: """ Create an instance of the model from a dictionary """ diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index cc6129e604..3ef4f4dfcd 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numcodecs from numcodecs.compat import ensure_bytes, ensure_ndarray @@ -18,7 +18,7 @@ @dataclass(frozen=True) -class V2Compressor(ArrayBytesCodec): +class V2Compressor(ArrayBytesCodec[Any]): compressor: numcodecs.abc.Codec | None is_fixed_size = False @@ -66,7 +66,7 @@ def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) @dataclass(frozen=True) -class V2Filters(ArrayArrayCodec): +class V2Filters(ArrayArrayCodec[Any]): filters: tuple[numcodecs.abc.Codec, ...] | None is_fixed_size = False diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 16bcf48a34..8da742ed4c 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -3,12 +3,12 @@ from dataclasses import dataclass, replace from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numcodecs from numcodecs.blosc import Blosc -from zarr.abc.codec import BytesBytesCodec +from zarr.abc.codec import BytesBytesCodec, CodecConfigDict, CodecDict from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_enum, parse_named_configuration, to_thread from zarr.registry import register_codec @@ -54,6 +54,22 @@ class BloscCname(Enum): zlib = "zlib" +class BloscCodecConfigDict(CodecConfigDict): + """A dictionary representing a Blosc codec configuration.""" + + typesize: int + cname: BloscCname + clevel: int + shuffle: BloscShuffle + blocksize: int + + +class BloscCodecDict(CodecDict[BloscCodecConfigDict]): + """A dictionary representing a Blosc codec.""" + + ... + + # See https://zarr.readthedocs.io/en/stable/tutorial.html#configuring-blosc numcodecs.blosc.use_threads = False @@ -83,7 +99,7 @@ def parse_blocksize(data: JSON) -> int: @dataclass(frozen=True) -class BloscCodec(BytesBytesCodec): +class BloscCodec(BytesBytesCodec[BloscCodecDict]): is_fixed_size = False typesize: int | None @@ -114,16 +130,16 @@ def __init__( object.__setattr__(self, "blocksize", blocksize_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: BloscCodecDict) -> Self: _, configuration_parsed = parse_named_configuration(data, "blosc") return cls(**configuration_parsed) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: + def to_dict(self) -> BloscCodecDict: if self.typesize is None: raise ValueError("`typesize` needs to be set for serialization.") if self.shuffle is None: raise ValueError("`shuffle` needs to be set for serialization.") - return { + out_dict = { "name": "blosc", "configuration": { "typesize": self.typesize, @@ -134,6 +150,8 @@ def to_dict(self) -> dict[str, JSON]: }, } + return cast(BloscCodecDict, out_dict) + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: dtype = array_spec.dtype new_codec = self diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 78c7b22fbc..4e2b8707eb 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -7,13 +7,13 @@ import numpy as np -from zarr.abc.codec import ArrayBytesCodec +from zarr.abc.codec import ArrayBytesCodec, CodecConfigDict, CodecDict from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer -from zarr.core.common import JSON, parse_enum, parse_named_configuration +from zarr.core.common import parse_enum, parse_named_configuration from zarr.registry import register_codec if TYPE_CHECKING: - from typing import Self + from typing import Literal, Self from zarr.core.array_spec import ArraySpec @@ -30,8 +30,21 @@ class Endian(Enum): default_system_endian = Endian(sys.byteorder) +class BytesCodecConfigDict(CodecConfigDict): + """A dictionary representing a bytes codec configuration.""" + + # TODO: Why not type this w/ the Endian Enum + endian: Literal["big", "little"] + + +class BytesCodecDict(CodecDict[BytesCodecConfigDict]): + """A dictionary representing a bytes codec.""" + + ... + + @dataclass(frozen=True) -class BytesCodec(ArrayBytesCodec): +class BytesCodec(ArrayBytesCodec[BytesCodecDict]): is_fixed_size = True endian: Endian | None @@ -42,18 +55,20 @@ def __init__(self, *, endian: Endian | str | None = default_system_endian) -> No object.__setattr__(self, "endian", endian_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: BytesCodecDict) -> Self: _, configuration_parsed = parse_named_configuration( data, "bytes", require_configuration=False ) + configuration_parsed = configuration_parsed or {} return cls(**configuration_parsed) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - if self.endian is None: - return {"name": "bytes"} - else: - return {"name": "bytes", "configuration": {"endian": self.endian.value}} + def to_dict(self) -> BytesCodecDict: + out_dict: BytesCodecDict = {"name": "bytes"} + if self.endian is not None: + out_dict["configuration"] = {"endian": self.endian.value} + + return out_dict def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: if array_spec.dtype.itemsize == 0: diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 3a6624ad25..fb952da809 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -7,8 +7,8 @@ import typing_extensions from crc32c import crc32c -from zarr.abc.codec import BytesBytesCodec -from zarr.core.common import JSON, parse_named_configuration +from zarr.abc.codec import BytesBytesCodec, CodecConfigDict, CodecDict +from zarr.core.common import parse_named_configuration from zarr.registry import register_codec if TYPE_CHECKING: @@ -18,17 +18,24 @@ from zarr.core.buffer import Buffer +class Crc32cCodecDict(CodecDict[CodecConfigDict]): + """A dictionary representing a CRC32C codec.""" + + ... + + @dataclass(frozen=True) -class Crc32cCodec(BytesBytesCodec): +class Crc32cCodec(BytesBytesCodec[Crc32cCodecDict]): is_fixed_size = True @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: Crc32cCodecDict) -> Self: parse_named_configuration(data, "crc32c", require_configuration=False) return cls() - def to_dict(self) -> dict[str, JSON]: - return {"name": "crc32c"} + def to_dict(self) -> Crc32cCodecDict: + out_dict = {"name": "crc32c"} + return cast(Crc32cCodecDict, out_dict) async def _decode_single( self, diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 6cc8517f20..2b7ab0fb3d 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -1,11 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from numcodecs.gzip import GZip -from zarr.abc.codec import BytesBytesCodec +from zarr.abc.codec import BytesBytesCodec, CodecConfigDict, CodecDict from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_named_configuration, to_thread from zarr.registry import register_codec @@ -17,6 +17,18 @@ from zarr.core.buffer import Buffer +class GzipCodecConfigDict(CodecConfigDict): + """A dictionary representing a gzip codec configuration.""" + + level: int + + +class GzipCodecDict(CodecDict[GzipCodecConfigDict]): + """A dictionary representing a gzip codec.""" + + ... + + def parse_gzip_level(data: JSON) -> int: if not isinstance(data, (int)): raise TypeError(f"Expected int, got {type(data)}") @@ -28,7 +40,7 @@ def parse_gzip_level(data: JSON) -> int: @dataclass(frozen=True) -class GzipCodec(BytesBytesCodec): +class GzipCodec(BytesBytesCodec[GzipCodecDict]): is_fixed_size = False level: int = 5 @@ -39,12 +51,13 @@ def __init__(self, *, level: int = 5) -> None: object.__setattr__(self, "level", level_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: GzipCodecDict) -> Self: _, configuration_parsed = parse_named_configuration(data, "gzip") return cls(**configuration_parsed) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - return {"name": "gzip", "configuration": {"level": self.level}} + def to_dict(self) -> GzipCodecDict: + out_dict = {"name": "gzip", "configuration": {"level": self.level}} + return cast(GzipCodecDict, out_dict) async def _decode_single( self, diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index 1226a04f06..0f9e157d23 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -56,6 +56,7 @@ def resolve_batched(codec: Codec, chunk_specs: Iterable[ArraySpec]) -> Iterable[ return [codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs] +# TODO: Double-check whether `CodecDict[Any]` is appropriate @dataclass(frozen=True) class BatchedCodecPipeline(CodecPipeline): """Default codec pipeline. @@ -65,9 +66,9 @@ class BatchedCodecPipeline(CodecPipeline): lock step for each mini-batch. Multiple mini-batches are processing concurrently. """ - array_array_codecs: tuple[ArrayArrayCodec, ...] - array_bytes_codec: ArrayBytesCodec - bytes_bytes_codecs: tuple[BytesBytesCodec, ...] + array_array_codecs: tuple[ArrayArrayCodec[Any], ...] + array_bytes_codec: ArrayBytesCodec[Any] + bytes_bytes_codecs: tuple[BytesBytesCodec[Any], ...] batch_size: int def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: @@ -134,11 +135,11 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: def _codecs_with_resolved_metadata_batched( self, chunk_specs: Iterable[ArraySpec] ) -> tuple[ - list[tuple[ArrayArrayCodec, list[ArraySpec]]], - tuple[ArrayBytesCodec, list[ArraySpec]], - list[tuple[BytesBytesCodec, list[ArraySpec]]], + list[tuple[ArrayArrayCodec[Any], list[ArraySpec]]], + tuple[ArrayBytesCodec[Any], list[ArraySpec]], + list[tuple[BytesBytesCodec[Any], list[ArraySpec]]], ]: - aa_codecs_with_spec: list[tuple[ArrayArrayCodec, list[ArraySpec]]] = [] + aa_codecs_with_spec: list[tuple[ArrayArrayCodec[Any], list[ArraySpec]]] = [] chunk_specs = list(chunk_specs) for aa_codec in self.array_array_codecs: aa_codecs_with_spec.append((aa_codec, chunk_specs)) @@ -149,7 +150,7 @@ def _codecs_with_resolved_metadata_batched( self.array_bytes_codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs ] - bb_codecs_with_spec: list[tuple[BytesBytesCodec, list[ArraySpec]]] = [] + bb_codecs_with_spec: list[tuple[BytesBytesCodec[Any], list[ArraySpec]]] = [] for bb_codec in self.bytes_bytes_codecs: bb_codecs_with_spec.append((bb_codec, chunk_specs)) chunk_specs = [bb_codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs] @@ -465,12 +466,14 @@ async def write( def codecs_from_list( codecs: Iterable[Codec], -) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]: +) -> tuple[ + tuple[ArrayArrayCodec[Any], ...], ArrayBytesCodec[Any], tuple[BytesBytesCodec[Any], ...] +]: from zarr.codecs.sharding import ShardingCodec - array_array: tuple[ArrayArrayCodec, ...] = () - array_bytes_maybe: ArrayBytesCodec | None = None - bytes_bytes: tuple[BytesBytesCodec, ...] = () + array_array: tuple[ArrayArrayCodec[Any], ...] = () + array_bytes_maybe: ArrayBytesCodec[Any] | None = None + bytes_bytes: tuple[BytesBytesCodec[Any], ...] = () if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1: warn( diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 2181e9eb76..9f193913d9 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -15,6 +15,8 @@ ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin, Codec, + CodecConfigDict, + CodecDict, CodecPipeline, ) from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter @@ -320,9 +322,26 @@ async def finalize( return await shard_builder.finalize(index_location, index_encoder) +class ShardingCodecConfigDict(CodecConfigDict): + """A dictionary representing a sharding codec configuration.""" + + chunk_shape: list[int] # TODO: Double check this + codecs: list[CodecDict[Any]] + index_codecs: list[CodecDict[Any]] + index_location: ShardingCodecIndexLocation + + +class ShardingCodecDict(CodecDict[ShardingCodecConfigDict]): + """A dictionary representing a sharding codec.""" + + ... + + @dataclass(frozen=True) class ShardingCodec( - ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin + ArrayBytesCodec[ShardingCodecDict], + ArrayBytesCodecPartialDecodeMixin, + ArrayBytesCodecPartialEncodeMixin, ): chunk_shape: ChunkCoords codecs: tuple[Codec, ...] @@ -353,7 +372,7 @@ def __init__( object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) # todo: typedict return type - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> ShardingCodecDict: return self.to_dict() def __setstate__(self, state: dict[str, Any]) -> None: @@ -369,7 +388,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: ShardingCodecDict) -> Self: _, configuration_parsed = parse_named_configuration(data, "sharding_indexed") return cls(**configuration_parsed) # type: ignore[arg-type] @@ -377,8 +396,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def codec_pipeline(self) -> CodecPipeline: return get_pipeline_class().from_codecs(self.codecs) - def to_dict(self) -> dict[str, JSON]: - return { + def to_dict(self) -> ShardingCodecDict: + out_dict = { "name": "sharding_indexed", "configuration": { "chunk_shape": self.chunk_shape, @@ -388,6 +407,8 @@ def to_dict(self) -> dict[str, JSON]: }, } + return cast(ShardingCodecDict, out_dict) + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: shard_spec = self._get_chunk_spec(array_spec) evolved_codecs = tuple(c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs) diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index 3a471beaf5..0219f434ad 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -6,7 +6,7 @@ import numpy as np -from zarr.abc.codec import ArrayArrayCodec +from zarr.abc.codec import ArrayArrayCodec, CodecConfigDict, CodecDict from zarr.core.array_spec import ArraySpec from zarr.core.common import JSON, ChunkCoordsLike, parse_named_configuration from zarr.registry import register_codec @@ -26,8 +26,20 @@ def parse_transpose_order(data: JSON | Iterable[int]) -> tuple[int, ...]: return tuple(cast(Iterable[int], data)) +class TransposeCodecConfigDict(CodecConfigDict): + """A dictionary representing a transpose codec configuration.""" + + order: list[int] + + +class TransposeCodecDict(CodecDict[TransposeCodecConfigDict]): + """A dictionary representing a transpose codec.""" + + ... + + @dataclass(frozen=True) -class TransposeCodec(ArrayArrayCodec): +class TransposeCodec(ArrayArrayCodec[TransposeCodecDict]): is_fixed_size = True order: tuple[int, ...] @@ -38,12 +50,13 @@ def __init__(self, *, order: ChunkCoordsLike) -> None: object.__setattr__(self, "order", order_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: TransposeCodecDict) -> Self: _, configuration_parsed = parse_named_configuration(data, "transpose") return cls(**configuration_parsed) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - return {"name": "transpose", "configuration": {"order": tuple(self.order)}} + def to_dict(self) -> TransposeCodecDict: + out_dict = {"name": "transpose", "configuration": {"order": tuple(self.order)}} + return cast(TransposeCodecDict, out_dict) def validate(self, shape: tuple[int, ...], dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None: if len(self.order) != len(shape): diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index 43544e0809..e7390e5082 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -1,14 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from numcodecs.vlen import VLenBytes, VLenUTF8 -from zarr.abc.codec import ArrayBytesCodec +from zarr.abc.codec import ArrayBytesCodec, CodecConfigDict, CodecDict from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import JSON, parse_named_configuration +from zarr.core.common import parse_named_configuration from zarr.core.strings import cast_to_string_dtype from zarr.registry import register_codec @@ -23,18 +23,25 @@ _vlen_bytes_codec = VLenBytes() +class VLenUTF8CodecConfigDict(CodecConfigDict): ... + + +class VLenUTF8CodecDict(CodecDict[VLenUTF8CodecConfigDict]): ... + + @dataclass(frozen=True) -class VLenUTF8Codec(ArrayBytesCodec): +class VLenUTF8Codec(ArrayBytesCodec[VLenUTF8CodecDict]): @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: VLenUTF8CodecDict) -> Self: _, configuration_parsed = parse_named_configuration( data, "vlen-utf8", require_configuration=False ) configuration_parsed = configuration_parsed or {} return cls(**configuration_parsed) - def to_dict(self) -> dict[str, JSON]: - return {"name": "vlen-utf8", "configuration": {}} + def to_dict(self) -> VLenUTF8CodecDict: + out_dict = {"name": "vlen-utf8", "configuration": {}} + return cast(VLenUTF8CodecDict, out_dict) def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self @@ -69,18 +76,25 @@ def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) - raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") +class VLenBytesCodecConfigDict(CodecConfigDict): ... + + +class VLenBytesCodecDict(CodecDict[VLenBytesCodecConfigDict]): ... + + @dataclass(frozen=True) -class VLenBytesCodec(ArrayBytesCodec): +class VLenBytesCodec(ArrayBytesCodec[VLenBytesCodecDict]): @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: VLenBytesCodecDict) -> Self: _, configuration_parsed = parse_named_configuration( data, "vlen-bytes", require_configuration=False ) configuration_parsed = configuration_parsed or {} return cls(**configuration_parsed) - def to_dict(self) -> dict[str, JSON]: - return {"name": "vlen-bytes", "configuration": {}} + def to_dict(self) -> VLenBytesCodecDict: + out_dict = {"name": "vlen-bytes", "configuration": {}} + return cast(VLenBytesCodecDict, out_dict) def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 913d0f01c7..0820d4a51d 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -3,11 +3,11 @@ from dataclasses import dataclass from functools import cached_property from importlib.metadata import version -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from numcodecs.zstd import Zstd -from zarr.abc.codec import BytesBytesCodec +from zarr.abc.codec import BytesBytesCodec, CodecConfigDict, CodecDict from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_named_configuration, to_thread from zarr.registry import register_codec @@ -33,8 +33,21 @@ def parse_checksum(data: JSON) -> bool: raise TypeError(f"Expected bool. Got {type(data)}.") +class ZstdCodecConfigDict(CodecConfigDict): + """A dictionary representing a zstd codec configuration.""" + + level: int + checksum: bool + + +class ZstdCodecDict(CodecDict[ZstdCodecConfigDict]): + """A dictionary representing a zstd codec.""" + + ... + + @dataclass(frozen=True) -class ZstdCodec(BytesBytesCodec): +class ZstdCodec(BytesBytesCodec[ZstdCodecDict]): is_fixed_size = True level: int = 0 @@ -56,12 +69,16 @@ def __init__(self, *, level: int = 0, checksum: bool = False) -> None: object.__setattr__(self, "checksum", checksum_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: ZstdCodecDict) -> Self: _, configuration_parsed = parse_named_configuration(data, "zstd") return cls(**configuration_parsed) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - return {"name": "zstd", "configuration": {"level": self.level, "checksum": self.checksum}} + def to_dict(self) -> ZstdCodecDict: + out_dict = { + "name": "zstd", + "configuration": {"level": self.level, "checksum": self.checksum}, + } + return cast(ZstdCodecDict, out_dict) @cached_property def _zstd_codec(self) -> Zstd: diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 8d63d9c321..eb791fd2cf 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -2,6 +2,7 @@ import json from asyncio import gather +from collections.abc import Mapping from dataclasses import dataclass, field, replace from logging import getLogger from typing import TYPE_CHECKING, Any, Literal, cast @@ -64,8 +65,8 @@ is_scalar, pop_fields, ) -from zarr.core.metadata.v2 import ArrayV2Metadata -from zarr.core.metadata.v3 import ArrayV3Metadata +from zarr.core.metadata.v2 import ArrayV2Metadata, ArrayV2MetadataDict +from zarr.core.metadata.v3 import ArrayV3Metadata, ArrayV3MetadataDict from zarr.core.sync import collect_aiterator, sync from zarr.registry import get_pipeline_class from zarr.storage import StoreLike, make_store_path @@ -87,9 +88,9 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata: if isinstance(data, ArrayV2Metadata | ArrayV3Metadata): return data - elif isinstance(data, dict): + elif isinstance(data, Mapping): if data["zarr_format"] == 3: - meta_out = ArrayV3Metadata.from_dict(data) + meta_out = ArrayV3Metadata.from_dict(cast(ArrayV3MetadataDict, data)) if len(meta_out.storage_transformers) > 0: msg = ( f"Array metadata contains storage transformers: {meta_out.storage_transformers}." @@ -98,7 +99,7 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata: raise ValueError(msg) return meta_out elif data["zarr_format"] == 2: - return ArrayV2Metadata.from_dict(data) + return ArrayV2Metadata.from_dict(cast(ArrayV2MetadataDict, data)) raise TypeError @@ -162,23 +163,23 @@ async def get_array_metadata( @dataclass(frozen=True) class AsyncArray: - metadata: ArrayMetadata + metadata: ArrayMetadata[Any] store_path: StorePath codec_pipeline: CodecPipeline = field(init=False) order: Literal["C", "F"] def __init__( self, - metadata: ArrayMetadata | dict[str, Any], + metadata: ArrayMetadata[Any] | Mapping[str, Any], store_path: StorePath, order: Literal["C", "F"] | None = None, ) -> None: - if isinstance(metadata, dict): + if isinstance(metadata, Mapping): zarr_format = metadata["zarr_format"] if zarr_format == 2: - metadata = ArrayV2Metadata.from_dict(metadata) + metadata = ArrayV2Metadata.from_dict(metadata) # type: ignore[arg-type] else: - metadata = ArrayV3Metadata.from_dict(metadata) + metadata = ArrayV3Metadata.from_dict(metadata) # type: ignore[arg-type] metadata_parsed = parse_array_metadata(metadata) order_parsed = parse_indexing_order(order or config.get("array.order")) @@ -635,7 +636,9 @@ async def getitem( ) return await self._get_selection(indexer, prototype=prototype) - async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: + async def _save_metadata( + self, metadata: ArrayMetadata[Any], ensure_parents: bool = False + ) -> None: to_save = metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] @@ -883,7 +886,7 @@ def basename(self) -> str | None: return self._async_array.basename @property - def metadata(self) -> ArrayMetadata: + def metadata(self) -> ArrayMetadata[Any]: return self._async_array.metadata @property diff --git a/src/zarr/core/chunk_grids.py b/src/zarr/core/chunk_grids.py index 77734056b3..17e34e81e0 100644 --- a/src/zarr/core/chunk_grids.py +++ b/src/zarr/core/chunk_grids.py @@ -7,13 +7,12 @@ from abc import abstractmethod from dataclasses import dataclass from functools import reduce -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict, cast import numpy as np from zarr.abc.metadata import Metadata from zarr.core.common import ( - JSON, ChunkCoords, ChunkCoordsLike, ShapeLike, @@ -141,10 +140,23 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl return tuple(int(c) for c in chunks) +class ChunkGridConfigDict(TypedDict): + """A dictionary representing a chunk grid configuration.""" + + chunk_shape: tuple[int, ...] + + +class ChunkGridDict(TypedDict): + """A generic dictionary representing a chunk grid.""" + + name: str + configuration: ChunkGridConfigDict + + @dataclass(frozen=True) -class ChunkGrid(Metadata): +class ChunkGrid(Metadata[ChunkGridDict]): @classmethod - def from_dict(cls, data: dict[str, JSON] | ChunkGrid) -> ChunkGrid: + def from_dict(cls, data: ChunkGridDict | ChunkGrid) -> ChunkGrid: if isinstance(data, ChunkGrid): return data @@ -172,13 +184,13 @@ def __init__(self, *, chunk_shape: ChunkCoordsLike) -> None: object.__setattr__(self, "chunk_shape", chunk_shape_parsed) @classmethod - def _from_dict(cls, data: dict[str, JSON]) -> Self: + def _from_dict(cls, data: ChunkGridDict) -> Self: _, configuration_parsed = parse_named_configuration(data, "regular") - return cls(**configuration_parsed) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}} + def to_dict(self) -> ChunkGridDict: + out_dict = {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}} + return cast(ChunkGridDict, out_dict) def all_chunk_coords(self, array_shape: ChunkCoords) -> Iterator[ChunkCoords]: return itertools.product( diff --git a/src/zarr/core/chunk_key_encodings.py b/src/zarr/core/chunk_key_encodings.py index ed12ee3065..a86615c55e 100644 --- a/src/zarr/core/chunk_key_encodings.py +++ b/src/zarr/core/chunk_key_encodings.py @@ -2,7 +2,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Literal, cast +from typing import TYPE_CHECKING, Literal, TypedDict, cast from zarr.abc.metadata import Metadata from zarr.core.common import ( @@ -11,6 +11,9 @@ parse_named_configuration, ) +if TYPE_CHECKING: + from collections.abc import Mapping + SeparatorLiteral = Literal[".", "/"] @@ -20,8 +23,15 @@ def parse_separator(data: JSON) -> SeparatorLiteral: return cast(SeparatorLiteral, data) +class ChunkKeyEncodingDict(TypedDict): + """A dictionary representing a chunk key encoding configuration.""" + + name: str + configuration: Mapping[Literal["separator"], SeparatorLiteral] + + @dataclass(frozen=True) -class ChunkKeyEncoding(Metadata): +class ChunkKeyEncoding(Metadata[ChunkKeyEncodingDict]): name: str separator: SeparatorLiteral = "." @@ -31,12 +41,14 @@ def __init__(self, *, separator: SeparatorLiteral) -> None: object.__setattr__(self, "separator", separator_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncoding) -> ChunkKeyEncoding: + def from_dict(cls, data: ChunkKeyEncodingDict | ChunkKeyEncoding) -> ChunkKeyEncoding: if isinstance(data, ChunkKeyEncoding): return data + _data = dict(data) + # configuration is optional for chunk key encodings - name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False) + name_parsed, config_parsed = parse_named_configuration(_data, require_configuration=False) # type: ignore[arg-type] if name_parsed == "default": if config_parsed is None: # for default, normalize missing configuration to use the "/" separator. @@ -50,8 +62,9 @@ def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncoding) -> ChunkKeyEncoding msg = f"Unknown chunk key encoding. Got {name_parsed}, expected one of ('v2', 'default')." raise ValueError(msg) - def to_dict(self) -> dict[str, JSON]: - return {"name": self.name, "configuration": {"separator": self.separator}} + def to_dict(self) -> ChunkKeyEncodingDict: + out_dict = {"name": self.name, "configuration": {"separator": self.separator}} + return cast(ChunkKeyEncodingDict, out_dict) @abstractmethod def decode_chunk_key(self, chunk_key: str) -> ChunkCoords: diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index bf0e385d06..fa5dc06543 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3,8 +3,9 @@ import asyncio import json import logging +from collections.abc import Iterator from dataclasses import asdict, dataclass, field, fields, replace -from typing import TYPE_CHECKING, Literal, TypeVar, cast, overload +from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, TypeVar, cast, overload import numpy as np import numpy.typing as npt @@ -33,7 +34,7 @@ from zarr.storage.common import StorePath, ensure_no_existing_node if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator, Iterable, Iterator + from collections.abc import AsyncGenerator, Generator, Iterable, Iterator, Mapping from typing import Any from zarr.abc.codec import Codec @@ -82,8 +83,16 @@ def _parse_async_node(node: AsyncArray | AsyncGroup) -> Array | Group: raise TypeError(f"Unknown node type, got {type(node)}") +class GroupMetadataDict(TypedDict): + """A dictionary representing a group metadata.""" + + attributes: Mapping[str, Any] + node_type: NotRequired[Literal["group"]] + + @dataclass(frozen=True) -class GroupMetadata(Metadata): +class GroupMetadata(Metadata[GroupMetadataDict]): + # TODO: Should attributes be a dict[str, JSON] instead? attributes: dict[str, Any] = field(default_factory=dict) zarr_format: ZarrFormat = 3 node_type: Literal["group"] = field(default="group", init=False) @@ -116,8 +125,9 @@ def __init__( object.__setattr__(self, "zarr_format", zarr_format_parsed) @classmethod - def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: - assert data.pop("node_type", None) in ("group", None) + def from_dict(cls, data: GroupMetadataDict) -> GroupMetadata: + _data = dict(data) + assert _data.pop("node_type", None) in ("group", None) zarr_format = data.get("zarr_format") if zarr_format == 2 or zarr_format is None: @@ -125,12 +135,12 @@ def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: # We don't want the GroupMetadata constructor to fail just because someone put an # extra key in the metadata. expected = {x.name for x in fields(cls)} - data = {k: v for k, v in data.items() if k in expected} + _data = {k: v for k, v in data.items() if k in expected} - return cls(**data) + return cls(**_data) # type: ignore[arg-type] - def to_dict(self) -> dict[str, Any]: - return asdict(self) + def to_dict(self) -> GroupMetadataDict: + return cast(GroupMetadataDict, asdict(self)) @dataclass(frozen=True) @@ -198,11 +208,13 @@ async def open( else: raise ValueError(f"unexpected zarr_format: {zarr_format}") + group_metadata: GroupMetadataDict if zarr_format == 2: # V2 groups are comprised of a .zgroup and .zattrs objects assert zgroup_bytes is not None zgroup = json.loads(zgroup_bytes.to_bytes()) zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} + # TODO: Mypy Non-required key "node_type" not explicitly found in any ** item group_metadata = {**zgroup, "attributes": zattrs} else: # V3 groups are comprised of a zarr.json object @@ -215,7 +227,7 @@ async def open( def from_dict( cls, store_path: StorePath, - data: dict[str, Any], + data: GroupMetadataDict, ) -> AsyncGroup: return cls( metadata=GroupMetadata.from_dict(data), diff --git a/src/zarr/core/metadata/common.py b/src/zarr/core/metadata/common.py index 7d71455a44..3a56655bf5 100644 --- a/src/zarr/core/metadata/common.py +++ b/src/zarr/core/metadata/common.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Literal, Self + from typing import Any, Literal, NotRequired, Self import numpy as np @@ -14,12 +14,22 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import TypedDict -from zarr.abc.metadata import Metadata +from zarr.abc.metadata import Metadata, T + + +# TODO: Revisit Optional vs ... | None +class ArrayMetadataDict(TypedDict): + """A dictionary representing array metadata common to all Zarr versions.""" + + shape: ChunkCoords + attributes: NotRequired[dict[str, JSON]] # TODO: Double-check if NotRequired is appropriate + zarr_format: Literal[2, 3] @dataclass(frozen=True, kw_only=True) -class ArrayMetadata(Metadata, ABC): +class ArrayMetadata(Metadata[T], ABC): shape: ChunkCoords fill_value: Any chunk_grid: ChunkGrid diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 6d8f2a8ab1..dfd1594336 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast if TYPE_CHECKING: - from typing import Any, Literal, Self + from typing import Any, Literal, NotRequired, Self import numpy.typing as npt @@ -24,11 +24,23 @@ from zarr.core.chunk_key_encodings import parse_separator from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike from zarr.core.config import config, parse_indexing_order -from zarr.core.metadata.common import ArrayMetadata, parse_attributes +from zarr.core.metadata.common import ArrayMetadata, ArrayMetadataDict, parse_attributes + + +class ArrayV2MetadataDict(ArrayMetadataDict): + """A dictionary representing array metadata for Zarr version 2.""" + + chunks: RegularChunkGrid + dtype: np.dtype[Any] + fill_value: NotRequired[None | int | float | str | bytes] + order: Literal["C", "F"] + filters: NotRequired[Iterable[numcodecs.abc.Codec | dict[str, JSON]]] + dimension_separator: NotRequired[Literal[".", "/"]] + compressor: NotRequired[numcodecs.abc.Codec] @dataclass(frozen=True, kw_only=True) -class ArrayV2Metadata(ArrayMetadata): +class ArrayV2Metadata(ArrayMetadata[ArrayV2MetadataDict]): shape: ChunkCoords chunk_grid: RegularChunkGrid data_type: np.dtype[Any] @@ -136,9 +148,10 @@ def _json_convert( } @classmethod - def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: + def from_dict(cls, data: ArrayV2MetadataDict) -> ArrayV2Metadata: # make a copy to protect the original from modification - _data = data.copy() + _data = dict(data) + # check that the zarr_format attribute is correct _ = parse_zarr_format(_data.pop("zarr_format")) dtype = parse_dtype(_data["dtype"]) @@ -159,10 +172,10 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: _data = {k: v for k, v in _data.items() if k in expected} - return cls(**_data) + return cls(**_data) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - zarray_dict = super().to_dict() + def to_dict(self) -> ArrayV2MetadataDict: + zarray_dict = dict(super().to_dict()) if self.dtype.kind in "SV" and self.fill_value is not None: # There's a relationship between self.dtype and self.fill_value @@ -177,7 +190,7 @@ def to_dict(self) -> dict[str, JSON]: _ = zarray_dict.pop("data_type") zarray_dict["dtype"] = self.data_type.str - return zarray_dict + return cast(ArrayV2MetadataDict, zarray_dict) def get_chunk_spec( self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 47c6106bfe..6ba145143a 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, NotRequired, overload if TYPE_CHECKING: from typing import Self @@ -23,11 +23,11 @@ from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArraySpec from zarr.core.buffer import default_buffer_prototype -from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid -from zarr.core.chunk_key_encodings import ChunkKeyEncoding +from zarr.core.chunk_grids import ChunkGrid, ChunkGridDict, RegularChunkGrid +from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingDict from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike from zarr.core.config import config -from zarr.core.metadata.common import ArrayMetadata, parse_attributes +from zarr.core.metadata.common import ArrayMetadata, ArrayMetadataDict, parse_attributes from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE from zarr.registry import get_codec_class @@ -68,7 +68,7 @@ def validate_codecs(codecs: tuple[Codec, ...], dtype: DataType) -> None: """Check that the codecs are valid for the given dtype""" # ensure that we have at least one ArrayBytesCodec - abcs: list[ArrayBytesCodec] = [] + abcs: list[ArrayBytesCodec[Any]] = [] for codec in codecs: if isinstance(codec, ArrayBytesCodec): abcs.append(codec) @@ -179,8 +179,21 @@ def _replace_special_floats(obj: object) -> Any: return obj +class ArrayV3MetadataDict(ArrayMetadataDict): + """A dictionary representing array metadata for Zarr version 3.""" + + chunk_grid: ChunkGrid + data_type: npt.DTypeLike | DataType + chunk_key_encoding: ChunkKeyEncoding + fill_value: Any + codecs: tuple[Codec, ...] + dimension_names: NotRequired[tuple[str, ...]] + node_type: Literal["array"] + storage_transformers: tuple[dict[str, JSON], ...] + + @dataclass(frozen=True, kw_only=True) -class ArrayV3Metadata(ArrayMetadata): +class ArrayV3Metadata(ArrayMetadata[ArrayV3MetadataDict]): shape: ChunkCoords data_type: DataType chunk_grid: ChunkGrid @@ -198,8 +211,8 @@ def __init__( *, shape: Iterable[int], data_type: npt.DTypeLike | DataType, - chunk_grid: dict[str, JSON] | ChunkGrid, - chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding, + chunk_grid: ChunkGridDict | ChunkGrid, + chunk_key_encoding: ChunkKeyEncodingDict | ChunkKeyEncoding, fill_value: Any, codecs: Iterable[Codec | dict[str, JSON]], attributes: None | dict[str, JSON], @@ -295,9 +308,9 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: return {ZARR_JSON: prototype.buffer.from_bytes(json.dumps(d, cls=V3JsonEncoder).encode())} @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls, data: ArrayV3MetadataDict) -> ArrayV3Metadata: # make a copy because we are modifying the dict - _data = data.copy() + _data = dict(data) # check that the zarr_format attribute is correct _ = parse_zarr_format(_data.pop("zarr_format")) @@ -311,10 +324,11 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _data["dimension_names"] = _data.pop("dimension_names", None) # attributes key is optional, normalize missing to `None` _data["attributes"] = _data.pop("attributes", None) + return cls(**_data, data_type=data_type) # type: ignore[arg-type] - def to_dict(self) -> dict[str, JSON]: - out_dict = super().to_dict() + def to_dict(self) -> ArrayV3MetadataDict: + out_dict = dict(super().to_dict()) if not isinstance(out_dict, dict): raise TypeError(f"Expected dict. Got {type(out_dict)}.") @@ -323,7 +337,8 @@ def to_dict(self) -> dict[str, JSON]: # the metadata document if out_dict["dimension_names"] is None: out_dict.pop("dimension_names") - return out_dict + + return cast(ArrayV3MetadataDict, out_dict) def update_shape(self, shape: ChunkCoords) -> Self: return replace(self, shape=shape)