Skip to content

Commit ce83f63

Browse files
committed
Some zarr3 typing fixes
1 parent 05254a6 commit ce83f63

File tree

3 files changed

+42
-32
lines changed

3 files changed

+42
-32
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ repos:
3131
hooks:
3232
- id: mypy
3333
args: [--config-file, pyproject.toml]
34-
additional_dependencies: [numpy, pytest, zfpy]
34+
additional_dependencies: [numpy, pytest, zfpy, 'zarr==3.0.0b1']

numcodecs/tests/test_zarr3.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import numpy as np
46
import pytest
57

6-
zarr = pytest.importorskip("zarr")
8+
if not TYPE_CHECKING:
9+
zarr = pytest.importorskip("zarr")
10+
else:
11+
import zarr
12+
13+
import zarr.storage
14+
from zarr.core.common import JSON
715

8-
import numcodecs.zarr3 # noqa: E402
16+
import numcodecs.zarr3
917

1018
pytestmark = [
1119
pytest.mark.skipif(zarr.__version__ < "3.0.0", reason="zarr 3.0.0 or later is required"),
@@ -17,7 +25,6 @@
1725

1826
get_codec_class = zarr.registry.get_codec_class
1927
Array = zarr.Array
20-
JSON = zarr.core.common.JSON
2128
BytesCodec = zarr.codecs.BytesCodec
2229
Store = zarr.abc.store.Store
2330
MemoryStore = zarr.storage.MemoryStore
@@ -28,7 +35,7 @@
2835

2936

3037
@pytest.fixture
31-
def store() -> Store:
38+
def store() -> StorePath:
3239
return StorePath(MemoryStore(mode="w"))
3340

3441

@@ -43,6 +50,8 @@ def test_entry_points(codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
4350

4451
@pytest.mark.parametrize("codec_class", ALL_CODECS)
4552
def test_docstring(codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
53+
if codec_class.__doc__ is None:
54+
pytest.skip()
4655
assert "See :class:`numcodecs." in codec_class.__doc__
4756

4857

@@ -59,7 +68,7 @@ def test_docstring(codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
5968
numcodecs.zarr3.Shuffle,
6069
],
6170
)
62-
def test_generic_codec_class(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
71+
def test_generic_codec_class(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
6372
data = np.arange(0, 256, dtype="uint16").reshape((16, 16))
6473

6574
with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
@@ -92,7 +101,9 @@ def test_generic_codec_class(store: Store, codec_class: type[numcodecs.zarr3._Nu
92101
],
93102
)
94103
def test_generic_filter(
95-
store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec], codec_config: dict[str, JSON]
104+
store: StorePath,
105+
codec_class: type[numcodecs.zarr3._NumcodecsCodec],
106+
codec_config: dict[str, JSON],
96107
):
97108
data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16))
98109

@@ -114,7 +125,7 @@ def test_generic_filter(
114125
np.testing.assert_array_equal(data, a[:, :])
115126

116127

117-
def test_generic_filter_bitround(store: Store):
128+
def test_generic_filter_bitround(store: StorePath):
118129
data = np.linspace(0, 1, 256, dtype="float32").reshape((16, 16))
119130

120131
with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
@@ -132,7 +143,7 @@ def test_generic_filter_bitround(store: Store):
132143
assert np.allclose(data, a[:, :], atol=0.1)
133144

134145

135-
def test_generic_filter_quantize(store: Store):
146+
def test_generic_filter_quantize(store: StorePath):
136147
data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16))
137148

138149
with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
@@ -150,7 +161,7 @@ def test_generic_filter_quantize(store: Store):
150161
assert np.allclose(data, a[:, :], atol=0.001)
151162

152163

153-
def test_generic_filter_packbits(store: Store):
164+
def test_generic_filter_packbits(store: StorePath):
154165
data = np.zeros((16, 16), dtype="bool")
155166
data[0:4, :] = True
156167

@@ -189,7 +200,7 @@ def test_generic_filter_packbits(store: Store):
189200
numcodecs.zarr3.JenkinsLookup3,
190201
],
191202
)
192-
def test_generic_checksum(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
203+
def test_generic_checksum(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
193204
data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16))
194205

