Skip to content

Commit 7806563

Browse files
committed
add regression testing against v2.18
1 parent d80d565 commit 7806563

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

tests/test_regression/__init__.py

Whitespace-only changes.
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import subprocess
2+
from dataclasses import asdict, dataclass
3+
from itertools import product
4+
from pathlib import Path
5+
6+
import numcodecs
7+
import numpy as np
8+
import pytest
9+
from numcodecs import LZ4, LZMA, Blosc, GZip, VLenUTF8, Zstd
10+
11+
import zarr
12+
from zarr.core.array import Array
13+
from zarr.core.dtype.npy.string import VariableLengthString
14+
from zarr.core.metadata.v2 import ArrayV2Metadata
15+
from zarr.storage import LocalStore
16+
17+
18+
def runner_installed() -> bool:
19+
try:
20+
subprocess.check_output(["uv", "--version"])
21+
return True
22+
except FileNotFoundError:
23+
return False
24+
25+
26+
def array_metadata_equals(a: ArrayV2Metadata, b: ArrayV2Metadata) -> bool:
27+
dict_a, dict_b = asdict(a), asdict(b)
28+
fill_value_a, fill_value_b = dict_a.pop("fill_value"), dict_b.pop("fill_value")
29+
if (
30+
isinstance(fill_value_a, float)
31+
and isinstance(fill_value_b, float)
32+
and np.isnan(fill_value_a)
33+
and np.isnan(fill_value_b)
34+
):
35+
return dict_a == dict_b
36+
else:
37+
return fill_value_a == fill_value_b and dict_a == dict_b
38+
39+
40+
@dataclass(kw_only=True)
41+
class ArrayParams:
42+
values: np.ndarray[tuple[int], np.dtype[np.generic]]
43+
fill_value: np.generic | str
44+
compressor: numcodecs.abc.Codec
45+
46+
47+
basic_codecs = GZip(), Blosc(), LZ4(), LZMA(), Zstd()
48+
basic_dtypes = "|b", ">i2", ">i4", ">f4", ">f8", "<f4", "<f8", ">c8", "<c8", ">c16", "<c16"
49+
datetime_dtypes = "<M8[10ns]", ">M8[10us]", "<m8[2ms]", ">m8[4ps]"
50+
string_dtypes = ">S1", "<S4", "<U1", ">U4"
51+
52+
basic_array_cases = [
53+
ArrayParams(values=np.arange(4, dtype=dtype), fill_value=1, compressor=codec)
54+
for codec, dtype in product(basic_codecs, basic_dtypes)
55+
]
56+
datetime_array_cases = [
57+
ArrayParams(values=np.ones((4,), dtype=dtype), fill_value=1, compressor=codec)
58+
for codec, dtype in product(basic_codecs, datetime_dtypes)
59+
]
60+
string_array_cases = [
61+
ArrayParams(
62+
values=np.array(["aaaa", "bbbb", "ccccc", "dddd"], dtype=dtype),
63+
fill_value="foo",
64+
compressor=codec,
65+
)
66+
for codec, dtype in product(basic_codecs, string_dtypes)
67+
]
68+
vlen_string_cases = [
69+
ArrayParams(
70+
values=np.array(["a", "bb", "ccc", "dddd"], dtype="O"),
71+
fill_value="1",
72+
compressor=VLenUTF8(),
73+
)
74+
]
75+
array_cases = basic_array_cases + datetime_array_cases + string_array_cases + vlen_string_cases
76+
77+
78+
@pytest.fixture
79+
def source_array(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
80+
dest = tmp_path / "in"
81+
store = LocalStore(dest)
82+
array_params: ArrayParams = request.param
83+
compressor = array_params.compressor
84+
if array_params.values.dtype == np.dtype("|O"):
85+
dtype = VariableLengthString()
86+
else:
87+
dtype = array_params.values.dtype
88+
z = zarr.create_array(
89+
store,
90+
shape=array_params.values.shape,
91+
dtype=dtype,
92+
chunks=array_params.values.shape,
93+
compressors=compressor,
94+
fill_value=array_params.fill_value,
95+
order="C",
96+
filters=None,
97+
chunk_key_encoding={"name": "v2", "configuration": {"separator": "/"}},
98+
write_data=True,
99+
zarr_format=2,
100+
)
101+
z[:] = array_params.values
102+
return z
103+
104+
105+
@pytest.mark.skipif(not runner_installed(), reason="no python script runner installed")
106+
@pytest.mark.parametrize(
107+
"source_array", array_cases, indirect=True, ids=tuple(map(str, array_cases))
108+
)
109+
def test_roundtrip(source_array: Array, tmp_path: Path) -> None:
110+
out_path = tmp_path / "out"
111+
copy_op = subprocess.run(
112+
[
113+
"uv",
114+
"run",
115+
Path(__file__).resolve().parent / "v2.18.py",
116+
str(source_array.store).removeprefix("file://"),
117+
str(out_path),
118+
],
119+
capture_output=True,
120+
text=True,
121+
)
122+
assert copy_op.returncode == 0
123+
out_array = zarr.open_array(store=out_path, mode="r", zarr_format=2)
124+
assert array_metadata_equals(source_array.metadata, out_array.metadata)
125+
assert np.array_equal(source_array[:], out_array[:])

tests/test_regression/v2.18.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# /// script
2+
# requires-python = ">=3.11"
3+
# dependencies = [
4+
# "zarr==2.18",
5+
# "numcodecs==0.15"
6+
# ]
7+
# ///
8+
9+
import argparse
10+
11+
import zarr
12+
from zarr._storage.store import BaseStore
13+
14+
15+
def copy_group(
16+
*, node: zarr.hierarchy.Group, store: zarr.storage.BaseStore, path: str, overwrite: bool
17+
) -> zarr.hierarchy.Group:
18+
result = zarr.group(store=store, path=path, overwrite=overwrite)
19+
result.attrs.put(node.attrs.asdict())
20+
for key, child in node.items():
21+
child_path = f"{path}/{key}"
22+
if isinstance(child, zarr.hierarchy.Group):
23+
copy_group(node=child, store=store, path=child_path, overwrite=overwrite)
24+
elif isinstance(child, zarr.core.Array):
25+
copy_array(node=child, store=store, overwrite=overwrite, path=child_path)
26+
return result
27+
28+
29+
def copy_array(
30+
*, node: zarr.core.Array, store: BaseStore, path: str, overwrite: bool
31+
) -> zarr.core.Array:
32+
result = zarr.create(
33+
shape=node.shape,
34+
dtype=node.dtype,
35+
fill_value=node.fill_value,
36+
chunks=node.chunks,
37+
compressor=node.compressor,
38+
filters=node.filters,
39+
order=node.order,
40+
dimension_separator=node._dimension_separator,
41+
store=store,
42+
path=path,
43+
overwrite=overwrite,
44+
)
45+
result.attrs.put(node.attrs.asdict())
46+
result[:] = node[:]
47+
return result
48+
49+
50+
def copy_node(
51+
node: zarr.hierarchy.Group | zarr.core.Array, store: BaseStore, path: str, overwrite: bool
52+
) -> zarr.hierarchy.Group | zarr.core.Array:
53+
if isinstance(node, zarr.hierarchy.Group):
54+
return copy_group(node=node, store=store, path=path, overwrite=overwrite)
55+
elif isinstance(node, zarr.core.Array):
56+
return copy_array(node=node, store=store, path=path, overwrite=overwrite)
57+
else:
58+
raise TypeError(f"Unexpected node type: {type(node)}") # pragma: no cover
59+
60+
61+
def cli() -> None:
62+
parser = argparse.ArgumentParser(
63+
description="Copy a zarr hierarchy from one location to another"
64+
)
65+
parser.add_argument("source", type=str, help="Path to the source zarr hierarchy")
66+
parser.add_argument("destination", type=str, help="Path to the destination zarr hierarchy")
67+
args = parser.parse_args()
68+
69+
src, dst = args.source, args.destination
70+
root_src = zarr.open(src, mode="r")
71+
result = copy_node(node=root_src, store=zarr.NestedDirectoryStore(dst), path="", overwrite=True)
72+
73+
print(f"successfully created {result} at {dst}")
74+
75+
76+
def main() -> None:
77+
cli()
78+
79+
80+
if __name__ == "__main__":
81+
main()

0 commit comments

Comments
 (0)