Skip to content

Commit f5d74f0

Browse files
committed
Keep attributes upon conservative regrid
1 parent 804b989 commit f5d74f0

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

src/xarray_regrid/methods/conservative.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def conservative_regrid_dataset(
7979
"""Dataset implementation of the conservative regridding method."""
8080
data_vars: list[str] = list(data.data_vars)
8181
dataarrays = [data[var] for var in data_vars]
82+
attrs = data.attrs
83+
da_attrs = [da.attrs for da in dataarrays]
8284

8385
for coord in coords:
8486
target_coords = coords[coord].to_numpy()
@@ -98,7 +100,11 @@ def conservative_regrid_dataset(
98100
da = dataarrays[i].transpose(coord, ...)
99101
dataarrays[i] = apply_weights(da, weights, coord, target_coords)
100102

101-
return xr.merge(dataarrays) # TODO: add other coordinates/data variables back in.
103+
for da, _attr in zip(dataarrays, da_attrs, strict=True):
104+
da.attrs = _attr
105+
regridded = xr.merge(dataarrays)
106+
regridded.attrs = attrs
107+
return regridded # TODO: add other coordinates/data variables back in.
102108

103109

104110
def conservative_regrid_dataarray(
@@ -107,6 +113,7 @@ def conservative_regrid_dataarray(
107113
latitude_coord: str,
108114
) -> xr.DataArray:
109115
"""DataArray implementation of the conservative regridding method."""
116+
attrs = data.attrs
110117
for coord in coords:
111118
if coord in data.coords:
112119
target_coords = coords[coord].to_numpy()
@@ -124,7 +131,7 @@ def conservative_regrid_dataarray(
124131

125132
data = data.transpose(coord, ...)
126133
data = apply_weights(data, weights, coord, target_coords)
127-
134+
data.attrs = attrs
128135
return data
129136

130137

tests/test_most_common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ def test_most_common(dummy_lc_data, dummy_target_grid):
8282

8383
def test_attrs_dataarray(dummy_lc_data, dummy_target_grid):
8484
dummy_lc_data["lc"].attrs = {"test": "testing"}
85-
da_regrid = dummy_lc_data["lc"].regrid.most_common(
86-
dummy_target_grid
87-
)
85+
da_regrid = dummy_lc_data["lc"].regrid.most_common(dummy_target_grid)
8886
assert da_regrid.attrs != {}
8987
assert da_regrid.attrs == dummy_lc_data["lc"].attrs
9088

tests/test_regrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from copy import deepcopy
21
from pathlib import Path
32

43
import pytest
@@ -150,3 +149,4 @@ def test_attrs_dataset_conservative(sample_input_data, sample_grid_ds):
150149
sample_grid_ds, latitude_coord="latitude"
151150
)
152151
assert ds_regrid.attrs == sample_input_data.attrs
152+
assert ds_regrid["d2m"].attrs == sample_input_data["d2m"].attrs

0 commit comments

Comments
 (0)