Skip to content

Commit fb286a7

Browse files
committed
fix mypy
1 parent a4ba7db commit fb286a7

File tree

7 files changed

+62
-42
lines changed

7 files changed

+62
-42
lines changed

src/zarr/core/array.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989
from zarr.core.metadata.v2 import (
9090
_default_compressor,
9191
_default_filters,
92+
parse_compressor,
93+
parse_filters,
9294
)
9395
from zarr.core.metadata.v3 import DataType, parse_node_type_array
9496
from zarr.core.sync import sync
@@ -164,7 +166,7 @@ async def get_array_metadata(
164166
)
165167
if zarr_json_bytes is not None and zarray_bytes is not None:
166168
# warn and favor v3
167-
msg = f"Both zarr.json (zarr v3) and .zarray (zarr v2) metadata objects exist at {store_path}."
169+
msg = f"Both zarr.json (Zarr v3) and .zarray (Zarr v2) metadata objects exist at {store_path}. Zarr v3 will be used."
168170
warnings.warn(msg, stacklevel=1)
169171
if zarr_json_bytes is None and zarray_bytes is None:
170172
raise FileNotFoundError(store_path)
@@ -667,8 +669,8 @@ async def _create_v2(
667669
config: ArrayConfig,
668670
dimension_separator: Literal[".", "/"] | None = None,
669671
fill_value: float | None = None,
670-
filters: list[dict[str, JSON]] | None = None,
671-
compressor: dict[str, JSON] | None = None,
672+
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
673+
compressor: dict[str, JSON] | numcodecs.abc.Codec | None = None,
672674
attributes: dict[str, JSON] | None = None,
673675
overwrite: bool = False,
674676
) -> AsyncArray[ArrayV2Metadata]:
@@ -3492,13 +3494,13 @@ def _get_default_codecs(
34923494
else:
34933495
dtype_key = "numeric"
34943496

3495-
return default_codecs[dtype_key]
3497+
return cast(list[dict[str, JSON]], default_codecs[dtype_key])
34963498

34973499

34983500
FiltersParam: TypeAlias = (
34993501
Iterable[dict[str, JSON] | Codec] | Iterable[numcodecs.abc.Codec] | Literal["auto"]
35003502
)
3501-
CompressionParam: TypeAlias = (
3503+
CompressorsParam: TypeAlias = (
35023504
Iterable[dict[str, JSON] | Codec] | Codec | numcodecs.abc.Codec | Literal["auto"]
35033505
)
35043506

@@ -3512,7 +3514,7 @@ async def create_array(
35123514
chunks: ChunkCoords | Literal["auto"] = "auto",
35133515
shards: ChunkCoords | Literal["auto"] | None = None,
35143516
filters: FiltersParam = "auto",
3515-
compressors: CompressionParam = "auto",
3517+
compressors: CompressorsParam = "auto",
35163518
fill_value: Any | None = 0,
35173519
order: MemoryOrder | None = "C",
35183520
zarr_format: ZarrFormat | None = 3,
@@ -3544,16 +3546,16 @@ async def create_array(
35443546
filters : Iterable[Codec], optional
35453547
Iterable of filters to apply to each chunk of the array, in order, before serializing that
35463548
chunk to bytes.
3547-
For Zarr v3, a "filter" is a transformation that takes an array and returns an array,
3549+
For Zarr v3, a "filter" is a codec that takes an array and returns an array,
35483550
and these values must be instances of ``ArrayArrayCodec``, or dict representations
35493551
of ``ArrayArrayCodec``.
35503552
For Zarr v2, a "filter" can be any numcodecs codec; you should ensure that the
35513553
the order if your filters is consistent with the behavior of each filter.
35523554
compressors : Iterable[Codec], optional
35533555
List of compressors to apply to the array. Compressors are applied in order, and after any
35543556
filters are applied (if any are specified).
3555-
For Zarr v3, a "compressor" is a transformation that takes a string of bytes and
3556-
returns another string of bytes.
3557+
For Zarr v3, a "compressor" is a codec that takes a bytestrea, and
3558+
returns another bytestream.
35573559
For Zarr v2, a "compressor" can be any numcodecs codec.
35583560
fill_value : Any, optional
35593561
Fill value for the array.
@@ -3611,11 +3613,6 @@ async def create_array(
36113613
)
36123614

36133615
raise ValueError(msg)
3614-
if filters != "auto" and not all(isinstance(f, numcodecs.abc.Codec) for f in filters):
3615-
raise TypeError(
3616-
"For Zarr v2 arrays, all elements of `filters` must be numcodecs codecs."
3617-
)
3618-
filters = cast(Iterable[numcodecs.abc.Codec] | Literal["auto"], filters)
36193616
filters_parsed, compressor_parsed = _parse_chunk_encoding_v2(
36203617
compressor=compressors, filters=filters, dtype=dtype_parsed
36213618
)
@@ -3644,7 +3641,7 @@ async def create_array(
36443641
array_array, array_bytes, bytes_bytes = _parse_chunk_encoding_v3(
36453642
compressors=compressors, filters=filters, dtype=dtype_parsed
36463643
)
3647-
sub_codecs = (*array_array, array_bytes, *bytes_bytes)
3644+
sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes))
36483645
codecs_out: tuple[Codec, ...]
36493646
if shard_shape_parsed is not None:
36503647
sharding_codec = ShardingCodec(chunk_shape=chunk_shape_parsed, codecs=sub_codecs)
@@ -3688,7 +3685,7 @@ def _parse_chunk_key_encoding(
36883685
"""
36893686
if data is None:
36903687
if zarr_format == 2:
3691-
result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "/"})
3688+
result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "."})
36923689
else:
36933690
result = ChunkKeyEncoding.from_dict({"name": "default", "separator": "/"})
36943691
elif isinstance(data, ChunkKeyEncoding):
@@ -3769,46 +3766,56 @@ def _get_default_chunk_encoding_v2(
37693766

37703767
def _parse_chunk_encoding_v2(
37713768
*,
3772-
compressor: numcodecs.abc.Codec | Literal["auto"],
3773-
filters: tuple[numcodecs.abc.Codec, ...] | Literal["auto"],
3769+
compressor: CompressorsParam,
3770+
filters: FiltersParam,
37743771
dtype: np.dtype[Any],
37753772
) -> tuple[tuple[numcodecs.abc.Codec, ...] | None, numcodecs.abc.Codec | None]:
37763773
"""
37773774
Generate chunk encoding classes for v2 arrays with optional defaults.
37783775
"""
37793776
default_filters, default_compressor = _get_default_chunk_encoding_v2(dtype)
3780-
_filters: tuple[numcodecs.abc.Codec, ...] = ()
3777+
3778+
_filters: tuple[numcodecs.abc.Codec, ...] | None = None
3779+
_compressor: numcodecs.abc.Codec | None = None
3780+
37813781
if compressor == "auto":
37823782
_compressor = default_compressor
37833783
else:
3784-
_compressor = compressor
3784+
if isinstance(compressor, Iterable):
3785+
raise TypeError("For Zarr v2 arrays, the `compressor` must be a single codec.")
3786+
_compressor = parse_compressor(compressor)
37853787
if filters == "auto":
37863788
_filters = default_filters
37873789
else:
3788-
_filters = filters
3790+
if not all(isinstance(f, numcodecs.abc.Codec) for f in filters):
3791+
raise TypeError(
3792+
"For Zarr v2 arrays, all elements of `filters` must be numcodecs codecs."
3793+
)
3794+
_filters = parse_filters(filters)
3795+
37893796
return _filters, _compressor
37903797

37913798

37923799
def _parse_chunk_encoding_v3(
37933800
*,
3794-
compressors: Iterable[BytesBytesCodec | dict[str, JSON]] | Literal["auto"],
3795-
filters: Iterable[ArrayArrayCodec | dict[str, JSON]] | Literal["auto"],
3801+
compressors: CompressorsParam,
3802+
filters: FiltersParam,
37963803
dtype: np.dtype[Any],
37973804
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
37983805
"""
37993806
Generate chunk encoding classes for v3 arrays with optional defaults.
38003807
"""
38013808
default_array_array, default_array_bytes, default_bytes_bytes = _get_default_encoding_v3(dtype)
3802-
maybe_bytes_bytes: Iterable[BytesBytesCodec | dict[str, JSON]]
3803-
maybe_array_array: Iterable[ArrayArrayCodec | dict[str, JSON]]
3809+
maybe_bytes_bytes: Iterable[Codec | dict[str, JSON]]
3810+
maybe_array_array: Iterable[Codec | dict[str, JSON]]
38043811

38053812
if compressors == "auto":
38063813
out_bytes_bytes = default_bytes_bytes
38073814
else:
38083815
if isinstance(compressors, dict | Codec):
38093816
maybe_bytes_bytes = (compressors,)
38103817
else:
3811-
maybe_bytes_bytes = compressors
3818+
maybe_bytes_bytes = cast(Iterable[Codec | dict[str, JSON]], compressors)
38123819

38133820
out_bytes_bytes = tuple(_parse_bytes_bytes_codec(c) for c in maybe_bytes_bytes)
38143821

@@ -3818,7 +3825,7 @@ def _parse_chunk_encoding_v3(
38183825
if isinstance(filters, dict | Codec):
38193826
maybe_array_array = (filters,)
38203827
else:
3821-
maybe_array_array = filters
3828+
maybe_array_array = cast(Iterable[Codec | dict[str, JSON]], filters)
38223829
out_array_array = tuple(_parse_array_array_codec(c) for c in maybe_array_array)
38233830

38243831
return out_array_array, default_array_bytes, out_bytes_bytes

src/zarr/core/chunk_key_encodings.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,16 @@ def __init__(self, *, separator: SeparatorLiteral) -> None:
3636
object.__setattr__(self, "separator", separator_parsed)
3737

3838
@classmethod
39-
def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncoding) -> ChunkKeyEncoding:
39+
def from_dict(
40+
cls, data: dict[str, JSON] | ChunkKeyEncoding | ChunkKeyEncodingParams
41+
) -> ChunkKeyEncoding:
4042
if isinstance(data, ChunkKeyEncoding):
4143
return data
4244

45+
# handle ChunkKeyEncodingParams
46+
if "name" in data and "separator" in data:
47+
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}
48+
4349
# configuration is optional for chunk key encodings
4450
name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False)
4551
if name_parsed == "default":

src/zarr/core/group.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from zarr.core.array import (
2222
Array,
2323
AsyncArray,
24-
CompressionParam,
24+
CompressorsParam,
2525
FiltersParam,
2626
_build_parents,
2727
create_array,
@@ -511,7 +511,7 @@ async def open(
511511
)
512512
if zarr_json_bytes is not None and zgroup_bytes is not None:
513513
# warn and favor v3
514-
msg = f"Both zarr.json (zarr v3) and .zgroup (zarr v2) metadata objects exist at {store_path}."
514+
msg = f"Both zarr.json (Zarr v3) and .zgroup (Zarr v2) metadata objects exist at {store_path}. Zarr v3 will be used."
515515
warnings.warn(msg, stacklevel=1)
516516
if zarr_json_bytes is None and zgroup_bytes is None:
517517
raise FileNotFoundError(
@@ -1011,7 +1011,7 @@ async def create_array(
10111011
chunks: ChunkCoords | Literal["auto"] = "auto",
10121012
shards: ChunkCoords | Literal["auto"] | None = None,
10131013
filters: FiltersParam = "auto",
1014-
compressors: CompressionParam = "auto",
1014+
compressors: CompressorsParam = "auto",
10151015
fill_value: Any | None = 0,
10161016
order: MemoryOrder | None = "C",
10171017
attributes: dict[str, JSON] | None = None,
@@ -2539,8 +2539,8 @@ def array(
25392539
dtype: npt.DTypeLike,
25402540
chunks: ChunkCoords | Literal["auto"] = "auto",
25412541
shards: ChunkCoords | Literal["auto"] | None = None,
2542-
filters: Iterable[dict[str, JSON] | Codec] = "auto",
2543-
compressors: Iterable[dict[str, JSON] | Codec] = "auto",
2542+
filters: FiltersParam = "auto",
2543+
compressors: CompressorsParam = "auto",
25442544
fill_value: Any | None = 0,
25452545
order: MemoryOrder | None = "C",
25462546
attributes: dict[str, JSON] | None = None,

src/zarr/core/metadata/v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy.typing as npt
1717

1818
from zarr.core.buffer import Buffer, BufferPrototype
19-
from zarr.core.common import JSON, ChunkCoords
19+
from zarr.core.common import ChunkCoords
2020

2121
import json
2222
from dataclasses import dataclass, field, fields, replace
@@ -27,7 +27,7 @@
2727
from zarr.core.array_spec import ArrayConfig, ArraySpec
2828
from zarr.core.chunk_grids import RegularChunkGrid
2929
from zarr.core.chunk_key_encodings import parse_separator
30-
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, MemoryOrder, parse_shapelike
30+
from zarr.core.common import JSON, ZARRAY_JSON, ZATTRS_JSON, MemoryOrder, parse_shapelike
3131
from zarr.core.config import config, parse_indexing_order
3232
from zarr.core.metadata.common import parse_attributes
3333

@@ -352,7 +352,7 @@ def _default_compressor(
352352
else:
353353
raise ValueError(f"Unsupported dtype kind {dtype.kind}")
354354

355-
return default_compressor.get(dtype_key, None)
355+
return cast(dict[str, JSON] | None, default_compressor.get(dtype_key, None))
356356

357357

358358
def _default_filters(
@@ -372,4 +372,4 @@ def _default_filters(
372372
else:
373373
raise ValueError(f"Unsupported dtype kind {dtype.kind}")
374374

375-
return default_filters.get(dtype_key, None)
375+
return cast(list[dict[str, JSON]] | None, default_filters.get(dtype_key, None))

src/zarr/core/metadata/v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def default_fill_value(dtype: DataType) -> str | bytes | np.generic:
548548
else:
549549
np_dtype = dtype.to_numpy()
550550
np_dtype = cast(np.dtype[Any], np_dtype)
551-
return np_dtype.type(0)
551+
return np_dtype.type(0) # type: ignore[misc]
552552

553553

554554
# For type checking

src/zarr/registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _resolve_codec(data: dict[str, JSON]) -> Codec:
161161
return get_codec_class(data["name"]).from_dict(data) # type: ignore[arg-type]
162162

163163

164-
def _parse_bytes_bytes_codec(data: dict[str, JSON] | BytesBytesCodec) -> BytesBytesCodec:
164+
def _parse_bytes_bytes_codec(data: dict[str, JSON] | Codec) -> BytesBytesCodec:
165165
"""
166166
Normalize the input to a ``BytesBytesCodec`` instance.
167167
If the input is already a ``BytesBytesCodec``, it is returned as is. If the input is a dict, it
@@ -173,6 +173,8 @@ def _parse_bytes_bytes_codec(data: dict[str, JSON] | BytesBytesCodec) -> BytesBy
173173
msg = f"Expected a dict representation of a BytesBytesCodec; got a dict representation of a {type(result)} instead."
174174
raise TypeError(msg)
175175
else:
176+
if not isinstance(data, BytesBytesCodec):
177+
raise TypeError(f"Expected a BytesBytesCodec. Got {type(data)} instead.")
176178
result = data
177179
return result
178180

@@ -193,7 +195,7 @@ def _parse_array_bytes_codec(data: dict[str, JSON] | ArrayBytesCodec) -> ArrayBy
193195
return result
194196

195197

196-
def _parse_array_array_codec(data: dict[str, JSON] | ArrayArrayCodec) -> ArrayArrayCodec:
198+
def _parse_array_array_codec(data: dict[str, JSON] | Codec) -> ArrayArrayCodec:
197199
"""
198200
Normalize the input to a ``ArrayArrayCodec`` instance.
199201
If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it
@@ -205,6 +207,8 @@ def _parse_array_array_codec(data: dict[str, JSON] | ArrayArrayCodec) -> ArrayAr
205207
msg = f"Expected a dict representation of a ArrayArrayCodec; got a dict representation of a {type(result)} instead."
206208
raise TypeError(msg)
207209
else:
210+
if not isinstance(data, ArrayArrayCodec):
211+
raise TypeError(f"Expected a ArrayArrayCodec. Got {type(data)} instead.")
208212
result = data
209213
return result
210214

tests/test_array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
2323
from zarr.core.group import AsyncGroup
2424
from zarr.core.indexing import ceildiv
25-
from zarr.core.metadata.v3 import DataType
25+
from zarr.core.metadata.v2 import ArrayV2Metadata
26+
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
2627
from zarr.core.sync import sync
2728
from zarr.errors import ContainsArrayError, ContainsGroupError
2829
from zarr.storage import LocalStore, MemoryStore
@@ -885,7 +886,9 @@ async def test_nbytes(
885886
assert arr.nbytes == np.prod(arr.shape) * arr.dtype.itemsize
886887

887888

888-
def _get_partitioning(data: AsyncArray) -> tuple[tuple[int, ...], tuple[int, ...] | None]:
889+
def _get_partitioning(
890+
data: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata],
891+
) -> tuple[tuple[int, ...], tuple[int, ...] | None]:
889892
"""
890893
Get the shard shape and chunk shape of an array. If the array is not sharded, the shard shape
891894
will be None.

0 commit comments

Comments
 (0)