Skip to content

Commit 4e978f9

Browse files
committed
add (typed) functions for resolving codecs
1 parent eab46a2 commit 4e978f9

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

src/zarr/registry.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from importlib.metadata import entry_points as get_entry_points
66
from typing import TYPE_CHECKING, Any, Generic, TypeVar
77

8+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
9+
from zarr.core.common import JSON
810
from zarr.core.config import BadConfigError, config
911

1012
if TYPE_CHECKING:
@@ -151,6 +153,62 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]:
151153
raise KeyError(key)
152154

153155

156+
def _resolve_codec(data: dict[str, JSON]) -> Codec:
157+
"""
158+
Get a codec instance from a dict representation of that codec.
159+
"""
160+
# TODO: narrow the type of the input to only those dicts that map on to codec class instances.
161+
return get_codec_class(data["name"]).from_dict(data) # type: ignore[arg-type]
162+
163+
164+
def _parse_bytes_bytes_codec(data: dict[str, JSON] | BytesBytesCodec) -> BytesBytesCodec:
165+
"""
166+
Normalize the input to a ``BytesBytesCodec`` instance.
167+
If the input is already a ``BytesBytesCodec``, it is returned as is. If the input is a dict, it
168+
is converted to a ``BytesBytesCodec`` instance via the ``_resolve_codec`` function.
169+
"""
170+
if isinstance(data, dict):
171+
result = _resolve_codec(data)
172+
if not isinstance(result, BytesBytesCodec):
173+
msg = f"Expected a dict representation of a BytesBytesCodec; got a dict representation of a {type(result)} instead."
174+
raise ValueError(msg)
175+
else:
176+
result = data
177+
return result
178+
179+
180+
def _parse_array_bytes_codec(data: dict[str, JSON] | ArrayBytesCodec) -> ArrayBytesCodec:
181+
"""
182+
Normalize the input to a ``ArrayBytesCodec`` instance.
183+
If the input is already a ``ArrayBytesCodec``, it is returned as is. If the input is a dict, it
184+
is converted to a ``ArrayBytesCodec`` instance via the ``_resolve_codec`` function.
185+
"""
186+
if isinstance(data, dict):
187+
result = _resolve_codec(data)
188+
if not isinstance(result, ArrayBytesCodec):
189+
msg = f"Expected a dict representation of a ArrayBytesCodec; got a dict representation of a {type(result)} instead."
190+
raise ValueError(msg)
191+
else:
192+
result = data
193+
return result
194+
195+
196+
def _parse_array_array_codec(data: dict[str, JSON] | ArrayArrayCodec) -> ArrayArrayCodec:
197+
"""
198+
Normalize the input to a ``ArrayArrayCodec`` instance.
199+
If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it
200+
is converted to a ``ArrayArrayCodec`` instance via the ``_resolve_codec`` function.
201+
"""
202+
if isinstance(data, dict):
203+
result = _resolve_codec(data)
204+
if not isinstance(result, ArrayArrayCodec):
205+
msg = f"Expected a dict representation of a ArrayArrayCodec; got a dict representation of a {type(result)} instead."
206+
raise ValueError(msg)
207+
else:
208+
result = data
209+
return result
210+
211+
154212
def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]:
155213
if reload_config:
156214
_reload_config()

0 commit comments

Comments
 (0)