195206
with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
@@ -208,7 +219,7 @@ def test_generic_checksum(store: Store, codec_class: type[numcodecs.zarr3._Numco
208219

209220

210221
@pytest.mark.parametrize("codec_class", [numcodecs.zarr3.PCodec, numcodecs.zarr3.ZFPY])
211-
def test_generic_bytes_codec(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
222+
def test_generic_bytes_codec(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
212223
try:
213224
codec_class()._codec # noqa: B018
214225
except ValueError as e:

numcodecs/zarr3.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import asyncio
2828
import math
29-
from collections.abc import Callable
3029
from dataclasses import dataclass, replace
3130
from functools import cached_property, partial
3231
from typing import Any, Self, TypeVar
@@ -76,7 +75,7 @@ class _NumcodecsCodec:
7675
codec_name: str
7776
codec_config: dict[str, JSON]
7877

79-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
78+
def __init__(self, **codec_config: JSON) -> None:
8079
if not self.codec_name:
8180
raise ValueError(
8281
"The codec name needs to be supplied through the `codec_name` attribute."
@@ -106,7 +105,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
106105
codec_config = _parse_codec_configuration(data)
107106
return cls(**codec_config)
108107

109-
def to_dict(self) -> JSON:
108+
def to_dict(self) -> dict[str, JSON]:
110109
codec_config = self.codec_config.copy()
111110
return {
112111
"name": self.codec_name,
@@ -118,7 +117,7 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) ->
118117

119118

120119
class _NumcodecsBytesBytesCodec(_NumcodecsCodec, BytesBytesCodec):
121-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
120+
def __init__(self, **codec_config: JSON) -> None:
122121
super().__init__(**codec_config)
123122

124123
async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer:
@@ -140,7 +139,7 @@ async def _encode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Bu
140139

141140

142141
class _NumcodecsArrayArrayCodec(_NumcodecsCodec, ArrayArrayCodec):
143-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
142+
def __init__(self, **codec_config: JSON) -> None:
144143
super().__init__(**codec_config)
145144

146145
async def _decode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer:
@@ -155,7 +154,7 @@ async def _encode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) ->
155154

156155

157156
class _NumcodecsArrayBytesCodec(_NumcodecsCodec, ArrayBytesCodec):
158-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
157+
def __init__(self, **codec_config: JSON) -> None:
159158
super().__init__(**codec_config)
160159

161160
async def _decode_single(self, chunk_buffer: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
@@ -179,7 +178,7 @@ def _add_docstring(cls: type[T], ref_class_name: str) -> type[T]:
179178
return cls
180179

181180

182-
def _add_docstring_wrapper(ref_class_name: str) -> Callable[[type[T]], type[T]]:
181+
def _add_docstring_wrapper(ref_class_name: str) -> partial:
183182
return partial(_add_docstring, ref_class_name=ref_class_name)
184183

185184

@@ -190,7 +189,7 @@ def _make_bytes_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBy
190189
class _Codec(_NumcodecsBytesBytesCodec):
191190
codec_name = _codec_name
192191

193-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
192+
def __init__(self, **codec_config: JSON) -> None:
194193
super().__init__(**codec_config)
195194

196195
_Codec.__name__ = cls_name
@@ -204,7 +203,7 @@ def _make_array_array_codec(codec_name: str, cls_name: str) -> type[_NumcodecsAr
204203
class _Codec(_NumcodecsArrayArrayCodec):
205204
codec_name = _codec_name
206205

207-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
206+
def __init__(self, **codec_config: JSON) -> None:
208207
super().__init__(**codec_config)
209208

210209
_Codec.__name__ = cls_name
@@ -218,7 +217,7 @@ def _make_array_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsAr
218217
class _Codec(_NumcodecsArrayBytesCodec):
219218
codec_name = _codec_name
220219

221-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
220+
def __init__(self, **codec_config: JSON) -> None:
222221
super().__init__(**codec_config)
223222

224223
_Codec.__name__ = cls_name
@@ -232,7 +231,7 @@ def _make_checksum_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBytes
232231
class _ChecksumCodec(_NumcodecsBytesBytesCodec):
233232
codec_name = _codec_name
234233

235-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
234+
def __init__(self, **codec_config: JSON) -> None:
236235
super().__init__(**codec_config)
237236

238237
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
@@ -256,10 +255,10 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) ->
256255
class Shuffle(_NumcodecsBytesBytesCodec):
257256
codec_name = f"{CODEC_PREFIX}shuffle"
258257

259-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
258+
def __init__(self, **codec_config: JSON) -> None:
260259
super().__init__(**codec_config)
261260

262-
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
261+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle:
263262
if array_spec.dtype.itemsize != self.codec_config.get("elementsize"):
264263
return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize})
265264
return self # pragma: no cover
@@ -276,15 +275,15 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
276275
class FixedScaleOffset(_NumcodecsArrayArrayCodec):
277276
codec_name = f"{CODEC_PREFIX}fixedscaleoffset"
278277

279-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
278+
def __init__(self, **codec_config: JSON) -> None:
280279
super().__init__(**codec_config)
281280

282281
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
283282
if astype := self.codec_config.get("astype"):
284283
return replace(chunk_spec, dtype=np.dtype(astype))
285284
return chunk_spec
286285

287-
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
286+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset:
288287
if str(array_spec.dtype) != self.codec_config.get("dtype"):
289288
return FixedScaleOffset(**{**self.codec_config, "dtype": str(array_spec.dtype)})
290289
return self
@@ -294,10 +293,10 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
294293
class Quantize(_NumcodecsArrayArrayCodec):
295294
codec_name = f"{CODEC_PREFIX}quantize"
296295

297-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
296+
def __init__(self, **codec_config: JSON) -> None:
298297
super().__init__(**codec_config)
299298

300-
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
299+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize:
301300
if str(array_spec.dtype) != self.codec_config.get("dtype"):
302301
return Quantize(**{**self.codec_config, "dtype": str(array_spec.dtype)})
303302
return self
@@ -307,7 +306,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
307306
class PackBits(_NumcodecsArrayArrayCodec):
308307
codec_name = f"{CODEC_PREFIX}packbits"
309308

310-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
309+
def __init__(self, **codec_config: JSON) -> None:
311310
super().__init__(**codec_config)
312311

313312
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
@@ -326,13 +325,13 @@ def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None:
326325
class AsType(_NumcodecsArrayArrayCodec):
327326
codec_name = f"{CODEC_PREFIX}astype"
328327

329-
def __init__(self, **codec_config: dict[str, JSON]) -> None:
328+
def __init__(self, **codec_config: JSON) -> None:
330329
super().__init__(**codec_config)
331330

332331
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
333332
return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"]))
334333

335-
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
334+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType:
336335
decode_dtype = self.codec_config.get("decode_dtype")
337336
if str(array_spec.dtype) != decode_dtype:
338337
return AsType(**{**self.codec_config, "decode_dtype": str(array_spec.dtype)})

0 commit comments

Comments
 (0)