Skip to content

Commit a134233

Browse files
committed
improve method annotations for numcodecs codecs
1 parent 122b19e commit a134233

22 files changed

+1239
-353
lines changed

src/zarr/abc/codec.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
Literal,
99
Self,
1010
TypedDict,
11-
TypeGuard,
1211
TypeVar,
1312
overload,
1413
)
1514

16-
from typing_extensions import ReadOnly
15+
from typing_extensions import ReadOnly, TypeIs
1716

1817
from zarr.abc.metadata import Metadata
1918
from zarr.core.buffer import Buffer, NDBuffer
@@ -46,16 +45,14 @@
4645
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
4746
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
4847

49-
TName = TypeVar("TName", bound=str, covariant=True)
5048

51-
52-
class CodecJSON_V2(TypedDict, Generic[TName]):
49+
class CodecJSON_V2(TypedDict):
5350
"""The JSON representation of a codec for Zarr V2"""
5451

55-
id: ReadOnly[TName]
52+
id: ReadOnly[str]
5653

5754

58-
def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
55+
def _check_codecjson_v2(data: object) -> TypeIs[CodecJSON_V2]:
5956
return isinstance(data, Mapping) and "id" in data and isinstance(data["id"], str)
6057

6158

@@ -64,7 +61,7 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
6461

6562
# The widest type we will *accept* for a codec JSON
6663
# This covers v2 and v3
67-
CodecJSON = str | Mapping[str, object]
64+
CodecJSON = CodecJSON_V2 | CodecJSON_V3
6865
"""The widest type of JSON-like input that could specify a codec."""
6966

7067

@@ -191,34 +188,28 @@ async def encode(
191188
return await _batching_helper(self._encode_single, chunks_and_specs)
192189

193190
@overload
194-
def to_json(self, zarr_format: Literal[2]) -> CodecJSON_V2[str]: ...
191+
def to_json(self, zarr_format: Literal[2]) -> CodecJSON_V2: ...
195192
@overload
196-
def to_json(self, zarr_format: Literal[3]) -> NamedConfig[str, Mapping[str, object]]: ...
193+
def to_json(self, zarr_format: Literal[3]) -> CodecJSON_V3: ...
197194

198-
def to_json(
199-
self, zarr_format: ZarrFormat
200-
) -> CodecJSON_V2[str] | NamedConfig[str, Mapping[str, object]]:
195+
def to_json(self, zarr_format: ZarrFormat) -> CodecJSON_V2 | CodecJSON_V3:
201196
raise NotImplementedError
202197

203198
@classmethod
204-
def _from_json_v2(cls, data: CodecJSON_V2[str]) -> Self:
199+
def _from_json_v2(cls, data: CodecJSON_V2) -> Self:
205200
return cls(**{k: v for k, v in data.items() if k != "id"})
206201

207202
@classmethod
208203
def _from_json_v3(cls, data: CodecJSON_V3) -> Self:
209204
if isinstance(data, str):
210205
return cls()
211-
return cls(**data["configuration"])
206+
return cls(**data.get("configuration", {}))
212207

213208
@classmethod
214-
def from_json(cls, data: CodecJSON, zarr_format: ZarrFormat) -> Self:
215-
if zarr_format == 2:
209+
def from_json(cls, data: CodecJSON) -> Self:
210+
if _check_codecjson_v2(data):
216211
return cls._from_json_v2(data)
217-
elif zarr_format == 3:
218-
return cls._from_json_v3(data)
219-
raise ValueError(
220-
f"Unsupported Zarr format {zarr_format}. Expected 2 or 3."
221-
) # pragma: no cover
212+
return cls._from_json_v3(data)
222213

223214

224215
class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]):

src/zarr/abc/numcodec.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Any, Self, TypeGuard
22

3-
from typing_extensions import Protocol
3+
from typing_extensions import Protocol, runtime_checkable
44

55

