Skip to content

Commit a367268

Browse files
committed
add numcodec protocol
1 parent 4eda04e commit a367268

File tree

9 files changed

+142
-31
lines changed

9 files changed

+142
-31
lines changed

src/zarr/abc/codec.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Generic, TypeVar
4+
from collections.abc import Mapping
5+
from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar
6+
7+
from typing_extensions import ReadOnly, TypedDict
58

69
from zarr.abc.metadata import Metadata
710
from zarr.core.buffer import Buffer, NDBuffer
8-
from zarr.core.common import ChunkCoords, concurrent_map
11+
from zarr.core.common import ChunkCoords, NamedConfig, concurrent_map
912
from zarr.core.config import config
1013

1114
if TYPE_CHECKING:
@@ -34,6 +37,27 @@
3437
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
3538
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
3639

40+
TName = TypeVar("TName", bound=str, covariant=True)
41+
42+
43+
class CodecJSON_V2(TypedDict, Generic[TName]):
44+
"""The JSON representation of a codec for Zarr V2"""
45+
46+
id: ReadOnly[TName]
47+
48+
49+
def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
50+
return isinstance(data, Mapping) and "id" in data and isinstance(data["id"], str)
51+
52+
53+
CodecJSON_V3 = str | NamedConfig[str, Mapping[str, object]]
54+
"""The JSON representation of a codec for Zarr V3."""
55+
56+
# The widest type we will *accept* for a codec JSON
57+
# This covers v2 and v3
58+
CodecJSON = str | Mapping[str, object]
59+
"""The widest type of JSON-like input that could specify a codec."""
60+
3761

