Skip to content

Commit bd5ce89

Browse files
authored
Support writing xarray.DataTree to Icechunk store (#538)
Fixes #244
1 parent 7a97287 commit bd5ce89

File tree

6 files changed

+315
-27
lines changed

6 files changed

+315
-27
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Serialization
3030

3131
VirtualiZarrDatasetAccessor.to_kerchunk
3232
VirtualiZarrDatasetAccessor.to_icechunk
33+
VirtualiZarrDataTreeAccessor.to_icechunk
3334

3435
Information
3536
-----------

docs/releases.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ New Features
1414
By `Tom Nicholas <https://github.com/TomNicholas>`_.
1515
- Added experimental :py:func:`open_virtual_mfdataset` function (:issue:`345`, :pull:`349`).
1616
By `Tom Nicholas <https://github.com/TomNicholas>`_.
17+
- Added :py:func:`datatree_to_icechunk` function for writing an ``xarray.DataTree`` to
18+
an Icechunk store (:issue:`244`). By `Chuck Daniels <https://github.com/chuckwondo>`_.
19+
- Added a ``.virtualize`` custom accessor to ``xarray.DataTree``, exposing the method
20+
:py:meth:`xarray.DataTree.virtualize.to_icechunk()` for writing an ``xarray.DataTree``
21+
to an Icechunk store (:issue:`244`). By
22+
`Chuck Daniels <https://github.com/chuckwondo>`_.
1723

1824
Breaking changes
1925
~~~~~~~~~~~~~~~~

virtualizarr/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1-
from virtualizarr.manifests import ChunkManifest, ManifestArray # type: ignore # noqa
2-
from virtualizarr.accessor import VirtualiZarrDatasetAccessor # type: ignore # noqa
3-
from virtualizarr.backend import open_virtual_dataset, open_virtual_mfdataset # noqa: F401
4-
51
from importlib.metadata import version as _version
62

3+
from virtualizarr.accessor import (
4+
VirtualiZarrDatasetAccessor,
5+
VirtualiZarrDataTreeAccessor,
6+
)
7+
from virtualizarr.backend import open_virtual_dataset, open_virtual_mfdataset
8+
from virtualizarr.manifests import ChunkManifest, ManifestArray
9+
710
try:
811
__version__ = _version("virtualizarr")
912
except Exception:
1013
# Local copy or not installed with setuptools.
1114
# Disable minimum version checks on downstream libraries.
1215
__version__ = "9999"
16+
17+
__all__ = [
18+
"ChunkManifest",
19+
"ManifestArray",
20+
"VirtualiZarrDatasetAccessor",
21+
"VirtualiZarrDataTreeAccessor",
22+
"open_virtual_dataset",
23+
"open_virtual_mfdataset",
24+
]

virtualizarr/accessor.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from typing import TYPE_CHECKING, Callable, Literal, overload
44

5-
from xarray import Dataset, register_dataset_accessor
5+
import xarray as xr
66

77
from virtualizarr.manifests import ManifestArray
88
from virtualizarr.types.kerchunk import KerchunkStoreRefs
@@ -12,16 +12,16 @@
1212
from icechunk import IcechunkStore # type: ignore[import-not-found]
1313

1414

15-
@register_dataset_accessor("virtualize")
15+
@xr.register_dataset_accessor("virtualize")
1616
class VirtualiZarrDatasetAccessor:
1717
"""
1818
Xarray accessor for writing out virtual datasets to disk.
1919
2020
Methods on this object are called via `ds.virtualize.{method}`.
2121
"""
2222

23-
def __init__(self, ds: Dataset):
24-
self.ds: Dataset = ds
23+
def __init__(self, ds: xr.Dataset):
24+
self.ds: xr.Dataset = ds
2525

2626
def to_icechunk(
2727
self,
@@ -171,7 +171,7 @@ def to_kerchunk(
171171
def rename_paths(
172172
self,
173173
new: str | Callable[[str], str],
174-
) -> Dataset:
174+
) -> xr.Dataset:
175175
"""
176176
Rename paths to chunks in every ManifestArray in this dataset.
177177
@@ -232,3 +232,76 @@ def nbytes(self) -> int:
232232
else var.nbytes
233233
for var in self.ds.variables.values()
234234
)
235+
236+
237+
@xr.register_datatree_accessor("virtualize")
238+
class VirtualiZarrDataTreeAccessor:
239+
"""
240+
Xarray accessor for writing out virtual datatrees to disk.
241+
242+
Methods on this object are called via `dt.virtualize.{method}`.
243+
"""
244+
245+
def __init__(self, dt: xr.DataTree):
246+
self.dt = dt
247+
248+
def to_icechunk(
249+
self,
250+
store: "IcechunkStore",
251+
*,
252+
write_inherited_coords: bool = False,
253+
last_updated_at: datetime | None = None,
254+
) -> None:
255+
"""
256+
Write an xarray DataTree to an Icechunk store.
257+
258+
Any variables backed by ManifestArray objects will be be written as virtual
259+
references. Any other variables will be loaded into memory before their binary
260+
chunk data is written into the store.
261+
262+
If ``last_updated_at`` is provided, it will be used as a checksum for any
263+
virtual chunks written to the store with this operation. At read time, if any
264+
of the virtual chunks have been updated since this provided datetime, an error
265+
will be raised. This protects against reading outdated virtual chunks that have
266+
been updated since the last read. When not provided, no check is performed.
267+
This value is stored in Icechunk with seconds precision, so be sure to take that
268+
into account when providing this value.
269+
270+
Parameters
271+
----------
272+
store: IcechunkStore
273+
Store to write dataset into.
274+
write_inherited_coords : bool, default: False
275+
If ``True``, replicate inherited coordinates on all descendant nodes.
276+
Otherwise, only write coordinates at the level at which they are
277+
originally defined. This saves disk space, but requires opening the
278+
full tree to load inherited coordinates.
279+
last_updated_at: datetime, optional
280+
Datetime to use as a checksum for any virtual chunks written to the store
281+
with this operation. When not provided, no check is performed.
282+
283+
Raises
284+
------
285+
ValueError
286+
If the store is read-only.
287+
288+
Examples
289+
--------
290+
To ensure an error is raised if the files containing referenced virtual chunks
291+
are modified at any time from now on, pass the current time to
292+
``last_updated_at``.
293+
294+
>>> from datetime import datetime
295+
>>> vdt.virtualize.to_icechunk( # doctest: +SKIP
296+
... icechunkstore,
297+
... last_updated_at=datetime.now(),
298+
... )
299+
"""
300+
from virtualizarr.writers.icechunk import datatree_to_icechunk
301+
302+
datatree_to_icechunk(
303+
self.dt,
304+
store,
305+
write_inherited_coords=write_inherited_coords,
306+
last_updated_at=last_updated_at,
307+
)

virtualizarr/tests/test_integration.py

Lines changed: 135 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections.abc import Mapping
12
from os.path import relpath
23
from pathlib import Path
3-
from typing import Callable, Concatenate, TypeAlias
4+
from typing import Any, Callable, Concatenate, TypeAlias, overload
45

56
import numpy as np
67
import pytest
@@ -23,7 +24,9 @@
2324
dataset_from_kerchunk_refs,
2425
)
2526

26-
RoundtripFunction: TypeAlias = Callable[Concatenate[xr.Dataset, Path, ...], xr.Dataset]
27+
RoundtripFunction: TypeAlias = Callable[
28+
Concatenate[xr.Dataset | xr.DataTree, Path, ...], xr.Dataset | xr.DataTree
29+
]
2730

2831

2932
def test_kerchunk_roundtrip_in_memory_no_concat(array_v3_metadata):
@@ -111,7 +114,22 @@ def roundtrip_as_kerchunk_parquet(vds: xr.Dataset, tmpdir, **kwargs):
111114
return xr.open_dataset(f"{tmpdir}/refs.parquet", engine="kerchunk", **kwargs)
112115

113116

114-
def roundtrip_as_in_memory_icechunk(vds: xr.Dataset, tmpdir, **kwargs):
117+
@overload
118+
def roundtrip_as_in_memory_icechunk(
119+
vdata: xr.Dataset, tmp_path: Path, **kwargs
120+
) -> xr.Dataset: ...
121+
@overload
122+
def roundtrip_as_in_memory_icechunk(
123+
vdata: xr.DataTree, tmp_path: Path, **kwargs
124+
) -> xr.DataTree: ...
125+
126+
127+
def roundtrip_as_in_memory_icechunk(
128+
vdata: xr.Dataset | xr.DataTree,
129+
tmp_path: Path,
130+
virtualize_kwargs: Mapping[str, Any] | None = None,
131+
**kwargs,
132+
) -> xr.Dataset | xr.DataTree:
115133
from icechunk import Repository, Storage
116134

117135
# create an in-memory icechunk store
@@ -120,7 +138,17 @@ def roundtrip_as_in_memory_icechunk(vds: xr.Dataset, tmpdir, **kwargs):
120138
session = repo.writable_session("main")
121139

122140
# write those references to an icechunk store
123-
vds.virtualize.to_icechunk(session.store)
141+
vdata.virtualize.to_icechunk(session.store, **(virtualize_kwargs or {}))
142+
143+
if isinstance(vdata, xr.DataTree):
144+
# read the dataset from icechunk
145+
return xr.open_datatree(
146+
session.store, # type: ignore
147+
engine="zarr",
148+
zarr_format=3,
149+
consolidated=False,
150+
**kwargs,
151+
)
124152

125153
# read the dataset from icechunk
126154
return xr.open_zarr(session.store, zarr_format=3, consolidated=False, **kwargs)
@@ -219,16 +247,14 @@ def test_kerchunk_roundtrip_concat(
219247

220248
roundtrip = roundtrip_func(vds, tmp_path, decode_times=decode_times)
221249

222-
if decode_times is False:
223-
# assert all_close to original dataset
224-
xrt.assert_allclose(roundtrip, ds)
250+
# assert all_close to original dataset
251+
xrt.assert_allclose(roundtrip, ds)
225252

226-
# assert coordinate attributes are maintained
227-
for coord in ds.coords:
228-
assert ds.coords[coord].attrs == roundtrip.coords[coord].attrs
229-
else:
230-
# they are very very close! But assert_allclose doesn't seem to work on datetimes
231-
assert (roundtrip.time - ds.time).sum() == 0
253+
# assert coordinate attributes are maintained
254+
for coord in ds.coords:
255+
assert ds.coords[coord].attrs == roundtrip.coords[coord].attrs
256+
257+
if decode_times:
232258
assert roundtrip.time.dtype == ds.time.dtype
233259
assert roundtrip.time.encoding["units"] == ds.time.encoding["units"]
234260
assert (
@@ -303,6 +329,102 @@ def test_datetime64_dtype_fill_value(
303329
assert roundtrip.a.attrs == vds.a.attrs
304330

305331

332+
@parametrize_over_hdf_backends
333+
@pytest.mark.parametrize(
334+
"roundtrip_func", [roundtrip_as_in_memory_icechunk] if has_icechunk else []
335+
)
336+
@pytest.mark.parametrize("decode_times", (False, True))
337+
@pytest.mark.parametrize("time_vars", ([], ["time"]))
338+
@pytest.mark.parametrize("inherit", (False, True))
339+
def test_datatree_roundtrip(
340+
tmp_path: Path,
341+
roundtrip_func: RoundtripFunction,
342+
hdf_backend: type[VirtualBackend],
343+
decode_times: bool,
344+
time_vars: list[str],
345+
inherit: bool,
346+
):
347+
# set up example xarray dataset
348+
with xr.tutorial.open_dataset("air_temperature", decode_times=decode_times) as ds:
349+
# split into two datasets
350+
ds1 = ds.isel(time=slice(None, 1460))
351+
ds2 = ds.isel(time=slice(1460, None))
352+
353+
# save it to disk as netCDF (in temporary directory)
354+
air1_nc_path = tmp_path / "air1.nc"
355+
air2_nc_path = tmp_path / "air2.nc"
356+
ds1.to_netcdf(air1_nc_path)
357+
ds2.to_netcdf(air2_nc_path)
358+
359+
# use open_dataset_via_kerchunk to read it as references
360+
with (
361+
open_virtual_dataset(
362+
str(air1_nc_path),
363+
loadable_variables=time_vars,
364+
decode_times=decode_times,
365+
backend=hdf_backend,
366+
) as vds1,
367+
open_virtual_dataset(
368+
str(air2_nc_path),
369+
loadable_variables=time_vars,
370+
decode_times=decode_times,
371+
backend=hdf_backend,
372+
) as vds2,
373+
):
374+
if not decode_times or not time_vars:
375+
assert vds1.time.dtype == np.dtype("float32")
376+
assert vds2.time.dtype == np.dtype("float32")
377+
else:
378+
assert vds1.time.dtype == np.dtype("<M8[ns]")
379+
assert vds2.time.dtype == np.dtype("<M8[ns]")
380+
assert "units" in vds1.time.encoding
381+
assert "units" in vds2.time.encoding
382+
assert "calendar" in vds1.time.encoding
383+
assert "calendar" in vds2.time.encoding
384+
385+
vdt = xr.DataTree.from_dict({"/vds1": vds1, "/nested/vds2": vds2})
386+
387+
with roundtrip_func(
388+
vdt,
389+
tmp_path,
390+
virtualize_kwargs=dict(write_inherited_coords=inherit),
391+
decode_times=decode_times,
392+
) as roundtrip:
393+
assert isinstance(roundtrip, xr.DataTree)
394+
395+
# assert all_close to original dataset
396+
roundtrip_vds1 = roundtrip["/vds1"].to_dataset()
397+
roundtrip_vds2 = roundtrip["/nested/vds2"].to_dataset()
398+
xrt.assert_allclose(roundtrip_vds1, ds1)
399+
xrt.assert_allclose(roundtrip_vds2, ds2)
400+
401+
# assert coordinate attributes are maintained
402+
for coord in ds1.coords:
403+
assert ds1.coords[coord].attrs == roundtrip_vds1.coords[coord].attrs
404+
for coord in ds2.coords:
405+
assert ds2.coords[coord].attrs == roundtrip_vds2.coords[coord].attrs
406+
407+
if decode_times:
408+
assert roundtrip_vds1.time.dtype == ds1.time.dtype
409+
assert roundtrip_vds2.time.dtype == ds2.time.dtype
410+
assert (
411+
roundtrip_vds1.time.encoding["units"]
412+
== ds1.time.encoding["units"]
413+
)
414+
assert (
415+
roundtrip_vds2.time.encoding["units"]
416+
== ds2.time.encoding["units"]
417+
)
418+
assert (
419+
roundtrip_vds1.time.encoding["calendar"]
420+
== ds1.time.encoding["calendar"]
421+
)
422+
assert (
423+
roundtrip_vds2.time.encoding["calendar"]
424+
== ds2.time.encoding["calendar"]
425+
)
426+
427+
306428
@parametrize_over_hdf_backends
307429
def test_open_scalar_variable(tmp_path: Path, hdf_backend: type[VirtualBackend]):
308430
# regression test for GH issue #100

0 commit comments

Comments
 (0)