Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1c199ec
Add support for aggregation and scan elements in `Dataset`.
huard May 23, 2023
f878749
include note about add_variable_agg
huard May 23, 2023
7b32261
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2023
b3114d3
black
huard May 23, 2023
7d5f3c3
Merge branch 'fix-40' of github.com:xarray-contrib/xncml into fix-40
huard May 23, 2023
d01446a
Set the close function so that underlying files aggregated by NcML ar…
huard May 23, 2023
f9c348a
add psutil to requirements
huard May 23, 2023
b76121d
merge
huard May 23, 2023
d88f463
implement parallel reads
huard May 23, 2023
108b414
merge
huard May 24, 2023
5171094
read_netcdf returns Delayed object. Add test confirming read_scan and…
huard May 24, 2023
b4513cd
merge
huard May 24, 2023
8ed12df
mention the parallel parameter in the tutorial
huard May 24, 2023
669c6cd
remove print statement
huard May 24, 2023
c653fc7
Merge branch 'main' into fix-42
huard May 25, 2023
259e376
Merge branch 'main' into fix-42
huard Jul 17, 2023
f88b50f
Merge branch 'main' into fix-42
huard Sep 20, 2023
cc6279f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2023
f29c1eb
Merge branch 'main' into fix-42
huard Nov 7, 2023
f4e47cc
Merge branch 'main' into fix-42
huard Dec 12, 2023
76421a0
Merge branch 'main' into fix-42
huard Jan 8, 2024
f4b6611
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2024
ddc10a8
trying to put lock on netcdf operations, but tests still segfault
huard Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
0.3 (unreleased)
================

- Support parallel reads in `open_ncml` using `dask`. By @huard
- Closing the dataset returned by `open_ncml` will close the underlying opened files. By @huard
- Add `add_aggregation` and `add_variable_agg` to `Dataset` class. By @huard
- Add `add_scan` to `Dataset` class. By @huard
- Closing the dataset returned by `open_ncml` will close the underlying opened files. By @huard