3862
class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]):
3963
"""Generic base class for codecs.

src/zarr/api/asynchronous.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@
4646
if TYPE_CHECKING:
4747
from collections.abc import Iterable
4848

49-
import numcodecs.abc
50-
5149
from zarr.abc.codec import Codec
50+
from zarr.codecs._v2 import Numcodec
5251
from zarr.core.buffer import NDArrayLikeOrScalar
5352
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
5453
from zarr.storage import StoreLike
@@ -871,7 +870,7 @@ async def create(
871870
overwrite: bool = False,
872871
path: PathLike | None = None,
873872
chunk_store: StoreLike | None = None,
874-
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
873+
filters: Iterable[dict[str, JSON] | Numcodec] | None = None,
875874
cache_metadata: bool | None = None,
876875
cache_attrs: bool | None = None,
877876
read_only: bool | None = None,

src/zarr/api/synchronous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
if TYPE_CHECKING:
1515
from collections.abc import Iterable
1616

17-
import numcodecs.abc
1817
import numpy as np
1918
import numpy.typing as npt
2019

2120
from zarr.abc.codec import Codec
2221
from zarr.api.asynchronous import ArrayLike, PathLike
22+
from zarr.codecs._v2 import Numcodec
2323
from zarr.core.array import (
2424
CompressorsLike,
2525
FiltersLike,
@@ -609,7 +609,7 @@ def create(
609609
overwrite: bool = False,
610610
path: PathLike | None = None,
611611
chunk_store: StoreLike | None = None,
612-
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
612+
filters: Iterable[dict[str, JSON] | Numcodec] | None = None,
613613
cache_metadata: bool | None = None,
614614
cache_attrs: bool | None = None,
615615
read_only: bool | None = None,

src/zarr/codecs/_numcodecs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numcodecs.registry as numcodecs_registry
2+
3+
from zarr.abc.codec import CodecJSON_V2
4+
from zarr.codecs._v2 import Numcodec
5+
6+
7+
def get_numcodec(data: CodecJSON_V2[str]) -> Numcodec:
8+
"""
9+
Resolve a numcodec codec from the numcodecs registry.
10+
11+
This requires the Numcodecs package to be installed.
12+
13+
Parameters
14+
----------
15+
data : CodecJSON_V2
16+
The JSON metadata for the codec.
17+
18+
Returns
19+
-------
20+
codec : Numcodec
21+
22+
Examples
23+
--------
24+
25+
>>> codec = get_numcodec({'id': 'zlib', 'level': 1})
26+
>>> codec
27+
Zlib(level=1)
28+
"""
29+
30+
codec_id = data["id"]
31+
cls = numcodecs_registry.codec_registry.get(codec_id)
32+
if cls is None and data in numcodecs_registry.entries:
33+
cls = numcodecs_registry.entries[data].load()
34+
numcodecs_registry.register_codec(cls, codec_id=data)
35+
if cls is not None:
36+
return cls.from_config({k: v for k, v in data.items() if k != "id"}) # type: ignore[no-any-return]
37+
raise KeyError(data)

src/zarr/codecs/_v2.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,77 @@
22

33
import asyncio
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, ClassVar, Self, TypeGuard
66

7-
import numcodecs
87
import numpy as np
98
from numcodecs.compat import ensure_bytes, ensure_ndarray_like
9+
from typing_extensions import Protocol
1010

11-
from zarr.abc.codec import ArrayBytesCodec
11+
from zarr.abc.codec import ArrayBytesCodec, CodecJSON_V2
1212
from zarr.registry import get_ndbuffer_class
1313

1414
if TYPE_CHECKING:
15-
import numcodecs.abc
16-
1715
from zarr.core.array_spec import ArraySpec
1816
from zarr.core.buffer import Buffer, NDBuffer
1917

2018

19+
class Numcodec(Protocol):
20+
"""
21+
A protocol that models the ``numcodecs.abc.Codec`` interface.
22+
"""
23+
24+
codec_id: ClassVar[str]
25+
26+
def encode(self, buf: Buffer | NDBuffer) -> Buffer | NDBuffer: ...
27+
28+
def decode(
29+
self, buf: Buffer | NDBuffer, out: Buffer | NDBuffer | None = None
30+
) -> Buffer | NDBuffer: ...
31+
32+
def get_config(self) -> CodecJSON_V2[str]: ...
33+
34+
@classmethod
35+
def from_config(cls, config: CodecJSON_V2[str]) -> Self: ...
36+
37+
38+
def _is_numcodec(obj: object) -> TypeGuard[Numcodec]:
39+
"""
40+
Check if the given object implements the Numcodec protocol.
41+
42+
The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
43+
members (i.e., attributes), so we use this function to manually check for the presence of the
44+
required attributes and methods on a given object.
45+
"""
46+
return _is_numcodec_cls(type(obj))
47+
48+
49+
def _is_numcodec_cls(obj: object) -> TypeGuard[type[Numcodec]]:
50+
"""
51+
Check if the given object is a class implements the Numcodec protocol.
52+
53+
The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
54+
members (i.e., attributes), so we use this function to manually check for the presence of the
55+
required attributes and methods on a given object.
56+
"""
57+
return (
58+
isinstance(obj, type)
59+
and hasattr(obj, "codec_id")
60+
and isinstance(obj.codec_id, str)
61+
and hasattr(obj, "encode")
62+
and callable(obj.encode)
63+
and hasattr(obj, "decode")
64+
and callable(obj.decode)
65+
and hasattr(obj, "get_config")
66+
and callable(obj.get_config)
67+
and hasattr(obj, "from_config")
68+
and callable(obj.from_config)
69+
)
70+
71+
2172
@dataclass(frozen=True)
2273
class V2Codec(ArrayBytesCodec):
23-
filters: tuple[numcodecs.abc.Codec, ...] | None
24-
compressor: numcodecs.abc.Codec | None
74+
filters: tuple[Numcodec, ...] | None
75+
compressor: Numcodec | None
2576

2677
is_fixed_size = False
2778

@@ -33,9 +84,9 @@ async def _decode_single(
3384
cdata = chunk_bytes.as_array_like()
3485
# decompress
3586
if self.compressor:
36-
chunk = await asyncio.to_thread(self.compressor.decode, cdata)
87+
chunk = await asyncio.to_thread(self.compressor.decode, cdata) # type: ignore[arg-type]
3788
else:
38-
chunk = cdata
89+
chunk = cdata # type: ignore[assignment]
3990

4091
# apply filters
4192
if self.filters:
@@ -56,7 +107,7 @@ async def _decode_single(
56107
# is an object array. In this case, we need to convert the object
57108
# array to the correct dtype.
58109

59-
chunk = np.array(chunk).astype(chunk_spec.dtype.to_native_dtype())
110+
chunk = np.array(chunk).astype(chunk_spec.dtype.to_native_dtype()) # type: ignore[assignment]
60111

61112
elif chunk.dtype != object:
62113
# If we end up here, someone must have hacked around with the filters.
@@ -85,17 +136,17 @@ async def _encode_single(
85136
# apply filters
86137
if self.filters:
87138
for f in self.filters:
88-
chunk = await asyncio.to_thread(f.encode, chunk)
139+
chunk = await asyncio.to_thread(f.encode, chunk) # type: ignore[arg-type]
89140

90141
# check object encoding
91142
if ensure_ndarray_like(chunk).dtype == object:
92143
raise RuntimeError("cannot write object array without object codec")
93144

94145
# compress
95146
if self.compressor:
96-
cdata = await asyncio.to_thread(self.compressor.encode, chunk)
147+
cdata = await asyncio.to_thread(self.compressor.encode, chunk) # type: ignore[arg-type]
97148
else:
98-
cdata = chunk
149+
cdata = chunk # type: ignore[assignment]
99150

100151
cdata = ensure_bytes(cdata)
101152
return chunk_spec.prototype.buffer.from_bytes(cdata)

src/zarr/core/array.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import zarr
2828
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2929
from zarr.abc.store import Store, set_or_delete
30-
from zarr.codecs._v2 import V2Codec
30+
from zarr.codecs._v2 import Numcodec, V2Codec
3131
from zarr.codecs.bytes import BytesCodec
3232
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
3333
from zarr.codecs.zstd import ZstdCodec
@@ -607,7 +607,7 @@ async def _create(
607607
chunks: ShapeLike | None = None,
608608
dimension_separator: Literal[".", "/"] | None = None,
609609
order: MemoryOrder | None = None,
610-
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
610+
filters: Iterable[dict[str, JSON] | Numcodec] | None = None,
611611
compressor: CompressorLike = "auto",
612612
# runtime
613613
overwrite: bool = False,
@@ -818,7 +818,7 @@ def _create_metadata_v2(
818818
order: MemoryOrder,
819819
dimension_separator: Literal[".", "/"] | None = None,
820820
fill_value: Any | None = DEFAULT_FILL_VALUE,
821-
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
821+
filters: Iterable[dict[str, JSON] | Numcodec] | None = None,
822822
compressor: CompressorLikev2 = None,
823823
attributes: dict[str, JSON] | None = None,
824824
) -> ArrayV2Metadata:
@@ -856,7 +856,7 @@ async def _create_v2(
856856
config: ArrayConfig,
857857
dimension_separator: Literal[".", "/"] | None = None,
858858
fill_value: Any | None = DEFAULT_FILL_VALUE,
859-
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
859+
filters: Iterable[dict[str, JSON] | Numcodec] | None = None,
860860
compressor: CompressorLike = "auto",
861861
attributes: dict[str, JSON] | None = None,
862862
overwrite: bool = False,
@@ -3898,7 +3898,7 @@ def _build_parents(
38983898

38993899

39003900
FiltersLike: TypeAlias = (
3901-
Iterable[dict[str, JSON] | ArrayArrayCodec | numcodecs.abc.Codec]
3901+
Iterable[dict[str, JSON] | ArrayArrayCodec | Numcodec]
39023902
| ArrayArrayCodec
39033903
| Iterable[numcodecs.abc.Codec]
39043904
| numcodecs.abc.Codec
@@ -3911,10 +3911,10 @@ def _build_parents(
39113911
)
39123912

39133913
CompressorsLike: TypeAlias = (
3914-
Iterable[dict[str, JSON] | BytesBytesCodec | numcodecs.abc.Codec]
3914+
Iterable[dict[str, JSON] | BytesBytesCodec | Numcodec]
39153915
| dict[str, JSON]
39163916
| BytesBytesCodec
3917-
| numcodecs.abc.Codec
3917+
| Numcodec
39183918
| Literal["auto"]
39193919
| None
39203920
)
@@ -4944,7 +4944,7 @@ def _parse_deprecated_compressor(
49444944
# "no compression"
49454945
compressors = ()
49464946
else:
4947-
compressors = (compressor,)
4947+
compressors = (compressor,) # type: ignore[assignment]
49484948
elif zarr_format == 2 and compressor == compressors == "auto":
49494949
compressors = ({"id": "blosc"},)
49504950
return compressors

tests/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None) -> None:
12821282
dtype=src.dtype,
12831283
overwrite=True,
12841284
zarr_format=zarr_format,
1285-
compressors=compressors,
1285+
compressors=compressors, # type: ignore[arg-type]
12861286
)
12871287
z[:10, :10] = src[:10, :10]
12881288

tests/test_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,7 @@ def test_roundtrip_numcodecs() -> None:
16841684
shape=(720, 1440),
16851685
chunks=(720, 1440),
16861686
dtype="float64",
1687-
compressors=compressors,
1687+
compressors=compressors, # type: ignore[arg-type]
16881688
filters=filters,
16891689
fill_value=-9.99,
16901690
dimension_names=["lat", "lon"],

tests/test_codecs/test_vlen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_vlen_string(
4040
chunks=data.shape,
4141
dtype=data.dtype,
4242
fill_value="",
43-
compressors=compressor,
43+
compressors=compressor, # type: ignore[arg-type]
4444
)
4545
assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy
4646

0 commit comments

Comments
 (0)