Skip to content

Commit f843142

Browse files
committed
dtype adaptor for zarr 3.1
1 parent 998579a commit f843142

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

numcodecs/zarr3.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@
3333
from warnings import warn
3434

3535
import numpy as np
36-
36+
from packaging.version import Version
37+
from importlib.metadata import version
3738
import numcodecs
3839

3940
try:
4041
import zarr
4142

42-
if zarr.__version__ < "3.0.0": # pragma: no cover
43+
if version('zarr') < Version("3.0.0"): # pragma: no cover
4344
raise ImportError("zarr 3.0.0 or later is required to use the numcodecs zarr integration.")
4445
except ImportError as e: # pragma: no cover
4546
raise ImportError(
@@ -55,6 +56,19 @@
5556

5657
CODEC_PREFIX = "numcodecs."
5758

59+
def from_zarr_dtype(dtype: Any) -> np.dtype:
60+
"""
61+
Get a numpy data type from an array spec, depending on the zarr version.
62+
"""
63+
if version('zarr') >= Version("3.1.0"):
64+
return dtype.to_native_dtype()
65+
return dtype
66+
67+
def to_zarr_dtype(dtype: np.dtype) -> Any:
68+
if version('zarr') >= Version("3.1.0"):
69+
from zarr.dtype import parse_data_type
70+
return parse_data_type(dtype)
71+
return dtype
5872

5973
def _expect_name_prefix(codec_name: str) -> str:
6074
if not codec_name.startswith(CODEC_PREFIX):
@@ -224,15 +238,17 @@ class LZMA(_NumcodecsBytesBytesCodec, codec_name="lzma"):
224238
class Shuffle(_NumcodecsBytesBytesCodec, codec_name="shuffle"):
225239
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle:
226240
if self.codec_config.get("elementsize") is None:
227-
return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize})
241+
dtype = from_zarr_dtype(array_spec.dtype)
242+
return Shuffle(**{**self.codec_config, "elementsize": dtype.itemsize})
228243
return self # pragma: no cover
229244

230245

231246
# array-to-array codecs ("filters")
232247
class Delta(_NumcodecsArrayArrayCodec, codec_name="delta"):
233248
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
234249
if astype := self.codec_config.get("astype"):
235-
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
250+
dtype = to_zarr_dtype(np.dtype(astype))
251+
return replace(chunk_spec, dtype=dtype) # type: ignore[call-overload]
236252
return chunk_spec
237253

238254

@@ -243,12 +259,14 @@ class BitRound(_NumcodecsArrayArrayCodec, codec_name="bitround"):
243259
class FixedScaleOffset(_NumcodecsArrayArrayCodec, codec_name="fixedscaleoffset"):
244260
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
245261
if astype := self.codec_config.get("astype"):
246-
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
262+
dtype = to_zarr_dtype(np.dtype(astype))
263+
return replace(chunk_spec, dtype=dtype) # type: ignore[call-overload]
247264
return chunk_spec
248265

249266
def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset:
250267
if self.codec_config.get("dtype") is None:
251-
return FixedScaleOffset(**{**self.codec_config, "dtype": str(array_spec.dtype)})
268+
dtype = from_zarr_dtype(array_spec.dtype)
269+
return FixedScaleOffset(**{**self.codec_config, "dtype": str(dtype)})
252270
return self
253271

254272

@@ -258,7 +276,8 @@ def __init__(self, **codec_config: JSON) -> None:
258276

259277
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize:
260278
if self.codec_config.get("dtype") is None:
261-
return Quantize(**{**self.codec_config, "dtype": str(array_spec.dtype)})
279+
dtype = from_zarr_dtype(array_spec.dtype)
280+
return Quantize(**{**self.codec_config, "dtype": str(dtype)})
262281
return self
263282

264283

@@ -271,17 +290,20 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
271290
)
272291

273292
def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None:
274-
if dtype != np.dtype("bool"):
293+
_dtype = from_zarr_dtype(dtype)
294+
if _dtype != np.dtype("bool"):
275295
raise ValueError(f"Packbits filter requires bool dtype. Got {dtype}.")
276296

277297

278298
class AsType(_NumcodecsArrayArrayCodec, codec_name="astype"):
279299
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
280-
return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"])) # type: ignore[arg-type]
300+
dtype = to_zarr_dtype(np.dtype(self.codec_config["encode_dtype"]))
301+
return replace(chunk_spec, dtype=dtype) # type: ignore[arg-type]
281302

282303
def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType:
283304
if self.codec_config.get("decode_dtype") is None:
284-
return AsType(**{**self.codec_config, "decode_dtype": str(array_spec.dtype)})
305+
dtype = from_zarr_dtype(array_spec.dtype)
306+
return AsType(**{**self.codec_config, "decode_dtype": str(dtype)})
285307
return self
286308

287309

0 commit comments

Comments
 (0)