|
5 | 5 | from importlib.metadata import entry_points as get_entry_points |
6 | 6 | from typing import TYPE_CHECKING, Any, Generic, TypeVar |
7 | 7 |
|
| 8 | +from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec |
| 9 | +from zarr.core.common import JSON |
8 | 10 | from zarr.core.config import BadConfigError, config |
9 | 11 |
|
10 | 12 | if TYPE_CHECKING: |
@@ -151,6 +153,62 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: |
151 | 153 | raise KeyError(key) |
152 | 154 |
|
153 | 155 |
|
| 156 | +def _resolve_codec(data: dict[str, JSON]) -> Codec: |
| 157 | + """ |
| 158 | + Get a codec instance from a dict representation of that codec. |
| 159 | + """ |
| 160 | + # TODO: narrow the type of the input to only those dicts that map on to codec class instances. |
| 161 | + return get_codec_class(data["name"]).from_dict(data) # type: ignore[arg-type] |
| 162 | + |
| 163 | + |
| 164 | +def _parse_bytes_bytes_codec(data: dict[str, JSON] | BytesBytesCodec) -> BytesBytesCodec: |
| 165 | + """ |
| 166 | + Normalize the input to a ``BytesBytesCodec`` instance. |
| 167 | + If the input is already a ``BytesBytesCodec``, it is returned as is. If the input is a dict, it |
| 168 | + is converted to a ``BytesBytesCodec`` instance via the ``_resolve_codec`` function. |
| 169 | + """ |
| 170 | + if isinstance(data, dict): |
| 171 | + result = _resolve_codec(data) |
| 172 | + if not isinstance(result, BytesBytesCodec): |
| 173 | + msg = f"Expected a dict representation of a BytesBytesCodec; got a dict representation of a {type(result)} instead." |
| 174 | + raise ValueError(msg) |
| 175 | + else: |
| 176 | + result = data |
| 177 | + return result |
| 178 | + |
| 179 | + |
| 180 | +def _parse_array_bytes_codec(data: dict[str, JSON] | ArrayBytesCodec) -> ArrayBytesCodec: |
| 181 | + """ |
| 182 | + Normalize the input to a ``ArrayBytesCodec`` instance. |
| 183 | + If the input is already a ``ArrayBytesCodec``, it is returned as is. If the input is a dict, it |
| 184 | + is converted to a ``ArrayBytesCodec`` instance via the ``_resolve_codec`` function. |
| 185 | + """ |
| 186 | + if isinstance(data, dict): |
| 187 | + result = _resolve_codec(data) |
| 188 | + if not isinstance(result, ArrayBytesCodec): |
| 189 | + msg = f"Expected a dict representation of a ArrayBytesCodec; got a dict representation of a {type(result)} instead." |
| 190 | + raise ValueError(msg) |
| 191 | + else: |
| 192 | + result = data |
| 193 | + return result |
| 194 | + |
| 195 | + |
| 196 | +def _parse_array_array_codec(data: dict[str, JSON] | ArrayArrayCodec) -> ArrayArrayCodec: |
| 197 | + """ |
| 198 | + Normalize the input to a ``ArrayArrayCodec`` instance. |
| 199 | + If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it |
| 200 | + is converted to a ``ArrayArrayCodec`` instance via the ``_resolve_codec`` function. |
| 201 | + """ |
| 202 | + if isinstance(data, dict): |
| 203 | + result = _resolve_codec(data) |
| 204 | + if not isinstance(result, ArrayArrayCodec): |
| 205 | + msg = f"Expected a dict representation of a ArrayArrayCodec; got a dict representation of a {type(result)} instead." |
| 206 | + raise ValueError(msg) |
| 207 | + else: |
| 208 | + result = data |
| 209 | + return result |
| 210 | + |
| 211 | + |
154 | 212 | def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: |
155 | 213 | if reload_config: |
156 | 214 | _reload_config() |
|
0 commit comments