Skip to content

Commit 02053e9

Browse files
committed
modify _default_compressor to _default_filters_and_compressor
1 parent 1fa42d9 commit 02053e9

File tree

5 files changed

+73
-63
lines changed

5 files changed

+73
-63
lines changed

src/zarr/core/array.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,6 @@ async def create(
492492
order=order,
493493
)
494494
elif zarr_format == 2:
495-
if dtype is str or dtype == "str":
496-
# another special case: zarr v2 added the vlen-utf8 codec
497-
vlen_codec: dict[str, JSON] = {"id": "vlen-utf8"}
498-
if filters and not any(x["id"] == "vlen-utf8" for x in filters):
499-
filters = list(filters) + [vlen_codec]
500-
else:
501-
filters = [vlen_codec]
502-
503495
if codecs is not None:
504496
raise ValueError(
505497
"codecs cannot be used for arrays with version 2. Use filters and compressor instead."

src/zarr/core/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def reset(self) -> None:
6464
},
6565
"buffer": "zarr.core.buffer.cpu.Buffer",
6666
"ndbuffer": "zarr.core.buffer.cpu.NDBuffer",
67-
"v2_dtype_kind_to_default_compressor": {
68-
"biufcmM": "zstd",
69-
"OSUV": "vlen-bytes",
67+
"v2_dtype_kind_to_default_filters_and_compressor": {
68+
"biufcmM": ["zstd"],
69+
"OSUV": ["vlen-utf8"],
7070
},
7171
}
7272
],

src/zarr/core/metadata/v2.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable
55
from enum import Enum
66
from functools import cached_property
7-
from typing import TYPE_CHECKING, TypedDict, cast
7+
from typing import TYPE_CHECKING, Any, TypedDict, cast
88

99
from zarr.abc.metadata import Metadata
1010

