Skip to content

Commit 43877c0

Browse files
committed
compressors -> compression, auto chunking, auto sharding, auto compression, auto filters
1 parent 74f731a commit 43877c0

File tree

10 files changed

+328
-147
lines changed

10 files changed

+328
-147
lines changed

src/zarr/core/array.py

Lines changed: 192 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
from dataclasses import dataclass, field
88
from itertools import starmap
99
from logging import getLogger
10-
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
10+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast, overload
1111
from warnings import warn
1212

13+
import numcodecs
1314
import numpy as np
1415
import numpy.typing as npt
1516

1617
from zarr._compat import _deprecate_positional_args
18+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
1719
from zarr.abc.store import Store, set_or_delete
18-
from zarr.core.common import _default_zarr_version
1920
from zarr.codecs._v2 import V2Codec
21+
from zarr.codecs.zstd import ZstdCodec
2022
from zarr.core._info import ArrayInfo
2123
from zarr.core.array_spec import ArrayConfig, ArrayConfigParams, parse_array_config
2224
from zarr.core.attributes import Attributes
@@ -26,9 +28,10 @@
2628
NDBuffer,
2729
default_buffer_prototype,
2830
)
29-
from zarr.core.chunk_grids import RegularChunkGrid, normalize_chunks
31+
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
3032
from zarr.core.chunk_key_encodings import (
3133
ChunkKeyEncoding,
34+
ChunkKeyEncodingParams,
3235
DefaultChunkKeyEncoding,
3336
V2ChunkKeyEncoding,
3437
)
@@ -41,6 +44,7 @@
4144
MemoryOrder,
4245
ShapeLike,
4346
ZarrFormat,
47+
_default_zarr_version,
4448
_warn_order_kwarg,
4549
concurrent_map,
4650
parse_dtype,
@@ -87,15 +91,15 @@
8791
from zarr.core.metadata.v3 import DataType, parse_node_type_array
8892
from zarr.core.sync import sync
8993
from zarr.errors import MetadataValidationError
90-
from zarr.registry import get_pipeline_class
94+
from zarr.registry import get_codec_class, get_pipeline_class
9195
from zarr.storage import StoreLike, make_store_path
9296
from zarr.storage.common import StorePath, ensure_no_existing_node
9397

9498
if TYPE_CHECKING:
95-
from collections.abc import Iterable, Iterator, Sequence
99+
from collections.abc import Iterator, Sequence
96100
from typing import Self
97101

98-
from zarr.abc.codec import Codec, CodecPipeline
102+
from zarr.abc.codec import CodecPipeline
99103
from zarr.core.group import AsyncGroup
100104

