Skip to content

Commit 0e9ae08

Browse files
committed
Make Metadata abstract base class generic
Redefine `Metadata` class to be generic and refactor downstream classes
1 parent 0c65e56 commit 0e9ae08

File tree

18 files changed

+124
-114
lines changed

18 files changed

+124
-114
lines changed

src/zarr/abc/codec.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import abstractmethod
44
from typing import TYPE_CHECKING, Any, Generic, NotRequired, TypedDict, TypeVar
55

6-
from zarr.abc.metadata import Metadata
6+
from zarr.abc.metadata import Metadata, T
77
from zarr.core.buffer import Buffer, NDBuffer
88
from zarr.core.common import ChunkCoords, concurrent_map
99
from zarr.core.config import config
@@ -35,7 +35,7 @@
3535
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
3636

3737

38-
class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]):
38+
class BaseCodec(Generic[CodecInput, CodecOutput, T], Metadata[T]):
3939
"""Generic base class for codecs.
4040
4141
Codecs can be registered via zarr.codecs.registry.
@@ -153,25 +153,25 @@ async def encode(
153153
return await _batching_helper(self._encode_single, chunks_and_specs)
154154

155155

156-
class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]):
156+
class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer, T]):
157157
"""Base class for array-to-array codecs."""
158158

159159
...
160160

161161

162-
class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]):
162+
class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer, T]):
163163
"""Base class for array-to-bytes codecs."""
164164

165165
...
166166

167167

168-
class BytesBytesCodec(BaseCodec[Buffer, Buffer]):
168+
class BytesBytesCodec(BaseCodec[Buffer, Buffer, T]):
169169
"""Base class for bytes-to-bytes codecs."""
170170

171171
...
172172

173173

174-
Codec = ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec
174+
Codec = ArrayArrayCodec[Any] | ArrayBytesCodec[Any] | BytesBytesCodec[Any]
175175

176176

177177
class CodecConfigDict(TypedDict):
@@ -180,14 +180,14 @@ class CodecConfigDict(TypedDict):
180180
...
181181

182182

183-
T = TypeVar("T", bound=CodecConfigDict)
183+
CodecConfigDictType = TypeVar("CodecConfigDictType", bound=CodecConfigDict)
184184

185185

186-
class CodecDict(TypedDict, Generic[T]):
186+
class CodecDict(Generic[CodecConfigDictType], TypedDict):
187187
"""A generic dictionary representing a codec."""
188188

189189
name: str
190-
configuration: NotRequired[T]
190+
configuration: NotRequired[CodecConfigDictType]
191191

192192

193193
class ArrayBytesCodecPartialDecodeMixin:

src/zarr/abc/metadata.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
from __future__ import annotations
22

3-
from collections.abc import Sequence
4-
from typing import TYPE_CHECKING
3+
from collections.abc import Mapping, Sequence
4+
from typing import TYPE_CHECKING, Generic, TypeVar, cast
55

66
if TYPE_CHECKING:
77
from typing import Self
88

9-
from zarr.core.common import JSON
109

1110
from dataclasses import dataclass, fields
1211

1312
__all__ = ["Metadata"]
1413

14+
T = TypeVar("T", bound=Mapping[str, object])
15+
1516

1617
@dataclass(frozen=True)
17-
class Metadata:
18-
def to_dict(self) -> dict[str, JSON]:
18+
class Metadata(Generic[T]):
19+
def to_dict(self) -> T:
1920
"""
2021
Recursively serialize this model to a dictionary.
2122
This method inspects the fields of self and calls `x.to_dict()` for any fields that
@@ -36,10 +37,10 @@ def to_dict(self) -> dict[str, JSON]:
3637
else:
3738
out_dict[key] = value
3839

39-
return out_dict
40+
return cast(T, out_dict)
4041

4142
@classmethod
42-
def from_dict(cls, data: dict[str, JSON]) -> Self:
43+
def from_dict(cls, data: T) -> Self:
4344
"""
4445
Create an instance of the model from a dictionary
4546
"""

src/zarr/codecs/_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55

66
import numcodecs
77
from numcodecs.compat import ensure_bytes, ensure_ndarray
@@ -18,7 +18,7 @@
1818

1919

2020
@dataclass(frozen=True)
21-
class V2Compressor(ArrayBytesCodec):
21+
class V2Compressor(ArrayBytesCodec[Any]):
2222
compressor: numcodecs.abc.Codec | None
2323

2424
is_fixed_size = False
@@ -66,7 +66,7 @@ def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec)
6666

6767

6868
@dataclass(frozen=True)
69-
class V2Filters(ArrayArrayCodec):
69+
class V2Filters(ArrayArrayCodec[Any]):
7070
filters: tuple[numcodecs.abc.Codec, ...] | None
7171

7272
is_fixed_size = False

src/zarr/codecs/blosc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def parse_blocksize(data: JSON) -> int:
9999

100100

101101
@dataclass(frozen=True)
102-
class BloscCodec(BytesBytesCodec):
102+
class BloscCodec(BytesBytesCodec[BloscCodecDict]):
103103
is_fixed_size = False
104104

