Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit 7c121bb

Browse files
Make create_test_datatree a pytest.fixture (#107)
* Migrate create_test_datatree to pytest.fixture * Move create_test_datatree fixture to conftest.py * black * whatsnew Co-authored-by: Thomas Nicholas <[email protected]>
1 parent 28a79d1 commit 7c121bb

File tree

7 files changed

+112
-97
lines changed

7 files changed

+112
-97
lines changed

datatree/tests/conftest.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
import xarray as xr
3+
4+
from datatree import DataTree
5+
6+
7+
@pytest.fixture(scope="module")
8+
def create_test_datatree():
9+
"""
10+
Create a test datatree with this structure:
11+
12+
<datatree.DataTree>
13+
|-- set1
14+
| |-- <xarray.Dataset>
15+
| | Dimensions: ()
16+
| | Data variables:
17+
| | a int64 0
18+
| | b int64 1
19+
| |-- set1
20+
| |-- set2
21+
|-- set2
22+
| |-- <xarray.Dataset>
23+
| | Dimensions: (x: 2)
24+
| | Data variables:
25+
| | a (x) int64 2, 3
26+
| | b (x) int64 0.1, 0.2
27+
| |-- set1
28+
|-- set3
29+
|-- <xarray.Dataset>
30+
| Dimensions: (x: 2, y: 3)
31+
| Data variables:
32+
| a (y) int64 6, 7, 8
33+
| set0 (x) int64 9, 10
34+
35+
The structure has deliberately repeated names of tags, variables, and
36+
dimensions in order to better check for bugs caused by name conflicts.
37+
"""
38+
39+
def _create_test_datatree(modify=lambda ds: ds):
40+
set1_data = modify(xr.Dataset({"a": 0, "b": 1}))
41+
set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}))
42+
root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}))
43+
44+
# Avoid using __init__ so we can independently test it
45+
root = DataTree(data=root_data)
46+
set1 = DataTree(name="set1", parent=root, data=set1_data)
47+
DataTree(name="set1", parent=set1)
48+
DataTree(name="set2", parent=set1)
49+
set2 = DataTree(name="set2", parent=root, data=set2_data)
50+
DataTree(name="set1", parent=set2)
51+
DataTree(name="set3", parent=root)
52+
53+
return root
54+
55+
return _create_test_datatree
56+
57+
58+
@pytest.fixture(scope="module")
59+
def simple_datatree(create_test_datatree):
60+
"""
61+
Invoke create_test_datatree fixture (callback).
62+
63+
Returns a DataTree.
64+
"""
65+
return create_test_datatree()

datatree/tests/test_dataset_api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from datatree import DataTree
55
from datatree.testing import assert_equal
66

7-
from .test_datatree import create_test_datatree
8-
97

108
class TestDSMethodInheritance:
119
def test_dataset_method(self):
@@ -93,7 +91,7 @@ def test_binary_op_on_datatree(self):
9391

9492

9593
class TestUFuncs:
96-
def test_tree(self):
94+
def test_tree(self, create_test_datatree):
9795
dt = create_test_datatree()
9896
expected = create_test_datatree(modify=lambda ds: np.sin(ds))
9997
result_tree = np.sin(dt)

datatree/tests/test_datatree.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,6 @@
55
from datatree import DataTree
66

77

8-
def create_test_datatree(modify=lambda ds: ds):
9-
"""
10-
Create a test datatree with this structure:
11-
12-
<datatree.DataTree>
13-
|-- set1
14-
| |-- <xarray.Dataset>
15-
| | Dimensions: ()
16-
| | Data variables:
17-
| | a int64 0
18-
| | b int64 1
19-
| |-- set1
20-
| |-- set2
21-
|-- set2
22-
| |-- <xarray.Dataset>
23-
| | Dimensions: (x: 2)
24-
| | Data variables:
25-
| | a (x) int64 2, 3
26-
| | b (x) int64 0.1, 0.2
27-
| |-- set1
28-
|-- set3
29-
|-- <xarray.Dataset>
30-
| Dimensions: (x: 2, y: 3)
31-
| Data variables:
32-
| a (y) int64 6, 7, 8
33-
| set0 (x) int64 9, 10
34-
35-
The structure has deliberately repeated names of tags, variables, and
36-
dimensions in order to better check for bugs caused by name conflicts.
37-
"""
38-
set1_data = modify(xr.Dataset({"a": 0, "b": 1}))
39-
set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}))
40-
root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}))
41-
42-
# Avoid using __init__ so we can independently test it
43-
root = DataTree(data=root_data)
44-
set1 = DataTree(name="set1", parent=root, data=set1_data)
45-
DataTree(name="set1", parent=set1)
46-
DataTree(name="set2", parent=set1)
47-
set2 = DataTree(name="set2", parent=root, data=set2_data)
48-
DataTree(name="set1", parent=set2)
49-
DataTree(name="set3", parent=root)
50-
51-
return root
52-
53-
548
class TestTreeCreation:
559
def test_empty(self):
5610
dt = DataTree(name="root")
@@ -322,8 +276,8 @@ def test_nones(self):
322276
assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"]
323277
xrt.assert_equal(dt["d/e"].ds, xr.Dataset())
324278

325-
def test_full(self):
326-
dt = create_test_datatree()
279+
def test_full(self, simple_datatree):
280+
dt = simple_datatree
327281
paths = list(node.path for node in dt.subtree)
328282
assert paths == [
329283
"/",
@@ -335,16 +289,16 @@ def test_full(self):
335289
"/set3",
336290
]
337291

