Skip to content

Commit e3798cd

Browse files
committed
Fix to work with zarr 3.1.0
1 parent b644e9a commit e3798cd

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

numcodecs/zarr3.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@
2929
import math
3030
from dataclasses import dataclass, replace
3131
from functools import cached_property
32+
from importlib.metadata import version
3233
from typing import Any, Self
3334
from warnings import warn
3435

3536
import numpy as np
37+
from packaging.version import Version
3638

3739
import numcodecs
3840

3941
try:
40-
import zarr
42+
import zarr # noqa: F401
4143

42-
if zarr.__version__ < "3.0.0": # pragma: no cover
44+
if Version(version('zarr')) < Version("3.0.0"): # pragma: no cover
4345
raise ImportError("zarr 3.0.0 or later is required to use the numcodecs zarr integration.")
4446
except ImportError as e: # pragma: no cover
4547
raise ImportError(
@@ -52,10 +54,28 @@
5254
from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer
5355
from zarr.core.buffer.cpu import as_numpy_array_wrapper
5456
from zarr.core.common import JSON, parse_named_configuration, product
57+
from zarr.dtype import ZDType
5558

5659
CODEC_PREFIX = "numcodecs."
5760

5861

62+
def _from_zarr_dtype(dtype: Any) -> np.dtype:
63+
"""
64+
Get a numpy data type from an array spec, depending on the zarr version.
65+
"""
66+
if Version(version('zarr')) >= Version("3.1.0"):
67+
return dtype.to_native_dtype()
68+
return dtype # pragma: no cover
69+
70+
71+
def _to_zarr_dtype(dtype: np.dtype) -> Any:
72+
if Version(version('zarr')) >= Version("3.1.0"):
73+
from zarr.dtype import parse_data_type
74+
75+
return parse_data_type(dtype, zarr_format=3)
76+
return dtype # pragma: no cover
77+
78+
5979
def _expect_name_prefix(codec_name: str) -> str:
6080
if not codec_name.startswith(CODEC_PREFIX):
6181
raise ValueError(
@@ -224,15 +244,17 @@ class LZMA(_NumcodecsBytesBytesCodec, codec_name="lzma"):
224244
class Shuffle(_NumcodecsBytesBytesCodec, codec_name="shuffle"):
225245
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle:
226246
if self.codec_config.get("elementsize") is None:
227-
return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize})
247+
dtype = _from_zarr_dtype(array_spec.dtype)
248+
return Shuffle(**{**self.codec_config, "elementsize": dtype.itemsize})
228249
return self # pragma: no cover
229250

230251

231252
# array-to-array codecs ("filters")
232253
class Delta(_NumcodecsArrayArrayCodec, codec_name="delta"):
233254
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
234255
if astype := self.codec_config.get("astype"):
235-
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
256+
dtype = _to_zarr_dtype(np.dtype(astype)) # type: ignore[call-overload]
257+
return replace(chunk_spec, dtype=dtype)
236258
return chunk_spec
237259

238260

@@ -243,12 +265,14 @@ class BitRound(_NumcodecsArrayArrayCodec, codec_name="bitround"):
243265
class FixedScaleOffset(_NumcodecsArrayArrayCodec, codec_name="fixedscaleoffset"):
244266
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
245267
if astype := self.codec_config.get("astype"):
246-
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
268+
dtype = _to_zarr_dtype(np.dtype(astype)) # type: ignore[call-overload]
269+
return replace(chunk_spec, dtype=dtype)
247270
return chunk_spec
248271

249272
def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset:
250273
if self.codec_config.get("dtype") is None:
251-
return FixedScaleOffset(**{**self.codec_config, "dtype": str(array_spec.dtype)})
274+
dtype = _from_zarr_dtype(array_spec.dtype)
275+
return FixedScaleOffset(**{**self.codec_config, "dtype": str(dtype)})
252276
return self
253277

254278

@@ -258,7 +282,8 @@ def __init__(self, **codec_config: JSON) -> None:
258282

259283
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize:
260284
if self.codec_config.get("dtype") is None:
261-
return Quantize(**{**self.codec_config, "dtype": str(array_spec.dtype)})
285+
dtype = _from_zarr_dtype(array_spec.dtype)
286+
return Quantize(**{**self.codec_config, "dtype": str(dtype)})
262287
return self
263288

264289

@@ -267,21 +292,25 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
267292
return replace(
268293
chunk_spec,
269294
shape=(1 + math.ceil(product(chunk_spec.shape) / 8),),
270-
dtype=np.dtype("uint8"),
295+
dtype=_to_zarr_dtype(np.dtype("uint8")),
271296
)
272297

273-
def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None:
274-
if dtype != np.dtype("bool"):
298+
def validate(self, *, shape: tuple[int, ...], dtype: ZDType[Any, Any], **_kwargs) -> None:
299+
_dtype = _from_zarr_dtype(dtype)
300+
if _dtype != np.dtype("bool"):
275301
raise ValueError(f"Packbits filter requires bool dtype. Got {dtype}.")
276302

277303

278304
class AsType(_NumcodecsArrayArrayCodec, codec_name="astype"):
279305
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
280-
return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"])) # type: ignore[arg-type]
306+
dtype = _to_zarr_dtype(np.dtype(self.codec_config["encode_dtype"])) # type: ignore[arg-type]
307+
return replace(chunk_spec, dtype=dtype)
281308

282309
def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType:
283310
if self.codec_config.get("decode_dtype") is None:
284-
return AsType(**{**self.codec_config, "decode_dtype": str(array_spec.dtype)})
311+
# TODO: remove these coverage exemptions the correct way, i.e. with tests
312+
dtype = _from_zarr_dtype(array_spec.dtype) # pragma: no cover
313+
return AsType(**{**self.codec_config, "decode_dtype": str(dtype)}) # pragma: no cover
285314
return self
286315

287316

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ description = """
1414
A Python package providing buffer compression and transformation codecs \
1515
for use in data storage and communication applications."""
1616
readme = "README.rst"
17-
dependencies = ["numpy>=1.24", "typing_extensions"]
17+
dependencies = ["numpy>=1.24", "typing_extensions", "packaging"]
1818
requires-python = ">=3.11"
1919
dynamic = [
2020
"version",

0 commit comments

Comments
 (0)