105105
typesize: int | None
@@ -130,7 +130,7 @@ def __init__(
130130
object.__setattr__(self, "blocksize", blocksize_parsed)
131131

132132
@classmethod
133-
def from_dict(cls, data: dict[str, JSON]) -> Self:
133+
def from_dict(cls, data: BloscCodecDict) -> Self:
134134
_, configuration_parsed = parse_named_configuration(data, "blosc")
135135
return cls(**configuration_parsed) # type: ignore[arg-type]
136136

src/zarr/codecs/bytes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from zarr.abc.codec import ArrayBytesCodec, CodecConfigDict, CodecDict
1111
from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer
12-
from zarr.core.common import JSON, parse_enum, parse_named_configuration
12+
from zarr.core.common import parse_enum, parse_named_configuration
1313
from zarr.registry import register_codec
1414

1515
if TYPE_CHECKING:
16-
from typing import Self
16+
from typing import Literal, Self
1717

1818
from zarr.core.array_spec import ArraySpec
1919

@@ -33,7 +33,8 @@ class Endian(Enum):
3333
class BytesCodecConfigDict(CodecConfigDict):
3434
"""A dictionary representing a bytes codec configuration."""
3535

36-
endian: Endian
36+
# TODO: Why not type this w/ the Endian Enum
37+
endian: Literal["big", "little"]
3738

3839

3940
class BytesCodecDict(CodecDict[BytesCodecConfigDict]):
@@ -43,7 +44,7 @@ class BytesCodecDict(CodecDict[BytesCodecConfigDict]):
4344

4445

4546
@dataclass(frozen=True)
46-
class BytesCodec(ArrayBytesCodec):
47+
class BytesCodec(ArrayBytesCodec[BytesCodecDict]):
4748
is_fixed_size = True
4849

4950
endian: Endian | None
@@ -54,10 +55,11 @@ def __init__(self, *, endian: Endian | str | None = default_system_endian) -> No
5455
object.__setattr__(self, "endian", endian_parsed)
5556

5657
@classmethod
57-
def from_dict(cls, data: dict[str, JSON]) -> Self:
58+
def from_dict(cls, data: BytesCodecDict) -> Self:
5859
_, configuration_parsed = parse_named_configuration(
5960
data, "bytes", require_configuration=False
6061
)
62+
6163
configuration_parsed = configuration_parsed or {}
6264
return cls(**configuration_parsed) # type: ignore[arg-type]
6365

src/zarr/codecs/crc32c_.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from crc32c import crc32c
99

1010
from zarr.abc.codec import BytesBytesCodec, CodecConfigDict, CodecDict
11-
from zarr.core.common import JSON, parse_named_configuration
11+
from zarr.core.common import parse_named_configuration
1212
from zarr.registry import register_codec
1313

1414
if TYPE_CHECKING:
@@ -25,11 +25,11 @@ class Crc32cCodecDict(CodecDict[CodecConfigDict]):
2525

2626

2727
@dataclass(frozen=True)
28-
class Crc32cCodec(BytesBytesCodec):
28+
class Crc32cCodec(BytesBytesCodec[Crc32cCodecDict]):
2929
is_fixed_size = True
3030

3131
@classmethod
32-
def from_dict(cls, data: dict[str, JSON]) -> Self:
32+
def from_dict(cls, data: Crc32cCodecDict) -> Self:
3333
parse_named_configuration(data, "crc32c", require_configuration=False)
3434
return cls()
3535

src/zarr/codecs/gzip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def parse_gzip_level(data: JSON) -> int:
4040

4141

4242
@dataclass(frozen=True)
43-
class GzipCodec(BytesBytesCodec):
43+
class GzipCodec(BytesBytesCodec[GzipCodecDict]):
4444
is_fixed_size = False
4545

4646
level: int = 5
@@ -51,7 +51,7 @@ def __init__(self, *, level: int = 5) -> None:
5151
object.__setattr__(self, "level", level_parsed)
5252

5353
@classmethod
54-
def from_dict(cls, data: dict[str, JSON]) -> Self:
54+
def from_dict(cls, data: GzipCodecDict) -> Self:
5555
_, configuration_parsed = parse_named_configuration(data, "gzip")
5656
return cls(**configuration_parsed) # type: ignore[arg-type]
5757

src/zarr/codecs/pipeline.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def resolve_batched(codec: Codec, chunk_specs: Iterable[ArraySpec]) -> Iterable[
5555
return [codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs]
5656

5757

58+
# TODO: Double-check whether `CodecDict[Any]` is appropriate
5859
@dataclass(frozen=True)
5960
class BatchedCodecPipeline(CodecPipeline):
6061
"""Default codec pipeline.
@@ -64,9 +65,9 @@ class BatchedCodecPipeline(CodecPipeline):
6465
lock step for each mini-batch. Multiple mini-batches are processing concurrently.
6566
"""
6667

67-
array_array_codecs: tuple[ArrayArrayCodec, ...]
68-
array_bytes_codec: ArrayBytesCodec
69-
bytes_bytes_codecs: tuple[BytesBytesCodec, ...]
68+
array_array_codecs: tuple[ArrayArrayCodec[Any], ...]
69+
array_bytes_codec: ArrayBytesCodec[Any]
70+
bytes_bytes_codecs: tuple[BytesBytesCodec[Any], ...]
7071
batch_size: int
7172

7273
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
@@ -133,11 +134,11 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
133134
def _codecs_with_resolved_metadata_batched(
134135
self, chunk_specs: Iterable[ArraySpec]
135136
) -> tuple[
136-
list[tuple[ArrayArrayCodec, list[ArraySpec]]],
137-
tuple[ArrayBytesCodec, list[ArraySpec]],
138-
list[tuple[BytesBytesCodec, list[ArraySpec]]],
137+
list[tuple[ArrayArrayCodec[Any], list[ArraySpec]]],
138+
tuple[ArrayBytesCodec[Any], list[ArraySpec]],
139+
list[tuple[BytesBytesCodec[Any], list[ArraySpec]]],
139140
]:
140-
aa_codecs_with_spec: list[tuple[ArrayArrayCodec, list[ArraySpec]]] = []
141+
aa_codecs_with_spec: list[tuple[ArrayArrayCodec[Any], list[ArraySpec]]] = []
141142
chunk_specs = list(chunk_specs)
142143
for aa_codec in self.array_array_codecs:
143144
aa_codecs_with_spec.append((aa_codec, chunk_specs))
@@ -148,7 +149,7 @@ def _codecs_with_resolved_metadata_batched(
148149
self.array_bytes_codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs
149150
]
150151

151-
bb_codecs_with_spec: list[tuple[BytesBytesCodec, list[ArraySpec]]] = []
152+
bb_codecs_with_spec: list[tuple[BytesBytesCodec[Any], list[ArraySpec]]] = []
152153
for bb_codec in self.bytes_bytes_codecs:
153154
bb_codecs_with_spec.append((bb_codec, chunk_specs))
154155
chunk_specs = [bb_codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs]
@@ -451,12 +452,14 @@ async def write(
451452

452453
def codecs_from_list(
453454
codecs: Iterable[Codec],
454-
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
455+
) -> tuple[
456+
tuple[ArrayArrayCodec[Any], ...], ArrayBytesCodec[Any], tuple[BytesBytesCodec[Any], ...]
457+
]:
455458
from zarr.codecs.sharding import ShardingCodec
456459

457-
array_array: tuple[ArrayArrayCodec, ...] = ()
458-
array_bytes_maybe: ArrayBytesCodec | None = None
459-
bytes_bytes: tuple[BytesBytesCodec, ...] = ()
460+
array_array: tuple[ArrayArrayCodec[Any], ...] = ()
461+
array_bytes_maybe: ArrayBytesCodec[Any] | None = None
462+
bytes_bytes: tuple[BytesBytesCodec[Any], ...] = ()
460463

461464
if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1:
462465
warn(

src/zarr/codecs/sharding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ class ShardingCodecDict(CodecDict[ShardingCodecConfigDict]):
339339

340340
@dataclass(frozen=True)
341341
class ShardingCodec(
342-
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
342+
ArrayBytesCodec[ShardingCodecDict],
343+
ArrayBytesCodecPartialDecodeMixin,
344+
ArrayBytesCodecPartialEncodeMixin,
343345
):
344346
chunk_shape: ChunkCoords
345347
codecs: tuple[Codec, ...]
@@ -370,7 +372,7 @@ def __init__(
370372
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))
371373

372374
# todo: typedict return type
373-
def __getstate__(self) -> dict[str, Any]:
375+
def __getstate__(self) -> ShardingCodecDict:
374376
return self.to_dict()
375377

376378
def __setstate__(self, state: dict[str, Any]) -> None:
@@ -386,7 +388,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
386388
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))
387389

388390
@classmethod
389-
def from_dict(cls, data: dict[str, JSON]) -> Self:
391+
def from_dict(cls, data: ShardingCodecDict) -> Self:
390392
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")
391393
return cls(**configuration_parsed) # type: ignore[arg-type]
392394

src/zarr/codecs/transpose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TransposeCodecDict(CodecDict[TransposeCodecConfigDict]):
3939

4040

4141
@dataclass(frozen=True)
42-
class TransposeCodec(ArrayArrayCodec):
42+
class TransposeCodec(ArrayArrayCodec[TransposeCodecDict]):
4343
is_fixed_size = True
4444

4545
order: tuple[int, ...]
@@ -50,7 +50,7 @@ def __init__(self, *, order: ChunkCoordsLike) -> None:
5050
object.__setattr__(self, "order", order_parsed)
5151

5252
@classmethod
53-
def from_dict(cls, data: dict[str, JSON]) -> Self:
53+
def from_dict(cls, data: TransposeCodecDict) -> Self:
5454
_, configuration_parsed = parse_named_configuration(data, "transpose")
5555
return cls(**configuration_parsed) # type: ignore[arg-type]
5656

0 commit comments

Comments
 (0)