338-
def test_roundtrip(self):
339-
dt = create_test_datatree()
292+
def test_roundtrip(self, simple_datatree):
293+
dt = simple_datatree
340294
roundtrip = DataTree.from_dict(dt.to_dict())
341295
assert roundtrip.equals(dt)
342296

343297
@pytest.mark.xfail
344-
def test_roundtrip_unnamed_root(self):
298+
def test_roundtrip_unnamed_root(self, simple_datatree):
345299
# See GH81
346300

347-
dt = create_test_datatree()
301+
dt = simple_datatree
348302
dt.name = "root"
349303
roundtrip = DataTree.from_dict(dt.to_dict())
350304
assert roundtrip.equals(dt)

datatree/tests/test_formatting.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from datatree import DataTree
66
from datatree.formatting import diff_tree_repr
77

8-
from .test_datatree import create_test_datatree
9-
108

119
class TestRepr:
1210
def test_print_empty_node(self):
@@ -50,8 +48,8 @@ def test_nested_node(self):
5048
printout = root.__str__()
5149
assert printout.splitlines()[2].startswith(" ")
5250

53-
def test_print_datatree(self):
54-
dt = create_test_datatree()
51+
def test_print_datatree(self, simple_datatree):
52+
dt = simple_datatree
5553
print(dt)
5654

5755
# TODO work out how to test something complex like this

datatree/tests/test_io.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,26 @@
33
from datatree.io import open_datatree
44
from datatree.testing import assert_equal
55
from datatree.tests import requires_h5netcdf, requires_netCDF4, requires_zarr
6-
from datatree.tests.test_datatree import create_test_datatree
76

87

98
class TestIO:
109
@requires_netCDF4
11-
def test_to_netcdf(self, tmpdir):
10+
def test_to_netcdf(self, tmpdir, simple_datatree):
1211
filepath = str(
1312
tmpdir / "test.nc"
1413
) # casting to str avoids a pathlib bug in xarray
15-
original_dt = create_test_datatree()
14+
original_dt = simple_datatree
1615
original_dt.to_netcdf(filepath, engine="netcdf4")
1716

1817
roundtrip_dt = open_datatree(filepath)
1918
assert_equal(original_dt, roundtrip_dt)
2019

2120
@requires_netCDF4
22-
def test_netcdf_encoding(self, tmpdir):
21+
def test_netcdf_encoding(self, tmpdir, simple_datatree):
2322
filepath = str(
2423
tmpdir / "test.nc"
2524
) # casting to str avoids a pathlib bug in xarray
26-
original_dt = create_test_datatree()
25+
original_dt = simple_datatree
2726

2827
# add compression
2928
comp = dict(zlib=True, complevel=9)
@@ -40,35 +39,35 @@ def test_netcdf_encoding(self, tmpdir):
4039
original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4")
4140

4241
@requires_h5netcdf
43-
def test_to_h5netcdf(self, tmpdir):
42+
def test_to_h5netcdf(self, tmpdir, simple_datatree):
4443
filepath = str(
4544
tmpdir / "test.nc"
4645
) # casting to str avoids a pathlib bug in xarray
47-
original_dt = create_test_datatree()
46+
original_dt = simple_datatree
4847
original_dt.to_netcdf(filepath, engine="h5netcdf")
4948

5049
roundtrip_dt = open_datatree(filepath)
5150
assert_equal(original_dt, roundtrip_dt)
5251

5352
@requires_zarr
54-
def test_to_zarr(self, tmpdir):
53+
def test_to_zarr(self, tmpdir, simple_datatree):
5554
filepath = str(
5655
tmpdir / "test.zarr"
5756
) # casting to str avoids a pathlib bug in xarray
58-
original_dt = create_test_datatree()
57+
original_dt = simple_datatree
5958
original_dt.to_zarr(filepath)
6059

6160
roundtrip_dt = open_datatree(filepath, engine="zarr")
6261
assert_equal(original_dt, roundtrip_dt)
6362

6463
@requires_zarr
65-
def test_zarr_encoding(self, tmpdir):
64+
def test_zarr_encoding(self, tmpdir, simple_datatree):
6665
import zarr
6766

6867
filepath = str(
6968
tmpdir / "test.zarr"
7069
) # casting to str avoids a pathlib bug in xarray
71-
original_dt = create_test_datatree()
70+
original_dt = simple_datatree
7271

7372
comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
7473
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
@@ -83,26 +82,26 @@ def test_zarr_encoding(self, tmpdir):
8382
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")
8483

8584
@requires_zarr
86-
def test_to_zarr_zip_store(self, tmpdir):
85+
def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
8786
from zarr.storage import ZipStore
8887

8988
filepath = str(
9089
tmpdir / "test.zarr.zip"
9190
) # casting to str avoids a pathlib bug in xarray
92-
original_dt = create_test_datatree()
91+
original_dt = simple_datatree
9392
store = ZipStore(filepath)
9493
original_dt.to_zarr(store)
9594

9695
roundtrip_dt = open_datatree(store, engine="zarr")
9796
assert_equal(original_dt, roundtrip_dt)
9897

9998
@requires_zarr
100-
def test_to_zarr_not_consolidated(self, tmpdir):
99+
def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
101100
filepath = tmpdir / "test.zarr"
102101
zmetadata = filepath / ".zmetadata"
103102
s1zmetadata = filepath / "set1" / ".zmetadata"
104103
filepath = str(filepath) # casting to str avoids a pathlib bug in xarray
105-
original_dt = create_test_datatree()
104+
original_dt = simple_datatree
106105
original_dt.to_zarr(filepath, consolidated=False)
107106
assert not zmetadata.exists()
108107
assert not s1zmetadata.exists()

0 commit comments

Comments
 (0)