@@ -71,8 +71,14 @@ def __init__(
7171
shape_parsed = parse_shapelike(shape)
7272
dtype_parsed = parse_dtype(dtype)
7373
chunks_parsed = parse_shapelike(chunks)
74-
if compressor is None:
75-
compressor = _default_compressor(dtype_parsed)
74+
if not filters and not compressor:
75+
filters, compressor = _default_filters_and_compressor(dtype_parsed)
76+
if dtype is str or dtype == "str":
77+
vlen_codec: dict[str, JSON] = {"id": "vlen-utf8"}
78+
if filters and not any(x["id"] == "vlen-utf8" for x in filters):
79+
filters = list(filters) + [vlen_codec]
80+
else:
81+
filters = [vlen_codec]
7682
compressor_parsed = parse_compressor(compressor)
7783
order_parsed = parse_indexing_order(order)
7884
dimension_separator_parsed = parse_separator(dimension_separator)
@@ -240,15 +246,15 @@ def parse_filters(data: object) -> tuple[numcodecs.abc.Codec, ...] | None:
240246
raise TypeError(msg)
241247

242248

243-
def parse_compressor(data: object) -> numcodecs.abc.Codec:
249+
def parse_compressor(data: object) -> numcodecs.abc.Codec | None:
244250
"""
245251
Parse a potential compressor.
246252
"""
247-
if isinstance(data, numcodecs.abc.Codec):
253+
if data is None or isinstance(data, numcodecs.abc.Codec):
248254
return data
249255
if isinstance(data, dict):
250256
return numcodecs.get_codec(data)
251-
msg = f"Invalid compressor. Expected a numcodecs.abc.Codec, or a dict representation of a numcodecs.abc.Codec. Got {type(data)} instead."
257+
msg = f"Invalid compressor. Expected None, a numcodecs.abc.Codec, or a dict representation of a numcodecs.abc.Codec. Got {type(data)} instead."
252258
raise ValueError(msg)
253259

254260

@@ -330,14 +336,18 @@ def _default_fill_value(dtype: np.dtype[Any]) -> Any:
330336
return dtype.type(0)
331337

332338

333-
def _default_compressor(dtype: np.dtype[Any]) -> numcodecs.abc.Codec:
334-
"""Get the default compressor for a type.
339+
def _default_filters_and_compressor(
340+
dtype: np.dtype[Any],
341+
) -> tuple[list[dict[str, str]], dict[str, str] | None]:
342+
"""Get the default filters and compressor for a dtype.
335343
336344
The config contains a mapping from numpy dtype kind to the default compressor.
337345
https://numpy.org/doc/2.1/reference/generated/numpy.dtype.kind.html
338346
"""
339-
dtype_kind_to_default_compressor = config.get("v2_dtype_kind_to_default_compressor")
340-
for dtype_kinds, compressor in dtype_kind_to_default_compressor.items():
347+
dtype_kind_to_default_compressor = config.get("v2_dtype_kind_to_default_filters_and_compressor")
348+
for dtype_kinds, filters_and_compressor in dtype_kind_to_default_compressor.items():
341349
if dtype.kind in dtype_kinds:
342-
return numcodecs.get_codec({"id": compressor})
343-
raise ValueError(f"No default compressor found for dtype {dtype} of kind {dtype.kind}")
350+
filters = [{"id": f} for f in filters_and_compressor]
351+
compressor = None
352+
return filters, compressor
353+
return [], None

tests/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def test_config_defaults_set() -> None:
6363
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
6464
"vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec",
6565
},
66-
"v2_dtype_kind_to_default_compressor": {
67-
"biufcmM": "zstd",
68-
"OSUV": "vlen-bytes",
66+
"v2_dtype_kind_to_default_filters_and_compressor": {
67+
"biufcmM": ["zstd"],
68+
"OSUV": ["vlen-utf8"],
6969
},
7070
}
7171
]

tests/test_v2.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -80,36 +80,43 @@ def test_codec_pipeline() -> None:
8080

8181
@pytest.mark.parametrize("dtype", ["|S", "|V"])
8282
async def test_v2_encode_decode(dtype):
83-
store = zarr.storage.MemoryStore(mode="w")
84-
g = zarr.group(store=store, zarr_format=2)
85-
g.create_array(
86-
name="foo",
87-
shape=(3,),
88-
chunks=(3,),
89-
dtype=dtype,
90-
fill_value=b"X",
91-
)
92-
93-
result = await store.get("foo/.zarray", zarr.core.buffer.default_buffer_prototype())
94-
assert result is not None
95-
96-
serialized = json.loads(result.to_bytes())
97-
expected = {
98-
"chunks": [3],
99-
"dtype": f"{dtype}0",
100-
"fill_value": "WA==",
101-
"filters": None,
102-
"order": "C",
103-
"shape": [3],
104-
"zarr_format": 2,
105-
"dimension_separator": ".",
106-
}
107-
del serialized["compressor"]
108-
assert serialized == expected
83+
with config.set(
84+
{
85+
"v2_dtype_kind_to_default_filters_and_compressor": {
86+
"OSUV": ["vlen-bytes"],
87+
},
88+
}
89+
):
90+
store = zarr.storage.MemoryStore(mode="w")
91+
g = zarr.group(store=store, zarr_format=2)
92+
g.create_array(
93+
name="foo",
94+
shape=(3,),
95+
chunks=(3,),
96+
dtype=dtype,
97+
fill_value=b"X",
98+
)
99+
100+
result = await store.get("foo/.zarray", zarr.core.buffer.default_buffer_prototype())
101+
assert result is not None
102+
103+
serialized = json.loads(result.to_bytes())
104+
expected = {
105+
"chunks": [3],
106+
"compressor": None,
107+
"dtype": f"{dtype}0",
108+
"fill_value": "WA==",
109+
"filters": [{"id": "vlen-bytes"}],
110+
"order": "C",
111+
"shape": [3],
112+
"zarr_format": 2,
113+
"dimension_separator": ".",
114+
}
115+
assert serialized == expected
109116

110-
data = zarr.open_array(store=store, path="foo")[:]
111-
expected = np.full((3,), b"X", dtype=dtype)
112-
np.testing.assert_equal(data, expected)
117+
data = zarr.open_array(store=store, path="foo")[:]
118+
expected = np.full((3,), b"X", dtype=dtype)
119+
np.testing.assert_equal(data, expected)
113120

114121

115122
@pytest.mark.parametrize("dtype", [str, "str"])
@@ -133,18 +140,19 @@ def test_v2_filters_codecs(filters: Any) -> None:
133140

134141

135142
@pytest.mark.parametrize(
136-
"dtype_compressor",
137-
[["b", "zstd"], ["i", "zstd"], ["f", "zstd"], ["|S1", "vlen-bytes"], ["|U1", "vlen-bytes"]],
143+
"dtype_expected",
144+
[["b", "zstd"], ["i", "zstd"], ["f", "zstd"], ["|S1", "vlen-utf8"], ["|U1", "vlen-utf8"]],
138145
)
139-
def test_default_compressors(dtype_compressor: Any) -> None:
146+
def test_default_filters_and_compressor(dtype_expected: Any) -> None:
140147
with config.set(
141148
{
142-
"v2_dtype_kind_to_default_compressor": {
143-
"biufcmM": "zstd",
144-
"OSUV": "vlen-bytes",
149+
"v2_dtype_kind_to_default_filters_and_compressor": {
150+
"biufcmM": ["zstd"],
151+
"OSUV": ["vlen-utf8"],
145152
},
146153
}
147154
):
148-
dtype, expected_compressor = dtype_compressor
155+
dtype, expected = dtype_expected
149156
arr = zarr.create(shape=(10,), path="foo", store={}, zarr_format=2, dtype=dtype)
150-
assert arr.metadata.compressor.codec_id == expected_compressor
157+
assert arr.metadata.filters[0].codec_id == expected
158+
print(arr.metadata)

0 commit comments

Comments
 (0)