101105
# Array and AsyncArray are defined in the base ``zarr`` namespace
@@ -3454,26 +3458,29 @@ def _get_default_codecs(
34543458
return [{"name": codec_id, "configuration": {}} for codec_id in default_codecs[dtype_key]]
34553459

34563460

3461+
FiltersParam: TypeAlias = (
3462+
Iterable[dict[str, JSON] | Codec] | Iterable[numcodecs.abc.Codec] | Literal["auto"]
3463+
)
3464+
CompressionParam: TypeAlias = (
3465+
Iterable[dict[str, JSON] | Codec] | Codec | numcodecs.abc.Codec | Literal["auto"]
3466+
)
3467+
3468+
34573469
async def create_array(
34583470
store: str | StoreLike,
34593471
*,
3460-
path: str | None = None,
3472+
name: str | None = None,
34613473
shape: ShapeLike,
34623474
dtype: npt.DTypeLike,
3463-
chunk_shape: ChunkCoords,
3475+
chunk_shape: ChunkCoords | Literal["auto"] = "auto",
34643476
shard_shape: ChunkCoords | None = None,
3465-
filters: Iterable[dict[str, JSON] | Codec] = (),
3466-
compressors: Iterable[dict[str, JSON] | Codec] = (),
3477+
filters: FiltersParam = "auto",
3478+
compression: CompressionParam = "auto",
34673479
fill_value: Any | None = 0,
34683480
order: MemoryOrder | None = "C",
34693481
zarr_format: ZarrFormat | None = 3,
34703482
attributes: dict[str, JSON] | None = None,
3471-
chunk_key_encoding: (
3472-
ChunkKeyEncoding
3473-
| tuple[Literal["default"], Literal[".", "/"]]
3474-
| tuple[Literal["v2"], Literal[".", "/"]]
3475-
| None
3476-
) = ("default", "/"),
3483+
chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingParams | None = None,
34773484
dimension_names: Iterable[str] | None = None,
34783485
storage_options: dict[str, Any] | None = None,
34793486
overwrite: bool = False,
@@ -3486,8 +3493,8 @@ async def create_array(
34863493
----------
34873494
store : str or Store
34883495
Store or path to directory in file system or name of zip file.
3489-
path : str or None, optional
3490-
The name of the array within the store. If ``path`` is ``None``, the array will be located
3496+
name : str or None, optional
3497+
The name of the array within the store. If ``name`` is ``None``, the array will be located
34913498
at the root of the store.
34923499
shape : ChunkCoords
34933500
Shape of the array.
@@ -3499,7 +3506,7 @@ async def create_array(
34993506
Shard shape of the array. The default value of ``None`` results in no sharding at all.
35003507
filters : Iterable[Codec], optional
35013508
List of filters to apply to the array.
3502-
compressors : Iterable[Codec], optional
3509+
compression : Iterable[Codec], optional
35033510
List of compressors to apply to the array.
35043511
fill_value : Any, optional
35053512
Fill value for the array.
@@ -3533,75 +3540,85 @@ async def create_array(
35333540
zarr_format = _default_zarr_version()
35343541

35353542
# TODO: figure out why putting these imports at top-level causes circular imports
3536-
from zarr.codecs.bytes import BytesCodec
35373543
from zarr.codecs.sharding import ShardingCodec
35383544

35393545
# TODO: fix this when modes make sense. It should be `w` for overwriting, `w-` otherwise
35403546
mode: Literal["a"] = "a"
3541-
3542-
store_path = await make_store_path(store, path=path, mode=mode, storage_options=storage_options)
3543-
sub_codecs = (*filters, BytesCodec(), *compressors)
3544-
_dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
3547+
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
35453548
config_parsed = parse_array_config(config)
35463549
shape_parsed = parse_shapelike(shape)
3550+
chunk_key_encoding_parsed = _parse_chunk_key_encoding(
3551+
chunk_key_encoding, zarr_format=zarr_format
3552+
)
3553+
store_path = await make_store_path(store, path=name, mode=mode, storage_options=storage_options)
3554+
shard_shape_parsed, chunk_shape_parsed = _auto_partition(
3555+
shape_parsed, shard_shape, chunk_shape, dtype_parsed
3556+
)
35473557
result: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
3558+
35483559
if zarr_format == 2:
3549-
if shard_shape is not None:
3560+
if shard_shape_parsed is not None:
35503561
msg = (
35513562
'Zarr v2 arrays can only be created with `shard_shape` set to `None` or `"auto"`.'
35523563
f"Got `shard_shape={shard_shape}` instead."
35533564
)
35543565

35553566
raise ValueError(msg)
3556-
if len(tuple(compressors)) > 1:
3557-
compressor, *rest = compressors
3558-
else:
3559-
compressor = None
3560-
rest = []
3561-
filters = (*filters, *rest)
3567+
if filters != "auto" and not all(isinstance(f, numcodecs.abc.Codec) for f in filters):
3568+
raise TypeError(
3569+
"For Zarr v2 arrays, all elements of `filters` must be numcodecs codecs."
3570+
)
3571+
filters = cast(Iterable[numcodecs.abc.Codec] | Literal["auto"], filters)
3572+
filters_parsed, compressor_parsed = _parse_chunk_encoding_v2(
3573+
compression=compression, filters=filters, dtype=dtype_parsed
3574+
)
35623575
if dimension_names is not None:
35633576
raise ValueError("Zarr v2 arrays do not support dimension names.")
35643577
if order is None:
35653578
order_parsed = zarr_config.get("array.order")
35663579
else:
35673580
order_parsed = order
3581+
35683582
result = await AsyncArray._create_v2(
35693583
store_path=store_path,
35703584
shape=shape_parsed,
3571-
dtype=_dtype_parsed,
3572-
chunks=chunk_shape,
3573-
dimension_separator="/",
3585+
dtype=dtype_parsed,
3586+
chunks=chunk_shape_parsed,
3587+
dimension_separator=chunk_key_encoding_parsed.separator,
35743588
fill_value=fill_value,
35753589
order=order_parsed,
3576-
filters=filters,
3577-
compressor=compressor,
3590+
filters=filters_parsed,
3591+
compressor=compressor_parsed,
35783592
attributes=attributes,
35793593
overwrite=overwrite,
35803594
config=config_parsed,
35813595
)
35823596
else:
3583-
if shard_shape is not None:
3584-
sharding_codec = ShardingCodec(chunk_shape=chunk_shape, codecs=sub_codecs)
3597+
array_array, array_bytes, bytes_bytes = _get_default_encoding_v3(dtype_parsed)
3598+
sub_codecs = (*array_array, array_bytes, *bytes_bytes)
3599+
codecs_out: tuple[Codec, ...]
3600+
if shard_shape_parsed is not None:
3601+
sharding_codec = ShardingCodec(chunk_shape=chunk_shape_parsed, codecs=sub_codecs)
35853602
sharding_codec.validate(
3586-
shape=chunk_shape,
3587-
dtype=dtype,
3603+
shape=chunk_shape_parsed,
3604+
dtype=dtype_parsed,
35883605
chunk_grid=RegularChunkGrid(chunk_shape=shard_shape),
35893606
)
3590-
codecs = (sharding_codec,)
3607+
codecs_out = (sharding_codec,)
35913608
chunks_out = shard_shape
35923609
else:
3593-
chunks_out = chunk_shape
3594-
codecs = sub_codecs
3610+
chunks_out = chunk_shape_parsed
3611+
codecs_out = sub_codecs
35953612

35963613
result = await AsyncArray._create_v3(
35973614
store_path=store_path,
35983615
shape=shape_parsed,
3599-
dtype=_dtype_parsed,
3616+
dtype=dtype_parsed,
36003617
fill_value=fill_value,
36013618
attributes=attributes,
36023619
chunk_shape=chunks_out,
3603-
chunk_key_encoding=chunk_key_encoding,
3604-
codecs=codecs,
3620+
chunk_key_encoding=chunk_key_encoding_parsed,
3621+
codecs=codecs_out,
36053622
dimension_names=dimension_names,
36063623
overwrite=overwrite,
36073624
config=config_parsed,
@@ -3612,3 +3629,132 @@ async def create_array(
36123629
selection=slice(None), value=data, prototype=default_buffer_prototype()
36133630
)
36143631
return result
3632+
3633+
3634+
def _parse_chunk_key_encoding(
3635+
data: ChunkKeyEncoding | ChunkKeyEncodingParams | None, zarr_format: ZarrFormat
3636+
) -> ChunkKeyEncoding:
3637+
"""
3638+
Take an implicit specification of a chunk key encoding and parse it into a ChunkKeyEncoding object.
3639+
"""
3640+
if data is None:
3641+
if zarr_format == 2:
3642+
result = ChunkKeyEncoding.from_dict({"name": "v2", "separator": "/"})
3643+
else:
3644+
result = ChunkKeyEncoding.from_dict({"name": "default", "separator": "/"})
3645+
elif isinstance(data, ChunkKeyEncoding):
3646+
result = data
3647+
else:
3648+
result = ChunkKeyEncoding.from_dict(data)
3649+
if zarr_format == 2 and result.name != "v2":
3650+
msg = (
3651+
"Invalid chunk key encoding. For Zarr v2 arrays, the `name` field of the "
3652+
f"chunk key encoding must be 'v2'. Got `name` = {result.name} instead."
3653+
)
3654+
raise ValueError(msg)
3655+
return result
3656+
3657+
3658+
def _get_default_encoding_v3(
3659+
np_dtype: np.dtype[Any],
3660+
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
3661+
"""
3662+
Get the default ArrayArrayCodecs, ArrayBytesCodec, and BytesBytesCodec for a given dtype.
3663+
"""
3664+
default_codecs = zarr_config.get("array.v3_default_codecs")
3665+
dtype = DataType.from_numpy(np_dtype)
3666+
if dtype == DataType.string:
3667+
dtype_key = "string"
3668+
elif dtype == DataType.bytes:
3669+
dtype_key = "bytes"
3670+
else:
3671+
dtype_key = "numeric"
3672+
3673+
codec_names = default_codecs[dtype_key]
3674+
array_bytes_cls, *rest = tuple(get_codec_class(codec_name) for codec_name in codec_names)
3675+
array_bytes: ArrayBytesCodec = cast(ArrayBytesCodec, array_bytes_cls())
3676+
# TODO: we should compress bytes and strings by default!
3677+
# The current default codecs only lists names, and strings / bytes are not compressed at all,
3678+
# so we insert the ZstdCodec at the end of the list as a default
3679+
bytes_bytes: tuple[BytesBytesCodec, ...]
3680+
array_array: tuple[ArrayArrayCodec, ...] = ()
3681+
if len(rest) == 0:
3682+
bytes_bytes = (ZstdCodec(),)
3683+
else:
3684+
bytes_bytes = cast(tuple[BytesBytesCodec, ...], tuple(r() for r in rest))
3685+
3686+
return array_array, array_bytes, bytes_bytes
3687+
3688+
3689+
def _get_default_chunk_encoding_v2(
3690+
dtype: np.dtype[np.generic],
3691+
) -> tuple[tuple[numcodecs.abc.Codec, ...], numcodecs.abc.Codec]:
3692+
"""
3693+
Get the default chunk encoding for zarr v2 arrays, given a dtype
3694+
"""
3695+
codec_id_dict = zarr_config.get("array.v2_default_compressor")
3696+
3697+
if dtype.kind in "biufcmM":
3698+
dtype_key = "numeric"
3699+
codec_type = "compressor"
3700+
elif dtype.kind in "U":
3701+
dtype_key = "string"
3702+
codec_type = "filter"
3703+
elif dtype.kind in "OSV":
3704+
dtype_key = "bytes"
3705+
codec_type = "filter"
3706+
else:
3707+
raise ValueError(f"Unsupported dtype kind {dtype.kind}")
3708+
codec_id = codec_id_dict[dtype_key]
3709+
codec_instance = numcodecs.get_codec({"id": codec_id})
3710+
if codec_type == "compressor":
3711+
return (), codec_instance
3712+
elif codec_type == "filter":
3713+
return codec_instance, numcodecs.Zstd()
3714+
else:
3715+
raise ValueError(f"Unsupported codec type {codec_type}")
3716+
3717+
3718+
def _parse_chunk_encoding_v2(
3719+
*,
3720+
compression: numcodecs.abc.Codec | Literal["auto"],
3721+
filters: tuple[numcodecs.abc.Codec, ...] | Literal["auto"],
3722+
dtype: np.dtype[np.generic],
3723+
) -> tuple[tuple[numcodecs.abc.Codec, ...], numcodecs.abc.Codec]:
3724+
"""
3725+
Generate chunk encoding classes for v2 arrays with optional defaults.
3726+
"""
3727+
default_filters, default_compressor = _get_default_chunk_encoding_v2(dtype)
3728+
_filters: tuple[numcodecs.abc.Codec, ...] = ()
3729+
if compression == "auto":
3730+
_compressor = default_compressor
3731+
else:
3732+
_compressor = compression
3733+
if filters == "auto":
3734+
_filters = default_filters
3735+
else:
3736+
_filters = filters
3737+
return _filters, _compressor
3738+
3739+
3740+
def _parse_chunk_encoding_v3(
3741+
*,
3742+
compression: Iterable[BytesBytesCodec] | Literal["auto"],
3743+
filters: Iterable[ArrayArrayCodec] | Literal["auto"],
3744+
dtype: np.dtype[np.generic],
3745+
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
3746+
"""
3747+
Generate chunk encoding classes for v3 arrays with optional defaults.
3748+
"""
3749+
default_array_array, default_array_bytes, default_bytes_bytes = _get_default_encoding_v3(dtype)
3750+
3751+
if compression == "auto":
3752+
out_bytes_bytes = default_bytes_bytes
3753+
else:
3754+
out_bytes_bytes = tuple(compression)
3755+
if filters == "auto":
3756+
out_array_array = default_array_array
3757+
else:
3758+
out_array_array = tuple(filters)
3759+
3760+
return out_array_array, default_array_bytes, out_bytes_bytes

0 commit comments

Comments
 (0)