1
+ from collections .abc import Mapping
1
2
from os .path import relpath
2
3
from pathlib import Path
3
- from typing import Callable , Concatenate , TypeAlias
4
+ from typing import Any , Callable , Concatenate , TypeAlias , overload
4
5
5
6
import numpy as np
6
7
import pytest
23
24
dataset_from_kerchunk_refs ,
24
25
)
25
26
26
- RoundtripFunction : TypeAlias = Callable [Concatenate [xr .Dataset , Path , ...], xr .Dataset ]
27
+ RoundtripFunction : TypeAlias = Callable [
28
+ Concatenate [xr .Dataset | xr .DataTree , Path , ...], xr .Dataset | xr .DataTree
29
+ ]
27
30
28
31
29
32
def test_kerchunk_roundtrip_in_memory_no_concat (array_v3_metadata ):
@@ -111,7 +114,22 @@ def roundtrip_as_kerchunk_parquet(vds: xr.Dataset, tmpdir, **kwargs):
111
114
return xr .open_dataset (f"{ tmpdir } /refs.parquet" , engine = "kerchunk" , ** kwargs )
112
115
113
116
114
- def roundtrip_as_in_memory_icechunk (vds : xr .Dataset , tmpdir , ** kwargs ):
117
+ @overload
118
+ def roundtrip_as_in_memory_icechunk (
119
+ vdata : xr .Dataset , tmp_path : Path , ** kwargs
120
+ ) -> xr .Dataset : ...
121
+ @overload
122
+ def roundtrip_as_in_memory_icechunk (
123
+ vdata : xr .DataTree , tmp_path : Path , ** kwargs
124
+ ) -> xr .DataTree : ...
125
+
126
+
127
+ def roundtrip_as_in_memory_icechunk (
128
+ vdata : xr .Dataset | xr .DataTree ,
129
+ tmp_path : Path ,
130
+ virtualize_kwargs : Mapping [str , Any ] | None = None ,
131
+ ** kwargs ,
132
+ ) -> xr .Dataset | xr .DataTree :
115
133
from icechunk import Repository , Storage
116
134
117
135
# create an in-memory icechunk store
@@ -120,7 +138,17 @@ def roundtrip_as_in_memory_icechunk(vds: xr.Dataset, tmpdir, **kwargs):
120
138
session = repo .writable_session ("main" )
121
139
122
140
# write those references to an icechunk store
123
- vds .virtualize .to_icechunk (session .store )
141
+ vdata .virtualize .to_icechunk (session .store , ** (virtualize_kwargs or {}))
142
+
143
+ if isinstance (vdata , xr .DataTree ):
144
+ # read the dataset from icechunk
145
+ return xr .open_datatree (
146
+ session .store , # type: ignore
147
+ engine = "zarr" ,
148
+ zarr_format = 3 ,
149
+ consolidated = False ,
150
+ ** kwargs ,
151
+ )
124
152
125
153
# read the dataset from icechunk
126
154
return xr .open_zarr (session .store , zarr_format = 3 , consolidated = False , ** kwargs )
@@ -219,16 +247,14 @@ def test_kerchunk_roundtrip_concat(
219
247
220
248
roundtrip = roundtrip_func (vds , tmp_path , decode_times = decode_times )
221
249
222
- if decode_times is False :
223
- # assert all_close to original dataset
224
- xrt .assert_allclose (roundtrip , ds )
250
+ # assert all_close to original dataset
251
+ xrt .assert_allclose (roundtrip , ds )
225
252
226
- # assert coordinate attributes are maintained
227
- for coord in ds .coords :
228
- assert ds .coords [coord ].attrs == roundtrip .coords [coord ].attrs
229
- else :
230
- # they are very very close! But assert_allclose doesn't seem to work on datetimes
231
- assert (roundtrip .time - ds .time ).sum () == 0
253
+ # assert coordinate attributes are maintained
254
+ for coord in ds .coords :
255
+ assert ds .coords [coord ].attrs == roundtrip .coords [coord ].attrs
256
+
257
+ if decode_times :
232
258
assert roundtrip .time .dtype == ds .time .dtype
233
259
assert roundtrip .time .encoding ["units" ] == ds .time .encoding ["units" ]
234
260
assert (
@@ -303,6 +329,102 @@ def test_datetime64_dtype_fill_value(
303
329
assert roundtrip .a .attrs == vds .a .attrs
304
330
305
331
332
+ @parametrize_over_hdf_backends
333
+ @pytest .mark .parametrize (
334
+ "roundtrip_func" , [roundtrip_as_in_memory_icechunk ] if has_icechunk else []
335
+ )
336
+ @pytest .mark .parametrize ("decode_times" , (False , True ))
337
+ @pytest .mark .parametrize ("time_vars" , ([], ["time" ]))
338
+ @pytest .mark .parametrize ("inherit" , (False , True ))
339
+ def test_datatree_roundtrip (
340
+ tmp_path : Path ,
341
+ roundtrip_func : RoundtripFunction ,
342
+ hdf_backend : type [VirtualBackend ],
343
+ decode_times : bool ,
344
+ time_vars : list [str ],
345
+ inherit : bool ,
346
+ ):
347
+ # set up example xarray dataset
348
+ with xr .tutorial .open_dataset ("air_temperature" , decode_times = decode_times ) as ds :
349
+ # split into two datasets
350
+ ds1 = ds .isel (time = slice (None , 1460 ))
351
+ ds2 = ds .isel (time = slice (1460 , None ))
352
+
353
+ # save it to disk as netCDF (in temporary directory)
354
+ air1_nc_path = tmp_path / "air1.nc"
355
+ air2_nc_path = tmp_path / "air2.nc"
356
+ ds1 .to_netcdf (air1_nc_path )
357
+ ds2 .to_netcdf (air2_nc_path )
358
+
359
+ # use open_dataset_via_kerchunk to read it as references
360
+ with (
361
+ open_virtual_dataset (
362
+ str (air1_nc_path ),
363
+ loadable_variables = time_vars ,
364
+ decode_times = decode_times ,
365
+ backend = hdf_backend ,
366
+ ) as vds1 ,
367
+ open_virtual_dataset (
368
+ str (air2_nc_path ),
369
+ loadable_variables = time_vars ,
370
+ decode_times = decode_times ,
371
+ backend = hdf_backend ,
372
+ ) as vds2 ,
373
+ ):
374
+ if not decode_times or not time_vars :
375
+ assert vds1 .time .dtype == np .dtype ("float32" )
376
+ assert vds2 .time .dtype == np .dtype ("float32" )
377
+ else :
378
+ assert vds1 .time .dtype == np .dtype ("<M8[ns]" )
379
+ assert vds2 .time .dtype == np .dtype ("<M8[ns]" )
380
+ assert "units" in vds1 .time .encoding
381
+ assert "units" in vds2 .time .encoding
382
+ assert "calendar" in vds1 .time .encoding
383
+ assert "calendar" in vds2 .time .encoding
384
+
385
+ vdt = xr .DataTree .from_dict ({"/vds1" : vds1 , "/nested/vds2" : vds2 })
386
+
387
+ with roundtrip_func (
388
+ vdt ,
389
+ tmp_path ,
390
+ virtualize_kwargs = dict (write_inherited_coords = inherit ),
391
+ decode_times = decode_times ,
392
+ ) as roundtrip :
393
+ assert isinstance (roundtrip , xr .DataTree )
394
+
395
+ # assert all_close to original dataset
396
+ roundtrip_vds1 = roundtrip ["/vds1" ].to_dataset ()
397
+ roundtrip_vds2 = roundtrip ["/nested/vds2" ].to_dataset ()
398
+ xrt .assert_allclose (roundtrip_vds1 , ds1 )
399
+ xrt .assert_allclose (roundtrip_vds2 , ds2 )
400
+
401
+ # assert coordinate attributes are maintained
402
+ for coord in ds1 .coords :
403
+ assert ds1 .coords [coord ].attrs == roundtrip_vds1 .coords [coord ].attrs
404
+ for coord in ds2 .coords :
405
+ assert ds2 .coords [coord ].attrs == roundtrip_vds2 .coords [coord ].attrs
406
+
407
+ if decode_times :
408
+ assert roundtrip_vds1 .time .dtype == ds1 .time .dtype
409
+ assert roundtrip_vds2 .time .dtype == ds2 .time .dtype
410
+ assert (
411
+ roundtrip_vds1 .time .encoding ["units" ]
412
+ == ds1 .time .encoding ["units" ]
413
+ )
414
+ assert (
415
+ roundtrip_vds2 .time .encoding ["units" ]
416
+ == ds2 .time .encoding ["units" ]
417
+ )
418
+ assert (
419
+ roundtrip_vds1 .time .encoding ["calendar" ]
420
+ == ds1 .time .encoding ["calendar" ]
421
+ )
422
+ assert (
423
+ roundtrip_vds2 .time .encoding ["calendar" ]
424
+ == ds2 .time .encoding ["calendar" ]
425
+ )
426
+
427
+
306
428
@parametrize_over_hdf_backends
307
429
def test_open_scalar_variable (tmp_path : Path , hdf_backend : type [VirtualBackend ]):
308
430
# regression test for GH issue #100
0 commit comments