0.2 (2023-02-23)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@
"source": [
"## Open an NcML document as an ``xarray.Dataset``\n",
"\n",
"``xncml`` can parse NcML instructions to create an ``xarray.Dataset``. Calling the `close` method on the returned dataset will close all underlying netCDF files referred to by the NcML document. Note that a few NcML instructions are not yet supported."
"``xncml`` can parse NcML instructions to create an ``xarray.Dataset``. Calling the `close` method on the returned dataset will close all underlying netCDF files referred to by the NcML document. The `parallel` argument will open underlying files in parallel using `dask`. Note that a few NcML instructions are not yet supported."
]
},
{
Expand Down Expand Up @@ -2257,7 +2257,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xncml
known_third_party=numpy,pkg_resources,psutil,pytest,setuptools,xarray,xmltodict,xsdata
known_third_party=dask,numpy,pkg_resources,psutil,pytest,setuptools,xarray,xmltodict,xsdata
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
55 changes: 52 additions & 3 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import datetime as dt
from pathlib import Path

import dask
import numpy as np
import psutil
import pytest
import xarray as xr
from dask.delayed import Delayed

import xncml

Expand Down Expand Up @@ -45,6 +48,17 @@ def test_aggexisting():
ds.close()


def test_parallel_agg_existing():
with CheckClose():
ds = xncml.open_ncml(data / 'aggExisting.xml', parallel=True)
check_dimension(ds)
check_coord_var(ds)
check_agg_coord_var(ds)
check_read_data(ds)
assert ds['time'].attrs['ncmlAdded'] == 'timeAtt'
ds.close()


def test_aggexisting_w_coords():
with CheckClose():
ds = xncml.open_ncml(data / 'aggExistingWcoords.xml')
Expand Down Expand Up @@ -160,9 +174,11 @@ def test_agg_syn_no_coords_dir():


def test_agg_synthetic():
ds = xncml.open_ncml(data / 'aggSynthetic.xml')
assert len(ds.time) == 3
assert all(ds.time == [0, 10, 99])
with CheckClose():
ds = xncml.open_ncml(data / 'aggSynthetic.xml')
assert len(ds.time) == 3
assert all(ds.time == [0, 10, 99])
ds.close()


def test_agg_synthetic_2():
Expand All @@ -185,6 +201,14 @@ def test_agg_syn_scan():
ds.close()


def test_parallel_agg_syn_scan():
with CheckClose():
ds = xncml.open_ncml(data / 'aggSynScan.xml', parallel=True)
assert len(ds.time) == 3
assert all(ds.time == [0, 10, 20])
ds.close()


def test_agg_syn_rename():
ds = xncml.open_ncml(data / 'aggSynRename.xml')
assert len(ds.time) == 3
Expand Down Expand Up @@ -333,3 +357,28 @@ def check_read_data(ds):
assert t.shape == (59, 3, 4)
assert t.dtype == float
assert 'T' in ds.data_vars


def test_read_scan_parallel():
"""Confirm that read_scan returns a list of dask.delayed objects."""
ncml = data / 'aggSynScan.xml'
obj = xncml.parser.parse(ncml)
agg = obj.choice[1]
scan = agg.scan[0]
datasets = xncml.parser.read_scan(scan, ncml, parallel=True)
assert type(datasets[0]) == Delayed
assert len(datasets) == 3
(datasets,) = dask.compute(datasets)
assert len(datasets) == 3


def test_read_netcdf_parallel():
"""Confirm that read_netcdf returns a dask.delayed object."""
ncml = data / 'aggExisting.xml'
obj = xncml.parser.parse(ncml)
agg = obj.choice[1]
nc = agg.netcdf[0]
datasets = xncml.parser.read_netcdf(
xr.Dataset(), ref=xr.Dataset(), obj=nc, ncml=ncml, parallel=True
)
assert type(datasets) == Delayed
77 changes: 54 additions & 23 deletions xncml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@
from functools import partial
from itertools import chain
from pathlib import Path
from typing import Union

import dask
import numpy as np
import xarray as xr
from dask.delayed import Delayed
from xsdata.formats.dataclass.parsers import XmlParser

from .generated import (
Expand Down Expand Up @@ -81,14 +84,16 @@ def parse(path: Path) -> Netcdf:
return parser.from_path(path, Netcdf)


def open_ncml(ncml: str | Path) -> xr.Dataset:
def open_ncml(ncml: str | Path, parallel: bool = False) -> xr.Dataset:
"""
Convert NcML document to a dataset.

Parameters
----------
ncml : str | Path
Path to NcML file.
parallel : bool
If True, use dask to read data in parallel.

Returns
-------
Expand All @@ -99,10 +104,12 @@ def open_ncml(ncml: str | Path) -> xr.Dataset:
ncml = Path(ncml)
obj = parse(ncml)

return read_netcdf(xr.Dataset(), xr.Dataset(), obj, ncml)
return read_netcdf(xr.Dataset(), xr.Dataset(), obj, ncml, parallel=parallel)


def read_netcdf(target: xr.Dataset, ref: xr.Dataset, obj: Netcdf, ncml: Path) -> xr.Dataset:
def read_netcdf(
target: xr.Dataset, ref: xr.Dataset, obj: Netcdf, ncml: Path, parallel: bool
) -> xr.Dataset:
"""
Return content of <netcdf> element.

Expand All @@ -116,14 +123,17 @@ def read_netcdf(target: xr.Dataset, ref: xr.Dataset, obj: Netcdf, ncml: Path) ->
<netcdf> object description.
ncml : Path
Path to NcML document, sometimes required to follow relative links.
parallel : bool
If True, use dask to read data in parallel.

Returns
-------
xr.Dataset
Dataset holding variables and attributes defined in <netcdf> element.
"""
# Open location if any
ref = read_ds(obj, ncml) or ref
if obj.location:
ref = read_ds(obj, ncml, parallel=parallel)

# <explicit/> element means that only content specifically mentioned in NcML document is included in dataset.
if obj.explicit is not None:
Expand All @@ -133,15 +143,17 @@ def read_netcdf(target: xr.Dataset, ref: xr.Dataset, obj: Netcdf, ncml: Path) ->
target = ref

for item in filter_by_class(obj.choice, Aggregation):
target = read_aggregation(target, item, ncml)
target = read_aggregation(target, item, ncml, parallel=parallel)

# Handle <variable>, <attribute> and <remove> elements
target = read_group(target, ref, obj)

return target


def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dataset:
def read_aggregation(
target: xr.Dataset, obj: Aggregation, ncml: Path, parallel: bool
) -> xr.Dataset:
"""
Return merged or concatenated content of <aggregation> element.

Expand All @@ -153,6 +165,8 @@ def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dat
<aggregation> object description.
ncml : Path
Path to NcML document, sometimes required to follow relative links.
parallel : bool
If True, use dask to read data in parallel.

Returns
-------
Expand All @@ -167,14 +181,16 @@ def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dat
for attr in obj.promote_global_attribute:
raise NotImplementedError

getattr_ = dask.delayed(getattr) if parallel else getattr

# Create list of datasets to aggregate.
datasets = []
closers = []

for item in obj.netcdf:
# Open dataset defined in <netcdf>'s `location` attribute
tar = read_netcdf(xr.Dataset(), ref=xr.Dataset(), obj=item, ncml=ncml)
closers.append(getattr(tar, '_close'))
tar = read_netcdf(xr.Dataset(), ref=xr.Dataset(), obj=item, ncml=ncml, parallel=parallel)
closers.append(getattr_(tar, '_close'))

# Select variables
if names:
Expand All @@ -189,9 +205,12 @@ def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dat

# Handle <scan> element
for item in obj.scan:
dss = read_scan(item, ncml)
dss = read_scan(item, ncml, parallel=parallel)
datasets.extend([ds.chunk() for ds in dss])
closers.extend([getattr(ds, '_close') for ds in dss])
closers.extend([getattr_(ds, '_close') for ds in dss])

if parallel:
datasets, closers = dask.compute(datasets, closers)

# Need to decode time variable
if obj.time_units_change:
Expand All @@ -210,38 +229,46 @@ def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dat
else:
raise NotImplementedError

# Merge aggregated dataset into target dataset
agg = read_group(agg, None, obj)
out = target.merge(agg, combine_attrs='no_conflicts')

# Set close method to close all opened datasets
out.set_close(partial(_multi_file_closer, closers))
return out


def read_ds(obj: Netcdf, ncml: Path) -> xr.Dataset:
def read_ds(obj: Netcdf, ncml: Path, parallel: bool) -> Union[xr.Dataset, Delayed]:
"""
Return dataset defined in <netcdf> element.

Parameters
----------
obj : Netcdf
<netcdf> object description.
<netcdf> object description. Must have `location` attribute.
ncml : Path
Path to NcML document, sometimes required to follow relative links.
parallel : bool
If True, use dask to read data in parallel.

Returns
-------
xr.Dataset
Dataset defined at <netcdf>' `location` attribute.
"""
if obj.location:
try:
# Python >= 3.9
location = obj.location.removeprefix('file:')
except AttributeError:
location = obj.location.strip('file:')

if not Path(location).is_absolute():
location = ncml.parent / location
return xr.open_dataset(location, decode_times=False)
try:
# Python >= 3.9
location = obj.location.removeprefix('file:')
except AttributeError:
location = obj.location.strip('file:')

if not Path(location).is_absolute():
location = ncml.parent / location

open_dataset_ = dask.delayed(xr.open_dataset) if parallel else xr.open_dataset

return open_dataset_(location, decode_times=False)


def read_group(target: xr.Dataset, ref: xr.Dataset, obj: Group | Netcdf) -> xr.Dataset:
Expand Down Expand Up @@ -284,7 +311,7 @@ def read_group(target: xr.Dataset, ref: xr.Dataset, obj: Group | Netcdf) -> xr.D
return target


def read_scan(obj: Aggregation.Scan, ncml: Path) -> [xr.Dataset]:
def read_scan(obj: Aggregation.Scan, ncml: Path, parallel: bool) -> [xr.Dataset]:
"""
Return list of datasets defined in <scan> element.

Expand All @@ -294,6 +321,8 @@ def read_scan(obj: Aggregation.Scan, ncml: Path) -> [xr.Dataset]:
<scan> object description.
ncml : Path
Path to NcML document, sometimes required to follow relative links.
parallel : bool
If True, use dask to read data in parallel.

Returns
-------
Expand Down Expand Up @@ -328,7 +357,9 @@ def read_scan(obj: Aggregation.Scan, ncml: Path) -> [xr.Dataset]:

files.sort()

return [xr.open_dataset(f, decode_times=False) for f in files]
open_dataset_ = dask.delayed(xr.open_dataset) if parallel else xr.open_dataset

return [open_dataset_(f, decode_times=False) for f in files]


def read_coord_value(nc: Netcdf, agg: Aggregation, dtypes: list = ()):
Expand Down