Skip to content

Commit 1fa42d9

Browse files
committed
add default compressor to config
1 parent 680142f commit 1fa42d9

File tree

4 files changed

+46
-5
lines changed

4 files changed

+46
-5
lines changed

src/zarr/core/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def reset(self) -> None:
6464
},
6565
"buffer": "zarr.core.buffer.cpu.Buffer",
6666
"ndbuffer": "zarr.core.buffer.cpu.NDBuffer",
67+
"v2_dtype_kind_to_default_compressor": {
68+
"biufcmM": "zstd",
69+
"OSUV": "vlen-bytes",
70+
},
6771
}
6872
],
6973
)

src/zarr/core/metadata/v2.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def __init__(
7171
shape_parsed = parse_shapelike(shape)
7272
dtype_parsed = parse_dtype(dtype)
7373
chunks_parsed = parse_shapelike(chunks)
74+
if compressor is None:
75+
compressor = _default_compressor(dtype_parsed)
7476
compressor_parsed = parse_compressor(compressor)
7577
order_parsed = parse_indexing_order(order)
7678
dimension_separator_parsed = parse_separator(dimension_separator)
@@ -238,15 +240,15 @@ def parse_filters(data: object) -> tuple[numcodecs.abc.Codec, ...] | None:
238240
raise TypeError(msg)
239241

240242

241-
def parse_compressor(data: object) -> numcodecs.abc.Codec | None:
243+
def parse_compressor(data: object) -> numcodecs.abc.Codec:
242244
"""
243245
Parse a potential compressor.
244246
"""
245-
if data is None or isinstance(data, numcodecs.abc.Codec):
247+
if isinstance(data, numcodecs.abc.Codec):
246248
return data
247249
if isinstance(data, dict):
248250
return numcodecs.get_codec(data)
249-
msg = f"Invalid compressor. Expected None, a numcodecs.abc.Codec, or a dict representation of a numcodecs.abc.Codec. Got {type(data)} instead."
251+
msg = f"Invalid compressor. Expected a numcodecs.abc.Codec, or a dict representation of a numcodecs.abc.Codec. Got {type(data)} instead."
250252
raise ValueError(msg)
251253

252254

@@ -326,3 +328,16 @@ def _default_fill_value(dtype: np.dtype[Any]) -> Any:
326328
return ""
327329
else:
328330
return dtype.type(0)
331+
332+
333+
def _default_compressor(dtype: np.dtype[Any]) -> numcodecs.abc.Codec:
334+
"""Get the default compressor for a type.
335+
336+
The config contains a mapping from numpy dtype kind to the default compressor.
337+
https://numpy.org/doc/2.1/reference/generated/numpy.dtype.kind.html
338+
"""
339+
dtype_kind_to_default_compressor = config.get("v2_dtype_kind_to_default_compressor")
340+
for dtype_kinds, compressor in dtype_kind_to_default_compressor.items():
341+
if dtype.kind in dtype_kinds:
342+
return numcodecs.get_codec({"id": compressor})
343+
raise ValueError(f"No default compressor found for dtype {dtype} of kind {dtype.kind}")

tests/test_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def test_config_defaults_set() -> None:
6363
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
6464
"vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec",
6565
},
66+
"v2_dtype_kind_to_default_compressor": {
67+
"biufcmM": "zstd",
68+
"OSUV": "vlen-bytes",
69+
},
6670
}
6771
]
6872
assert config.get("array.order") == "C"

tests/test_v2.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import zarr
1212
import zarr.storage
13-
from zarr import Array
13+
from zarr import Array, config
1414
from zarr.storage import MemoryStore, StorePath
1515

1616

@@ -96,7 +96,6 @@ async def test_v2_encode_decode(dtype):
9696
serialized = json.loads(result.to_bytes())
9797
expected = {
9898
"chunks": [3],
99-
"compressor": None,
10099
"dtype": f"{dtype}0",
101100
"fill_value": "WA==",
102101
"filters": None,
@@ -105,6 +104,7 @@ async def test_v2_encode_decode(dtype):
105104
"zarr_format": 2,
106105
"dimension_separator": ".",
107106
}
107+
del serialized["compressor"]
108108
assert serialized == expected
109109

110110
data = zarr.open_array(store=store, path="foo")[:]
@@ -130,3 +130,21 @@ def test_v2_filters_codecs(filters: Any) -> None:
130130
arr[:] = array_fixture
131131
result = arr[:]
132132
np.testing.assert_array_equal(result, array_fixture)
133+
134+
135+
@pytest.mark.parametrize(
136+
"dtype_compressor",
137+
[["b", "zstd"], ["i", "zstd"], ["f", "zstd"], ["|S1", "vlen-bytes"], ["|U1", "vlen-bytes"]],
138+
)
139+
def test_default_compressors(dtype_compressor: Any) -> None:
140+
with config.set(
141+
{
142+
"v2_dtype_kind_to_default_compressor": {
143+
"biufcmM": "zstd",
144+
"OSUV": "vlen-bytes",
145+
},
146+
}
147+
):
148+
dtype, expected_compressor = dtype_compressor
149+
arr = zarr.create(shape=(10,), path="foo", store={}, zarr_format=2, dtype=dtype)
150+
assert arr.metadata.compressor.codec_id == expected_compressor

0 commit comments

Comments
 (0)