|
| 1 | +import subprocess |
| 2 | +from dataclasses import asdict, dataclass |
| 3 | +from itertools import product |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +import numcodecs |
| 7 | +import numpy as np |
| 8 | +import pytest |
| 9 | +from numcodecs import LZ4, LZMA, Blosc, GZip, VLenUTF8, Zstd |
| 10 | + |
| 11 | +import zarr |
| 12 | +from zarr.core.array import Array |
| 13 | +from zarr.core.dtype.npy.string import VariableLengthString |
| 14 | +from zarr.core.metadata.v2 import ArrayV2Metadata |
| 15 | +from zarr.storage import LocalStore |
| 16 | + |
| 17 | + |
| 18 | +def runner_installed() -> bool: |
| 19 | + try: |
| 20 | + subprocess.check_output(["uv", "--version"]) |
| 21 | + return True |
| 22 | + except FileNotFoundError: |
| 23 | + return False |
| 24 | + |
| 25 | + |
| 26 | +def array_metadata_equals(a: ArrayV2Metadata, b: ArrayV2Metadata) -> bool: |
| 27 | + dict_a, dict_b = asdict(a), asdict(b) |
| 28 | + fill_value_a, fill_value_b = dict_a.pop("fill_value"), dict_b.pop("fill_value") |
| 29 | + if ( |
| 30 | + isinstance(fill_value_a, float) |
| 31 | + and isinstance(fill_value_b, float) |
| 32 | + and np.isnan(fill_value_a) |
| 33 | + and np.isnan(fill_value_b) |
| 34 | + ): |
| 35 | + return dict_a == dict_b |
| 36 | + else: |
| 37 | + return fill_value_a == fill_value_b and dict_a == dict_b |
| 38 | + |
| 39 | + |
| 40 | +@dataclass(kw_only=True) |
| 41 | +class ArrayParams: |
| 42 | + values: np.ndarray[tuple[int], np.dtype[np.generic]] |
| 43 | + fill_value: np.generic | str |
| 44 | + compressor: numcodecs.abc.Codec |
| 45 | + |
| 46 | + |
| 47 | +basic_codecs = GZip(), Blosc(), LZ4(), LZMA(), Zstd() |
| 48 | +basic_dtypes = "|b", ">i2", ">i4", ">f4", ">f8", "<f4", "<f8", ">c8", "<c8", ">c16", "<c16" |
| 49 | +datetime_dtypes = "<M8[10ns]", ">M8[10us]", "<m8[2ms]", ">m8[4ps]" |
| 50 | +string_dtypes = ">S1", "<S4", "<U1", ">U4" |
| 51 | + |
| 52 | +basic_array_cases = [ |
| 53 | + ArrayParams(values=np.arange(4, dtype=dtype), fill_value=1, compressor=codec) |
| 54 | + for codec, dtype in product(basic_codecs, basic_dtypes) |
| 55 | +] |
| 56 | +datetime_array_cases = [ |
| 57 | + ArrayParams(values=np.ones((4,), dtype=dtype), fill_value=1, compressor=codec) |
| 58 | + for codec, dtype in product(basic_codecs, datetime_dtypes) |
| 59 | +] |
| 60 | +string_array_cases = [ |
| 61 | + ArrayParams( |
| 62 | + values=np.array(["aaaa", "bbbb", "ccccc", "dddd"], dtype=dtype), |
| 63 | + fill_value="foo", |
| 64 | + compressor=codec, |
| 65 | + ) |
| 66 | + for codec, dtype in product(basic_codecs, string_dtypes) |
| 67 | +] |
| 68 | +vlen_string_cases = [ |
| 69 | + ArrayParams( |
| 70 | + values=np.array(["a", "bb", "ccc", "dddd"], dtype="O"), |
| 71 | + fill_value="1", |
| 72 | + compressor=VLenUTF8(), |
| 73 | + ) |
| 74 | +] |
| 75 | +array_cases = basic_array_cases + datetime_array_cases + string_array_cases + vlen_string_cases |
| 76 | + |
| 77 | + |
| 78 | +@pytest.fixture |
| 79 | +def source_array(tmp_path: Path, request: pytest.FixtureRequest) -> Array: |
| 80 | + dest = tmp_path / "in" |
| 81 | + store = LocalStore(dest) |
| 82 | + array_params: ArrayParams = request.param |
| 83 | + compressor = array_params.compressor |
| 84 | + if array_params.values.dtype == np.dtype("|O"): |
| 85 | + dtype = VariableLengthString() |
| 86 | + else: |
| 87 | + dtype = array_params.values.dtype |
| 88 | + z = zarr.create_array( |
| 89 | + store, |
| 90 | + shape=array_params.values.shape, |
| 91 | + dtype=dtype, |
| 92 | + chunks=array_params.values.shape, |
| 93 | + compressors=compressor, |
| 94 | + fill_value=array_params.fill_value, |
| 95 | + order="C", |
| 96 | + filters=None, |
| 97 | + chunk_key_encoding={"name": "v2", "configuration": {"separator": "/"}}, |
| 98 | + write_data=True, |
| 99 | + zarr_format=2, |
| 100 | + ) |
| 101 | + z[:] = array_params.values |
| 102 | + return z |
| 103 | + |
| 104 | + |
| 105 | +@pytest.mark.skipif(not runner_installed(), reason="no python script runner installed") |
| 106 | +@pytest.mark.parametrize( |
| 107 | + "source_array", array_cases, indirect=True, ids=tuple(map(str, array_cases)) |
| 108 | +) |
| 109 | +def test_roundtrip(source_array: Array, tmp_path: Path) -> None: |
| 110 | + out_path = tmp_path / "out" |
| 111 | + copy_op = subprocess.run( |
| 112 | + [ |
| 113 | + "uv", |
| 114 | + "run", |
| 115 | + Path(__file__).resolve().parent / "v2.18.py", |
| 116 | + str(source_array.store).removeprefix("file://"), |
| 117 | + str(out_path), |
| 118 | + ], |
| 119 | + capture_output=True, |
| 120 | + text=True, |
| 121 | + ) |
| 122 | + assert copy_op.returncode == 0 |
| 123 | + out_array = zarr.open_array(store=out_path, mode="r", zarr_format=2) |
| 124 | + assert array_metadata_equals(source_array.metadata, out_array.metadata) |
| 125 | + assert np.array_equal(source_array[:], out_array[:]) |
0 commit comments