Skip to content

Commit 91f77ae

Browse files
committed
initial roundtrip
1 parent 122760f commit 91f77ae

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ test = [
6767
"mypy",
6868
"hypothesis",
6969
"universal-pathlib",
70+
"xarray",
7071
]
7172

7273
jupyter = [

tests/test_xarray.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import string
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
import xarray as xr
7+
8+
import zarr
9+
10+
_DEFAULT_TEST_DIM_SIZES = (8, 9, 10)
11+
12+
13+
@pytest.fixture
14+
def store() -> zarr.abc.store.Store:
15+
return zarr.storage.MemoryStore()
16+
17+
18+
@pytest.fixture
19+
def dataset(
20+
seed: int = 12345,
21+
add_attrs: bool = True,
22+
dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES,
23+
use_extension_array: bool = False,
24+
) -> xr.Dataset:
25+
rs = np.random.default_rng(seed)
26+
_vars = {
27+
"var1": ["dim1", "dim2"],
28+
"var2": ["dim1", "dim2"],
29+
"var3": ["dim3", "dim1"],
30+
}
31+
_dims = {"dim1": dim_sizes[0], "dim2": dim_sizes[1], "dim3": dim_sizes[2]}
32+
33+
obj = xr.Dataset()
34+
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
35+
if _dims["dim3"] > 26:
36+
raise RuntimeError(f'Not enough letters for filling this dimension size ({_dims["dim3"]})')
37+
obj["dim3"] = ("dim3", list(string.ascii_lowercase[0 : _dims["dim3"]]))
38+
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
39+
for v, dims in sorted(_vars.items()):
40+
data = rs.normal(size=tuple(_dims[d] for d in dims))
41+
obj[v] = (dims, data)
42+
if add_attrs:
43+
obj[v].attrs = {"foo": "variable"}
44+
if use_extension_array:
45+
obj["var4"] = (
46+
"dim1",
47+
pd.Categorical(
48+
rs.choice(
49+
list(string.ascii_lowercase[: rs.integers(1, 5)]),
50+
size=dim_sizes[0],
51+
)
52+
),
53+
)
54+
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
55+
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
56+
else:
57+
numbers_values = rs.integers(0, 3, _dims["dim3"], dtype="int64")
58+
obj.coords["numbers"] = ("dim3", numbers_values)
59+
obj.encoding = {"foo": "bar"}
60+
return obj
61+
62+
63+
def test_roundtrip(store: zarr.abc.store.Store, dataset: xr.Dataset) -> None:
64+
dataset.to_zarr(store)
65+
other_dataset = xr.open_dataset(store)
66+
assert dataset.identical(other_dataset)

0 commit comments

Comments
 (0)