22
33import asyncio
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING
5+ from typing import TYPE_CHECKING , ClassVar , Self , TypeGuard
66
7- import numcodecs
87import numpy as np
98from numcodecs .compat import ensure_bytes , ensure_ndarray_like
9+ from typing_extensions import Protocol
1010
11- from zarr .abc .codec import ArrayBytesCodec
11+ from zarr .abc .codec import ArrayBytesCodec , CodecJSON_V2
1212from zarr .registry import get_ndbuffer_class
1313
1414if TYPE_CHECKING :
15- import numcodecs .abc
16-
1715 from zarr .core .array_spec import ArraySpec
1816 from zarr .core .buffer import Buffer , NDBuffer
1917
2018
19+ class Numcodec (Protocol ):
20+ """
21+ A protocol that models the ``numcodecs.abc.Codec`` interface.
22+ """
23+
24+ codec_id : ClassVar [str ]
25+
26+ def encode (self , buf : Buffer | NDBuffer ) -> Buffer | NDBuffer : ...
27+
28+ def decode (
29+ self , buf : Buffer | NDBuffer , out : Buffer | NDBuffer | None = None
30+ ) -> Buffer | NDBuffer : ...
31+
32+ def get_config (self ) -> CodecJSON_V2 [str ]: ...
33+
34+ @classmethod
35+ def from_config (cls , config : CodecJSON_V2 [str ]) -> Self : ...
36+
37+
38+ def _is_numcodec (obj : object ) -> TypeGuard [Numcodec ]:
39+ """
40+ Check if the given object implements the Numcodec protocol.
41+
42+ The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
43+ members (i.e., attributes), so we use this function to manually check for the presence of the
44+ required attributes and methods on a given object.
45+ """
46+ return _is_numcodec_cls (type (obj ))
47+
48+
49+ def _is_numcodec_cls (obj : object ) -> TypeGuard [type [Numcodec ]]:
50+ """
51+ Check if the given object is a class implements the Numcodec protocol.
52+
53+ The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
54+ members (i.e., attributes), so we use this function to manually check for the presence of the
55+ required attributes and methods on a given object.
56+ """
57+ return (
58+ isinstance (obj , type )
59+ and hasattr (obj , "codec_id" )
60+ and isinstance (obj .codec_id , str )
61+ and hasattr (obj , "encode" )
62+ and callable (obj .encode )
63+ and hasattr (obj , "decode" )
64+ and callable (obj .decode )
65+ and hasattr (obj , "get_config" )
66+ and callable (obj .get_config )
67+ and hasattr (obj , "from_config" )
68+ and callable (obj .from_config )
69+ )
70+
71+
2172@dataclass (frozen = True )
2273class V2Codec (ArrayBytesCodec ):
23- filters : tuple [numcodecs . abc . Codec , ...] | None
24- compressor : numcodecs . abc . Codec | None
74+ filters : tuple [Numcodec , ...] | None
75+ compressor : Numcodec | None
2576
2677 is_fixed_size = False
2778
@@ -33,9 +84,9 @@ async def _decode_single(
3384 cdata = chunk_bytes .as_array_like ()
3485 # decompress
3586 if self .compressor :
36- chunk = await asyncio .to_thread (self .compressor .decode , cdata )
87+ chunk = await asyncio .to_thread (self .compressor .decode , cdata ) # type: ignore[arg-type]
3788 else :
38- chunk = cdata
89+ chunk = cdata # type: ignore[assignment]
3990
4091 # apply filters
4192 if self .filters :
@@ -56,7 +107,7 @@ async def _decode_single(
56107 # is an object array. In this case, we need to convert the object
57108 # array to the correct dtype.
58109
59- chunk = np .array (chunk ).astype (chunk_spec .dtype .to_native_dtype ())
110+ chunk = np .array (chunk ).astype (chunk_spec .dtype .to_native_dtype ()) # type: ignore[assignment]
60111
61112 elif chunk .dtype != object :
62113 # If we end up here, someone must have hacked around with the filters.
@@ -85,17 +136,17 @@ async def _encode_single(
85136 # apply filters
86137 if self .filters :
87138 for f in self .filters :
88- chunk = await asyncio .to_thread (f .encode , chunk )
139+ chunk = await asyncio .to_thread (f .encode , chunk ) # type: ignore[arg-type]
89140
90141 # check object encoding
91142 if ensure_ndarray_like (chunk ).dtype == object :
92143 raise RuntimeError ("cannot write object array without object codec" )
93144
94145 # compress
95146 if self .compressor :
96- cdata = await asyncio .to_thread (self .compressor .encode , chunk )
147+ cdata = await asyncio .to_thread (self .compressor .encode , chunk ) # type: ignore[arg-type]
97148 else :
98- cdata = chunk
149+ cdata = chunk # type: ignore[assignment]
99150
100151 cdata = ensure_bytes (cdata )
101152 return chunk_spec .prototype .buffer .from_bytes (cdata )
0 commit comments