6+
@runtime_checkable
67
class Numcodec(Protocol):
78
"""
89
A protocol that models the ``numcodecs.abc.Codec`` interface.
@@ -88,14 +89,3 @@ def _is_numcodec_cls(obj: object) -> TypeGuard[type[Numcodec]]:
8889
and hasattr(obj, "from_config")
8990
and callable(obj.from_config)
9091
)
91-
92-
93-
def _is_numcodec(obj: object) -> TypeGuard[Numcodec]:
94-
"""
95-
Check if the given object implements the Numcodec protocol.
96-
97-
The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
98-
members (i.e., attributes), so we use this function to manually check for the presence of the
99-
required attributes and methods on a given object.
100-
"""
101-
return _is_numcodec_cls(type(obj))

src/zarr/codecs/_v2.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import asyncio
44
from dataclasses import dataclass
5-
from functools import cached_property
6-
from typing import TYPE_CHECKING, Literal, Self, overload
5+
from typing import TYPE_CHECKING, Literal, Self, cast, overload
76

87
import numpy as np
98
from numcodecs.compat import ensure_bytes, ensure_ndarray_like
@@ -15,11 +14,9 @@
1514
CodecJSON,
1615
CodecJSON_V2,
1716
)
18-
from zarr.registry import get_ndbuffer_class
17+
from zarr.registry import _get_codec_v2, _get_codec_v3, get_ndbuffer_class
1918

2019
if TYPE_CHECKING:
21-
from collections.abc import Mapping
22-
2320
from zarr.abc.numcodec import Numcodec
2421
from zarr.core.array_spec import ArraySpec
2522
from zarr.core.buffer import Buffer, NDBuffer
@@ -115,21 +112,16 @@ def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec)
115112

116113
@dataclass(frozen=True, kw_only=True)
117114
class NumcodecsWrapper:
118-
codec_cls: type[Numcodec]
119-
config: Mapping[str, object]
120-
121-
@cached_property
122-
def codec(self) -> Numcodec:
123-
return self.codec_cls(**self.config)
115+
codec: Numcodec
124116

125117
@overload
126-
def to_json(self, zarr_format: Literal[2]) -> CodecJSON_V2[str]: ...
118+
def to_json(self, zarr_format: Literal[2]) -> CodecJSON_V2: ...
127119
@overload
128120
def to_json(self, zarr_format: Literal[3]) -> NamedConfig[str, BaseConfig]: ...
129121

130-
def to_json(self, zarr_format: ZarrFormat) -> CodecJSON_V2[str] | NamedConfig[str, BaseConfig]:
122+
def to_json(self, zarr_format: ZarrFormat) -> CodecJSON_V2 | NamedConfig[str, BaseConfig]:
131123
if zarr_format == 2:
132-
return self.config
124+
return cast(CodecJSON_V2, self.codec.get_config())
133125
elif zarr_format == 3:
134126
config = self.codec.get_config()
135127
config_no_id = {k: v for k, v in config.items() if k != "id"}
@@ -138,11 +130,13 @@ def to_json(self, zarr_format: ZarrFormat) -> CodecJSON_V2[str] | NamedConfig[st
138130

139131
@classmethod
140132
def _from_json_v2(cls, data: CodecJSON) -> Self:
141-
return cls(config=data)
133+
codec = _get_codec_v2(data)
134+
return cls(codec=codec)
142135

143136
@classmethod
144137
def _from_json_v3(cls, data: CodecJSON) -> Self:
145-
return cls(config=data.get("configuration", {}))
138+
codec = _get_codec_v3(data)
139+
return cls(codec=codec)
146140

147141
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
148142
raise NotImplementedError
@@ -185,19 +179,19 @@ def to_array_array(self) -> NumcodecsArrayArrayCodec:
185179
"""
186180
Use the ``_codec`` attribute to create a NumcodecsArrayArrayCodec.
187181
"""
188-
return NumcodecsArrayArrayCodec(cls=self.codec_cls, config=self.config)
182+
return NumcodecsArrayArrayCodec(codec=self.codec)
189183

190184
def to_bytes_bytes(self) -> NumcodecsBytesBytesCodec:
191185
"""
192186
Use the ``_codec`` attribute to create a NumcodecsBytesBytesCodec.
193187
"""
194-
return NumcodecsBytesBytesCodec(cls=self.codec_cls, config=self.config)
188+
return NumcodecsBytesBytesCodec(codec=self.codec)
195189

196190
def to_array_bytes(self) -> NumcodecsArrayBytesCodec:
197191
"""
198192
Use the ``_codec`` attribute to create a NumcodecsArrayBytesCodec.
199193
"""
200-
return NumcodecsArrayBytesCodec(codec_cls=self.codec_cls, config=self.config)
194+
return NumcodecsArrayBytesCodec(codec=self.codec)
201195

202196

203197
class NumcodecsBytesBytesCodec(NumcodecsWrapper, BytesBytesCodec):
@@ -226,12 +220,12 @@ class NumcodecsArrayArrayCodec(NumcodecsWrapper, ArrayArrayCodec):
226220
async def _decode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer:
227221
chunk_ndarray = chunk_data.as_ndarray_like()
228222
out = await asyncio.to_thread(self.codec.decode, chunk_ndarray)
229-
return chunk_spec.prototype.nd_buffer.from_ndarray_like(out.reshape(chunk_spec.shape)) # type: ignore[union-attr]
223+
return chunk_spec.prototype.nd_buffer.from_ndarray_like(out.reshape(chunk_spec.shape))
230224

231225
async def _encode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer:
232226
chunk_ndarray = chunk_data.as_ndarray_like()
233227
out = await asyncio.to_thread(self.codec.encode, chunk_ndarray)
234-
return chunk_spec.prototype.nd_buffer.from_ndarray_like(out) # type: ignore[arg-type]
228+
return chunk_spec.prototype.nd_buffer.from_ndarray_like(out)
235229

236230

237231
@dataclass(kw_only=True, frozen=True)

src/zarr/codecs/blosc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ def check_json_v2(data: CodecJSON) -> TypeGuard[BloscJSON_V2]:
7575
return (
7676
isinstance(data, Mapping)
7777
and set(data.keys()) == {"id", "clevel", "cname", "shuffle", "blocksize"}
78-
and data["id"] == "blosc"
78+
and data["id"] == "blosc" # type: ignore[typeddict-item]
7979
)
8080

8181

8282
def check_json_v3(data: CodecJSON) -> TypeGuard[BloscJSON_V3]:
8383
return (
8484
isinstance(data, Mapping)
8585
and set(data.keys()) == {"name", "configuration"}
86-
and data["name"] == "blosc"
87-
and isinstance(data["configuration"], Mapping)
88-
and set(data["configuration"].keys())
86+
and data["name"] == "blosc" # type: ignore[typeddict-item]
87+
and isinstance(data["configuration"], Mapping) # type: ignore[typeddict-item]
88+
and set(data["configuration"].keys()) # type: ignore[typeddict-item]
8989
== {"cname", "clevel", "shuffle", "blocksize", "typesize"}
9090
)
9191

src/zarr/codecs/bytes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Mapping
55
from dataclasses import dataclass, replace
66
from enum import Enum
7-
from typing import TYPE_CHECKING, Final, Literal, NotRequired, TypedDict, TypeGuard, overload
7+
from typing import TYPE_CHECKING, Final, Literal, NotRequired, TypedDict, TypeGuard, cast, overload
88

99
import numpy as np
1010
from typing_extensions import ReadOnly
@@ -61,7 +61,7 @@ def check_json_v2(data: CodecJSON) -> TypeGuard[BytesJSON_V2]:
6161
return (
6262
isinstance(data, Mapping)
6363
and set(data.keys()) in ({"id", "endian"}, {"id"})
64-
and data["id"] == "bytes"
64+
and data["id"] == "bytes" # type: ignore[typeddict-item]
6565
)
6666

6767

@@ -70,10 +70,10 @@ def check_json_v3(data: CodecJSON) -> TypeGuard[BytesJSON_V3]:
7070
(
7171
isinstance(data, Mapping)
7272
and set(data.keys()) in ({"name"}, {"name", "configuration"})
73-
and data["name"] == "bytes"
73+
and data["name"] == "bytes" # type: ignore[typeddict-item]
7474
)
7575
and isinstance(data.get("configuration", {}), Mapping)
76-
and set(data.get("configuration", {}).keys()) in ({"endian"}, set())
76+
and set(data.get("configuration", {}).keys()) in ({"endian"}, set()) # type: ignore[attr-defined]
7777
)
7878

7979

@@ -90,10 +90,10 @@ def __init__(self, *, endian: EndiannessStr | str | None = default_system_endian
9090

9191
@classmethod
9292
def from_dict(cls, data: dict[str, JSON]) -> Self:
93-
return cls.from_json(data, zarr_format=3)
93+
return cls.from_json(data) # type: ignore[arg-type]
9494

9595
def to_dict(self) -> dict[str, JSON]:
96-
return self.to_json(zarr_format=3)
96+
return cast(dict[str, JSON], self.to_json(zarr_format=3))
9797

9898
@classmethod
9999
def _from_json_v2(cls, data: CodecJSON) -> Self:
@@ -108,7 +108,7 @@ def _from_json_v3(cls, data: CodecJSON) -> Self:
108108
if data in ("bytes", {"name": "bytes"}, {"name": "bytes", "configuration": {}}):
109109
return cls()
110110
else:
111-
return cls(endian=data["configuration"].get("endian", None))
111+
return cls(endian=data["configuration"].get("endian", None)) # type: ignore[union-attr, index]
112112
raise ValueError(f"Invalid JSON: {data}")
113113

114114
@overload
@@ -182,11 +182,11 @@ async def _encode_single(
182182
if (
183183
chunk_array.dtype.itemsize > 1
184184
and self.endian is not None
185-
and self.endian != chunk_array.byteorder
185+
and self.endian != chunk_array.byteorder.value
186186
):
187187
# type-ignore is a numpy bug
188188
# see https://github.com/numpy/numpy/issues/26473
189-
new_dtype = chunk_array.dtype.newbyteorder(self.endian) # type: ignore[arg-type]
189+
new_dtype = chunk_array.dtype.newbyteorder(self.endian)
190190
chunk_array = chunk_array.astype(new_dtype)
191191

192192
nd_array = chunk_array.as_ndarray_like()

src/zarr/codecs/crc32c_.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import numpy as np
88
import typing_extensions
99
from crc32c import crc32c
10+
from typing_extensions import ReadOnly
1011

11-
from zarr.abc.codec import BytesBytesCodec, CodecJSON, CodecJSON_V2
12+
from zarr.abc.codec import BytesBytesCodec, CodecJSON
1213
from zarr.core.common import JSON, NamedConfig, ZarrFormat, parse_named_configuration
1314
from zarr.errors import CodecValidationError
1415

@@ -19,13 +20,14 @@
1920
from zarr.core.buffer import Buffer
2021

2122

22-
class Crc32Config(TypedDict): ...
23+
class Crc32cConfig(TypedDict): ...
2324

2425

25-
class Crc32cJSON_V2(CodecJSON_V2[Literal["crc32c"]]): ...
26+
class Crc32cJSON_V2(TypedDict):
27+
id: ReadOnly[Literal["crc32c"]]
2628

2729

28-
class Crc32cJSON_V3(NamedConfig[Literal["crc32c"], Crc32Config]): ...
30+
class Crc32cJSON_V3(NamedConfig[Literal["crc32c"], Crc32cConfig]): ...
2931

3032

3133
def check_json_v2(data: CodecJSON) -> TypeGuard[Crc32cJSON_V2]:
@@ -47,7 +49,6 @@ class Crc32cCodec(BytesBytesCodec):
4749

4850
@classmethod
4951
def from_dict(cls, data: dict[str, JSON]) -> Self:
50-
return cls.from_json(data, zarr_format=3)
5152
parse_named_configuration(data, "crc32c", require_configuration=False)
5253
return cls()
5354

@@ -72,7 +73,6 @@ def _from_json_v3(cls, data: CodecJSON) -> Self:
7273
raise CodecValidationError(msg)
7374

7475
def to_dict(self) -> dict[str, JSON]:
75-
return self.to_json(zarr_format=3)
7676
return {"name": "crc32c"}
7777

7878
@overload

src/zarr/codecs/gzip.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
from collections.abc import Mapping
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Literal, TypedDict, TypeGuard, overload
6+
from typing import TYPE_CHECKING, Literal, TypedDict, TypeGuard, cast, overload
77

88
from numcodecs.gzip import GZip
99
from typing_extensions import ReadOnly
@@ -64,10 +64,10 @@ def __init__(self, *, level: int = 5) -> None:
6464

6565
@classmethod
6666
def from_dict(cls, data: dict[str, JSON]) -> Self:
67-
return cls.from_json(data, zarr_format=3)
67+
return cls.from_json(data) # type: ignore[arg-type]
6868

6969
def to_dict(self) -> dict[str, JSON]:
70-
return self.to_json(zarr_format=3)
70+
return cast(dict[str, JSON], self.to_json(zarr_format=3))
7171

7272
@overload
7373
def to_json(self, zarr_format: Literal[2]) -> GZipJSON_V2: ...
@@ -88,19 +88,19 @@ def _check_json_v2(cls, data: CodecJSON) -> TypeGuard[GZipJSON_V2]:
8888
return (
8989
isinstance(data, Mapping)
9090
and set(data.keys()) == {"id", "level"}
91-
and data["id"] == "gzip"
92-
and isinstance(data["level"], int)
91+
and data["id"] == "gzip" # type: ignore[typeddict-item]
92+
and isinstance(data["level"], int) # type: ignore[typeddict-item]
9393
)
9494

9595
@classmethod
9696
def _check_json_v3(cls, data: CodecJSON) -> TypeGuard[GZipJSON_V3]:
9797
return (
9898
isinstance(data, Mapping)
9999
and set(data.keys()) == {"name", "configuration"}
100-
and data["name"] == "gzip"
101-
and isinstance(data["configuration"], dict)
102-
and "level" in data["configuration"]
103-
and isinstance(data["configuration"]["level"], int)
100+
and data["name"] == "gzip" # type: ignore[typeddict-item]
101+
and isinstance(data["configuration"], Mapping) # type: ignore[typeddict-item]
102+
and "level" in data["configuration"] # type: ignore[typeddict-item]
103+
and isinstance(data["configuration"]["level"], int) # type: ignore[typeddict-item]
104104
)
105105

106106
@classmethod

0 commit comments

Comments
 (0)