Skip to content

Commit a9850bf

Browse files
committed
better codec parsing
1 parent 4e978f9 commit a9850bf

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

src/zarr/core/array.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import warnings
55
from asyncio import gather
6-
from collections.abc import Iterable, Mapping
6+
from collections.abc import Iterable
77
from dataclasses import dataclass, field
88
from itertools import starmap
99
from logging import getLogger
@@ -93,7 +93,12 @@
9393
from zarr.core.metadata.v3 import DataType, parse_node_type_array
9494
from zarr.core.sync import sync
9595
from zarr.errors import MetadataValidationError
96-
from zarr.registry import get_codec_class, get_pipeline_class
96+
from zarr.registry import (
97+
_parse_array_array_codec,
98+
_parse_bytes_bytes_codec,
99+
_resolve_codec,
100+
get_pipeline_class,
101+
)
97102
from zarr.storage import StoreLike, make_store_path
98103
from zarr.storage.common import StorePath, ensure_no_existing_node
99104

@@ -3546,7 +3551,6 @@ async def create_array(
35463551
# TODO: figure out why putting these imports at top-level causes circular imports
35473552
from zarr.codecs.sharding import ShardingCodec
35483553

3549-
# TODO: fix this when modes make sense. It should be `w` for overwriting, `w-` otherwise
35503554
mode: Literal["a"] = "a"
35513555
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
35523556
config_parsed = parse_array_config(config)
@@ -3678,7 +3682,7 @@ def _get_default_encoding_v3(
36783682
dtype_key = "numeric"
36793683

36803684
codec_dicts = default_codecs[dtype_key]
3681-
codecs = tuple(get_codec_class(c["name"]).from_dict(c) for c in codec_dicts)
3685+
codecs = tuple(_resolve_codec(c) for c in codec_dicts)
36823686
array_bytes_maybe = None
36833687
array_array: list[ArrayArrayCodec] = []
36843688
bytes_bytes: list[BytesBytesCodec] = []
@@ -3710,21 +3714,11 @@ def _get_default_chunk_encoding_v2(
37103714
"""
37113715
Get the default chunk encoding for zarr v2 arrays, given a dtype
37123716
"""
3713-
if dtype.kind in "biufcmM":
3714-
dtype_key = "numeric"
3715-
elif dtype.kind in "U":
3716-
dtype_key = "string"
3717-
elif dtype.kind in "OSV":
3718-
dtype_key = "bytes"
3719-
else:
3720-
raise ValueError(f"Unsupported dtype kind {dtype.kind}")
37213717

3722-
compressor_dict = zarr_config.get("array.v2_default_compressor").get(dtype_key, None)
3723-
filter_dicts = zarr_config.get("array.v2_default_filters").get(dtype_key, [])
3718+
compressor_dict = _default_compressor(dtype)
3719+
filter_dicts = _default_filters(dtype)
37243720

3725-
compressor = None
3726-
if compressor_dict is not None:
3727-
compressor = numcodecs.get_codec(compressor_dict)
3721+
compressor = numcodecs.get_codec(compressor_dict)
37283722
filters = tuple(numcodecs.get_codec(f) for f in filter_dicts)
37293723
return filters, compressor
37303724

@@ -3753,28 +3747,34 @@ def _parse_chunk_encoding_v2(
37533747

37543748
def _parse_chunk_encoding_v3(
37553749
*,
3756-
compression: Iterable[BytesBytesCodec] | Literal["auto"],
3757-
filters: Iterable[ArrayArrayCodec] | Literal["auto"],
3750+
compression: Iterable[BytesBytesCodec | dict[str, JSON]] | Literal["auto"],
3751+
filters: Iterable[ArrayArrayCodec | dict[str, JSON]] | Literal["auto"],
37583752
dtype: np.dtype[Any],
37593753
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
37603754
"""
37613755
Generate chunk encoding classes for v3 arrays with optional defaults.
37623756
"""
37633757
default_array_array, default_array_bytes, default_bytes_bytes = _get_default_encoding_v3(dtype)
3758+
maybe_bytes_bytes: Iterable[BytesBytesCodec | dict[str, JSON]]
3759+
maybe_array_array: Iterable[ArrayArrayCodec | dict[str, JSON]]
37643760

37653761
if compression == "auto":
37663762
out_bytes_bytes = default_bytes_bytes
37673763
else:
3768-
if isinstance(compression, Mapping | Codec):
3769-
out_bytes_bytes = (compression,)
3764+
if isinstance(compression, dict | Codec):
3765+
maybe_bytes_bytes = (compression,)
37703766
else:
3771-
out_bytes_bytes = tuple(compression)
3767+
maybe_bytes_bytes = compression
3768+
3769+
out_bytes_bytes = tuple(_parse_bytes_bytes_codec(c) for c in maybe_bytes_bytes)
3770+
37723771
if filters == "auto":
37733772
out_array_array = default_array_array
37743773
else:
3775-
if isinstance(filters, Mapping | Codec):
3776-
out_array_array = (filters,)
3774+
if isinstance(filters, dict | Codec):
3775+
maybe_array_array = (filters,)
37773776
else:
3778-
out_array_array = tuple(filters)
3777+
maybe_array_array = filters
3778+
out_array_array = tuple(_parse_array_array_codec(c) for c in maybe_array_array)
37793779

37803780
return out_array_array, default_array_bytes, out_bytes_bytes

0 commit comments

Comments
 (0)