Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit a35218c

Browse files
Joe HammanJustin Magers
andauthored
Fix netcdf encoding (#95)
* the fix that didn't fix * further work on the netcdf encoding issue * use set2 * check for invalid groups in to_zarr * add tests and comments for the future Co-authored-by: Justin Magers <[email protected]>
1 parent 0db84ed commit a35218c

File tree

2 files changed

+62
-12
lines changed

2 files changed

+62
-12
lines changed

datatree/io.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,6 @@ def _open_datatree_zarr(store, **kwargs) -> DataTree:
105105
return tree_root
106106

107107

108-
def _maybe_extract_group_kwargs(enc, group):
109-
try:
110-
return enc[group]
111-
except KeyError:
112-
return None
113-
114-
115108
def _create_empty_netcdf_group(filename, group, mode, engine):
116109
ncDataset = _get_nc_dataset_class(engine)
117110

@@ -146,6 +139,14 @@ def _datatree_to_netcdf(
146139
if encoding is None:
147140
encoding = {}
148141

142+
# In the future, we may want to expand this check to insure all the provided encoding
143+
# options are valid. For now, this simply checks that all provided encoding keys are
144+
# groups in the datatree.
145+
if set(encoding) - set(dt.groups):
146+
raise ValueError(
147+
f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}"
148+
)
149+
149150
if unlimited_dims is None:
150151
unlimited_dims = {}
151152

@@ -155,16 +156,15 @@ def _datatree_to_netcdf(
155156
if ds is None:
156157
_create_empty_netcdf_group(filepath, group_path, mode, engine)
157158
else:
158-
159159
ds.to_netcdf(
160160
filepath,
161161
group=group_path,
162162
mode=mode,
163-
encoding=_maybe_extract_group_kwargs(encoding, dt.path),
164-
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.path),
163+
encoding=encoding.get(node.path),
164+
unlimited_dims=unlimited_dims.get(node.path),
165165
**kwargs,
166166
)
167-
mode = "a"
167+
mode = "r+"
168168

169169

170170
def _create_empty_zarr_group(store, group, mode):
@@ -196,6 +196,14 @@ def _datatree_to_zarr(
196196
if encoding is None:
197197
encoding = {}
198198

199+
# In the future, we may want to expand this check to insure all the provided encoding
200+
# options are valid. For now, this simply checks that all provided encoding keys are
201+
# groups in the datatree.
202+
if set(encoding) - set(dt.groups):
203+
raise ValueError(
204+
f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}"
205+
)
206+
199207
for node in dt.subtree:
200208
ds = node.ds
201209
group_path = node.path
@@ -206,7 +214,7 @@ def _datatree_to_zarr(
206214
store,
207215
group=group_path,
208216
mode=mode,
209-
encoding=_maybe_extract_group_kwargs(encoding, dt.path),
217+
encoding=encoding.get(node.path),
210218
consolidated=False,
211219
**kwargs,
212220
)

datatree/tests/test_io.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,27 @@ def test_to_netcdf(self, tmpdir):
1818
roundtrip_dt = open_datatree(filepath)
1919
assert_equal(original_dt, roundtrip_dt)
2020

21+
@requires_netCDF4
22+
def test_netcdf_encoding(self, tmpdir):
23+
filepath = str(
24+
tmpdir / "test.nc"
25+
) # casting to str avoids a pathlib bug in xarray
26+
original_dt = create_test_datatree()
27+
28+
# add compression
29+
comp = dict(zlib=True, complevel=9)
30+
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
31+
32+
original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4")
33+
roundtrip_dt = open_datatree(filepath)
34+
35+
assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
36+
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]
37+
38+
enc["/not/a/group"] = {"foo": "bar"}
39+
with pytest.raises(ValueError, match="unexpected encoding group.*"):
40+
original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4")
41+
2142
@requires_h5netcdf
2243
def test_to_h5netcdf(self, tmpdir):
2344
filepath = str(
@@ -40,6 +61,27 @@ def test_to_zarr(self, tmpdir):
4061
roundtrip_dt = open_datatree(filepath, engine="zarr")
4162
assert_equal(original_dt, roundtrip_dt)
4263

64+
@requires_zarr
65+
def test_zarr_encoding(self, tmpdir):
66+
import zarr
67+
68+
filepath = str(
69+
tmpdir / "test.zarr"
70+
) # casting to str avoids a pathlib bug in xarray
71+
original_dt = create_test_datatree()
72+
73+
comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
74+
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
75+
original_dt.to_zarr(filepath, encoding=enc)
76+
roundtrip_dt = open_datatree(filepath, engine="zarr")
77+
78+
print(roundtrip_dt["/set2/a"].encoding)
79+
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]
80+
81+
enc["/not/a/group"] = {"foo": "bar"}
82+
with pytest.raises(ValueError, match="unexpected encoding group.*"):
83+
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")
84+
4385
@requires_zarr
4486
def test_to_zarr_zip_store(self, tmpdir):
4587
from zarr.storage import ZipStore

0 commit comments

Comments
 (0)