Skip to content

Commit 56ad592

Browse files
committed
refactor codec parsing and add tests
1 parent 3fb7126 commit 56ad592

File tree

13 files changed

+240
-218
lines changed

13 files changed

+240
-218
lines changed

src/zarr/codecs/_v2.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from zarr.core.common import (
1616
CodecJSON,
1717
CodecJSON_V2,
18-
CodecJSON_V3,
1918
_check_codecjson_v2,
2019
_check_codecjson_v3,
2120
)
@@ -32,26 +31,6 @@
3231
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
3332

3433

35-
def codec_json_v2_to_v3(data: CodecJSON_V2) -> CodecJSON_V3:
36-
"""
37-
Convert V2 codec JSON to V3 codec JSON
38-
"""
39-
name = data["id"]
40-
config = {k: v for k, v in data.items() if k != "id"}
41-
return {"name": name, "configuration": config}
42-
43-
44-
def codec_json_v3_to_v2(data: CodecJSON_V3) -> CodecJSON_V2:
45-
"""
46-
Convert V3 codec JSON to V2 codec JSON
47-
"""
48-
if isinstance(data, str):
49-
return {"id": data}
50-
name = data["name"]
51-
config = dict(data.get("configuration", {}))
52-
return {"id": name, **config} # type: ignore[typeddict-item]
53-
54-
5534
@dataclass(frozen=True)
5635
class V2Codec(ArrayBytesCodec):
5736
filters: tuple[Numcodec, ...] | None
@@ -168,8 +147,12 @@ def _from_json_v2(cls, data: CodecJSON) -> Self:
168147
@classmethod
169148
def _from_json_v3(cls, data: CodecJSON) -> Self:
170149
if _check_codecjson_v3(data):
171-
# convert to a v2 codec JSON
172-
codec = get_numcodec(codec_json_v3_to_v2(data))
150+
request: CodecJSON_V2
151+
if isinstance(data, str):
152+
request = {"id": data}
153+
else:
154+
request = {"id": data["name"], **data["configuration"]} # type: ignore[typeddict-item]
155+
codec = get_numcodec(request)
173156
return cls(codec=codec)
174157
msg = (
175158
"Invalid Zarr V3 JSON representation of a codec. "

src/zarr/codecs/numcodecs/_codecs.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from zarr.codecs.zstd import ZstdConfig_V3, ZstdJSON_V2, ZstdJSON_V3
7676
from zarr.core.array_spec import ArraySpec
7777
from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer
78+
from zarr.core.dtype.common import DTypeSpec_V2, DTypeSpec_V3
7879

7980

8081
# TypedDict definitions for V2 and V3 JSON representations
@@ -154,21 +155,32 @@ class ShuffleJSON_V3(NamedRequiredConfig[Literal["shuffle"], ShuffleConfig]):
154155
"""JSON representation of Shuffle codec for Zarr V3."""
155156

156157

157-
# Array-to-array codec configuration classes
158-
class DeltaConfig(TypedDict):
159-
dtype: str
160-
astype: str
158+
class DeltaConfig_V2(TypedDict):
159+
dtype: DTypeSpec_V2
160+
astype: DTypeSpec_V2
161+
162+
163+
class DeltaConfig_V3(TypedDict):
164+
dtype: DTypeSpec_V3
165+
astype: DTypeSpec_V3
161166

162167

163168
class BitRoundConfig(TypedDict):
164169
keepbits: int
165170

166171

167-
class FixedScaleOffsetConfig(TypedDict):
168-
dtype: NotRequired[str]
172+
class FixedScaleOffsetConfig_V2(TypedDict):
173+
dtype: NotRequired[DTypeSpec_V2]
174+
astype: NotRequired[DTypeSpec_V2]
175+
scale: NotRequired[float]
176+
offset: NotRequired[float]
177+
178+
179+
class FixedScaleOffsetConfig_V3(TypedDict):
180+
dtype: NotRequired[DTypeSpec_V3]
181+
astype: NotRequired[DTypeSpec_V2]
169182
scale: NotRequired[float]
170183
offset: NotRequired[float]
171-
astype: NotRequired[str]
172184

173185

174186
class QuantizeConfig(TypedDict):
@@ -186,13 +198,13 @@ class AsTypeConfig(TypedDict):
186198

187199

188200
# Array-to-array codec JSON representations
189-
class DeltaJSON_V2(DeltaConfig):
201+
class DeltaJSON_V2(DeltaConfig_V2):
190202
"""JSON representation of Delta codec for Zarr V2."""
191203

192204
id: ReadOnly[Literal["delta"]]
193205

194206

195-
class DeltaJSON_V3(NamedRequiredConfig[Literal["delta"], DeltaConfig]):
207+
class DeltaJSON_V3(NamedRequiredConfig[Literal["delta"], DeltaConfig_V3]):
196208
"""JSON representation of Delta codec for Zarr V3."""
197209

198210

@@ -206,14 +218,14 @@ class BitRoundJSON_V3(NamedRequiredConfig[Literal["bitround"], BitRoundConfig]):
206218
"""JSON representation of BitRound codec for Zarr V3."""
207219

208220

209-
class FixedScaleOffsetJSON_V2(FixedScaleOffsetConfig):
221+
class FixedScaleOffsetJSON_V2(FixedScaleOffsetConfig_V2):
210222
"""JSON representation of FixedScaleOffset codec for Zarr V2."""
211223

212224
id: ReadOnly[Literal["fixedscaleoffset"]]
213225

214226

215227
class FixedScaleOffsetJSON_V3(
216-
NamedRequiredConfig[Literal["fixedscaleoffset"], FixedScaleOffsetConfig]
228+
NamedRequiredConfig[Literal["fixedscaleoffset"], FixedScaleOffsetConfig_V3]
217229
):
218230
"""JSON representation of FixedScaleOffset codec for Zarr V3."""
219231

@@ -507,7 +519,11 @@ def to_json(self, zarr_format: Literal[2]) -> ZstdJSON_V2: ...
507519
@overload
508520
def to_json(self, zarr_format: Literal[3]) -> ZstdJSON_V3: ...
509521
def to_json(self, zarr_format: ZarrFormat) -> ZstdJSON_V2 | ZstdJSON_V3:
510-
return super().to_json(zarr_format) # type: ignore[return-value]
522+
res = super().to_json(zarr_format)
523+
if zarr_format == 2 and not res.get("checksum", False): # type: ignore[union-attr]
524+
# https://github.com/zarr-developers/zarr-python/pull/2655
525+
res.pop("checksum") # type: ignore[union-attr, typeddict-item]
526+
return res # type: ignore[return-value]
511527

512528

513529
class Zlib(_NumcodecsBytesBytesCodec):
@@ -595,7 +611,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
595611
class Delta(_NumcodecsArrayArrayCodec):
596612
codec_name = "numcodecs.delta"
597613
_codec_id = "delta"
598-
codec_config: DeltaConfig
614+
codec_config: DeltaConfig_V2 | DeltaConfig_V3
599615

600616
def __init__(self, **codec_config: Any) -> None:
601617
if "codec_config" in codec_config:
@@ -612,7 +628,7 @@ def to_json(self, zarr_format: ZarrFormat) -> DeltaJSON_V2 | DeltaJSON_V3:
612628

613629
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
614630
if astype := self.codec_config.get("astype"):
615-
dtype = parse_dtype(np.dtype(astype), zarr_format=3)
631+
dtype = parse_dtype(np.dtype(astype), zarr_format=3) # type: ignore[arg-type]
616632
return replace(chunk_spec, dtype=dtype)
617633
return chunk_spec
618634

@@ -634,7 +650,7 @@ def to_json(self, zarr_format: ZarrFormat) -> BitRoundJSON_V2 | BitRoundJSON_V3:
634650
class FixedScaleOffset(_NumcodecsArrayArrayCodec):
635651
codec_name = "numcodecs.fixedscaleoffset"
636652
_codec_id = "fixedscaleoffset"
637-
codec_config: FixedScaleOffsetConfig
653+
codec_config: FixedScaleOffsetConfig_V2
638654

639655
@overload
640656
def to_json(self, zarr_format: Literal[2]) -> FixedScaleOffsetJSON_V2: ...
@@ -646,7 +662,7 @@ def to_json(self, zarr_format: ZarrFormat) -> FixedScaleOffsetJSON_V2 | FixedSca
646662

647663
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
648664
if astype := self.codec_config.get("astype"):
649-
dtype = parse_dtype(np.dtype(astype), zarr_format=3)
665+
dtype = parse_dtype(np.dtype(astype), zarr_format=3) # type: ignore[arg-type]
650666
return replace(chunk_spec, dtype=dtype)
651667
return chunk_spec
652668

src/zarr/codecs/zstd.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Mapping
55
from dataclasses import dataclass
66
from functools import cached_property
7-
from typing import TYPE_CHECKING, Literal, Self, TypedDict, TypeGuard, cast, overload
7+
from typing import TYPE_CHECKING, Literal, NotRequired, Self, TypedDict, TypeGuard, cast, overload
88

99
import numcodecs
1010
from numcodecs.zstd import Zstd
@@ -25,6 +25,7 @@
2525

2626
class ZstdConfig_V2(TypedDict):
2727
level: int
28+
checksum: NotRequired[Literal[True]]
2829

2930

3031
class ZstdConfig_V3(TypedDict):
@@ -104,7 +105,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
104105
def _from_json_v2(cls, data: CodecJSON) -> Self:
105106
if check_json_v2(data):
106107
if "checksum" in data:
107-
return cls(level=data["level"], checksum=data["checksum"]) # type: ignore[typeddict-item]
108+
return cls(level=data["level"], checksum=data["checksum"])
108109
else:
109110
return cls(level=data["level"])
110111

@@ -138,7 +139,10 @@ def to_json(self, zarr_format: Literal[3]) -> ZstdJSON_V3: ...
138139

139140
def to_json(self, zarr_format: ZarrFormat) -> ZstdJSON_V2 | ZstdJSON_V3:
140141
if zarr_format == 2:
141-
return {"id": "zstd", "level": self.level}
142+
if self.checksum is True:
143+
return {"id": "zstd", "level": self.level, "checksum": self.checksum}
144+
else:
145+
return {"id": "zstd", "level": self.level}
142146
else:
143147
return {
144148
"name": "zstd",

src/zarr/core/array.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import zarr
2626
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2727
from zarr.abc.numcodec import Numcodec
28+
from zarr.codecs._v2 import NumcodecWrapper
2829
from zarr.codecs.bytes import BytesCodec
2930
from zarr.codecs.transpose import TransposeCodec
3031
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
@@ -109,6 +110,7 @@
109110
ArrayV3MetadataDict,
110111
T_ArrayMetadata,
111112
)
113+
from zarr.core.metadata.common import _parse_codec
112114
from zarr.core.metadata.io import save_metadata
113115
from zarr.core.metadata.v2 import (
114116
CompressorLike_V2,
@@ -125,9 +127,6 @@
125127
ZarrUserWarning,
126128
)
127129
from zarr.registry import (
128-
_parse_array_array_codec,
129-
_parse_array_bytes_codec,
130-
_parse_bytes_bytes_codec,
131130
get_pipeline_class,
132131
)
133132
from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path
@@ -141,7 +140,6 @@
141140

142141
from zarr.abc.codec import CodecPipeline
143142
from zarr.abc.store import Store
144-
from zarr.codecs._v2 import NumcodecWrapper
145143
from zarr.codecs.sharding import ShardingCodecIndexLocation
146144
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar
147145
from zarr.storage import StoreLike
@@ -252,6 +250,60 @@ def create_codec_pipeline(metadata: ArrayMetadata, *, store: Store | None = None
252250
raise TypeError # pragma: no cover
253251

254252

253+
def _parse_bytes_bytes_codec(
254+
data: Mapping[str, JSON] | Codec | Numcodec, *, dtype: ZDType[Any, Any]
255+
) -> BytesBytesCodec:
256+
"""
257+
Normalize the input to a ``BytesBytesCodec`` instance.
258+
If the input is already a ``BytesBytesCodec``, it is returned as is. If the input is a dict, it
259+
is converted to a ``BytesBytesCodec`` instance via the ``_resolve_codec`` function.
260+
"""
261+
262+
_codec_or_numcodec_wrapper = _parse_codec(data, dtype=dtype)
263+
if isinstance(_codec_or_numcodec_wrapper, NumcodecWrapper):
264+
return _codec_or_numcodec_wrapper.to_bytes_bytes()
265+
elif isinstance(_codec_or_numcodec_wrapper, BytesBytesCodec):
266+
return _codec_or_numcodec_wrapper
267+
msg = f"Expected a NumcodecWrapper or ArrayBytesCodec or a dict representation thereof; got {data} instead."
268+
raise ValueError(msg)
269+
270+
271+
def _parse_array_bytes_codec(
272+
data: Mapping[str, JSON] | Codec | Numcodec, *, dtype: ZDType[Any, Any]
273+
) -> ArrayBytesCodec:
274+
"""
275+
Normalize the input to a ``ArrayBytesCodec`` instance.
276+
If the input is already a ``ArrayBytesCodec``, it is returned as is. If the input is a dict, it
277+
is converted to a ``ArrayBytesCodec`` instance via the ``_resolve_codec`` function.
278+
"""
279+
280+
_codec_or_numcodec_wrapper = _parse_codec(data, dtype=dtype)
281+
if isinstance(_codec_or_numcodec_wrapper, NumcodecWrapper):
282+
return _codec_or_numcodec_wrapper.to_array_bytes()
283+
elif isinstance(_codec_or_numcodec_wrapper, ArrayBytesCodec):
284+
return _codec_or_numcodec_wrapper
285+
msg = f"Expected a NumcodecWrapper or ArrayBytesCodec or a dict representation thereof; got {data} instead."
286+
raise ValueError(msg)
287+
288+
289+
def _parse_array_array_codec(
290+
data: Mapping[str, JSON] | Codec | Numcodec, *, dtype: ZDType[Any, Any]
291+
) -> ArrayArrayCodec:
292+
"""
293+
Normalize the input to a ``ArrayArrayCodec`` instance.
294+
If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it
295+
is converted to a ``ArrayArrayCodec`` instance via the ``_resolve_codec`` function.
296+
"""
297+
298+
_codec_or_numcodec_wrapper = _parse_codec(data, dtype=dtype)
299+
if isinstance(_codec_or_numcodec_wrapper, NumcodecWrapper):
300+
return _codec_or_numcodec_wrapper.to_array_array()
301+
elif isinstance(_codec_or_numcodec_wrapper, ArrayArrayCodec):
302+
return _codec_or_numcodec_wrapper
303+
msg = f"Expected a NumcodecWrapper or ArrayArrayCodec or a dict representation thereof; got {data} instead."
304+
raise ValueError(msg)
305+
306+
255307
async def get_array_metadata(
256308
store_path: StorePath, zarr_format: ZarrFormat | None = 3
257309
) -> dict[str, JSON]:
@@ -5177,17 +5229,15 @@ def _parse_chunk_encoding_v3(
51775229
maybe_array_array = (filters,)
51785230
else:
51795231
maybe_array_array = cast("Iterable[Codec | dict[str, JSON]]", filters)
5180-
out_array_array = tuple(
5181-
_parse_array_array_codec(c, zarr_format=3) for c in maybe_array_array
5182-
)
5232+
out_array_array = tuple(_parse_array_array_codec(c, dtype=dtype) for c in maybe_array_array)
51835233

51845234
if serializer == "auto":
51855235
out_array_bytes = default_serializer_v3(dtype)
51865236
else:
51875237
# TODO: ensure that the serializer is compatible with the ndarray produced by the
51885238
# array-array codecs. For example, if a sequence of array-array codecs produces an
51895239
# array with a single-byte data type, then the serializer should not specify endiannesss.
5190-
out_array_bytes = _parse_array_bytes_codec(serializer, zarr_format=3)
5240+
out_array_bytes = _parse_array_bytes_codec(serializer, dtype=dtype)
51915241

51925242
if compressors is None:
51935243
out_bytes_bytes: tuple[BytesBytesCodec, ...] = ()
@@ -5200,12 +5250,7 @@ def _parse_chunk_encoding_v3(
52005250
else:
52015251
maybe_bytes_bytes = compressors # type: ignore[assignment]
52025252

5203-
out_bytes_bytes = tuple(_parse_bytes_bytes_codec(c) for c in maybe_bytes_bytes)
5204-
5205-
# specialize codecs as needed given the dtype
5206-
5207-
# TODO: refactor so that the config only contains the name of the codec, and we use the dtype
5208-
# to create the codec instance, instead of storing a dict representation of a full codec.
5253+
out_bytes_bytes = tuple(_parse_bytes_bytes_codec(c, dtype=dtype) for c in maybe_bytes_bytes)
52095254

52105255
# TODO: ensure that the serializer is compatible with the ndarray produced by the
52115256
# array-array codecs. For example, if a sequence of array-array codecs produces an

0 commit comments

Comments
 (0)