Skip to content

Commit 6cba003

Browse files
committed
Correctly restore coordinate attributes for the conservative method
1 parent 47ec14f commit 6cba003

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

src/xarray_regrid/methods/conservative.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def conservative_regrid_dataset(
8080
data_vars: list[str] = list(data.data_vars)
8181
data_coords: list[str] = list(data.coords)
8282
dataarrays = [data[var] for var in data_vars]
83-
datacoords = [data[var] for var in data_coords]
83+
8484
attrs = data.attrs
8585
da_attrs = [da.attrs for da in dataarrays]
86-
dcrds_attrs = [dcrds.attrs for dcrds in datacoords]
86+
coord_attrs = [data[coord].attrs for coord in data_coords]
8787

8888
for coord in coords:
8989
target_coords = coords[coord].to_numpy()
@@ -103,12 +103,16 @@ def conservative_regrid_dataset(
103103
da = dataarrays[i].transpose(coord, ...)
104104
dataarrays[i] = apply_weights(da, weights, coord, target_coords)
105105

106-
for da, _attr in zip(dataarrays, da_attrs, strict=True):
107-
da.attrs = _attr
108-
for dcrd, _attr in zip(datacoords, dcrds_attrs, strict=True):
109-
dcrd.attrs = _attr
106+
for da, attr in zip(dataarrays, da_attrs, strict=True):
107+
da.attrs = attr
110108
regridded = xr.merge(dataarrays)
109+
111110
regridded.attrs = attrs
111+
112+
new_coords = [regridded[coord] for coord in data_coords]
113+
for coord, attr in zip(new_coords, coord_attrs, strict=True):
114+
coord.attrs = attr
115+
112116
return regridded # TODO: add other coordinates/data variables back in.
113117

114118

@@ -119,9 +123,10 @@ def conservative_regrid_dataarray(
119123
) -> xr.DataArray:
120124
"""DataArray implementation of the conservative regridding method."""
121125
data_coords: list[str] = list(data.coords)
122-
datacoords = [data[var] for var in data_coords]
126+
123127
attrs = data.attrs
124-
dcrds_attrs = [dcrds.attrs for dcrds in datacoords]
128+
coord_attrs = [data[coord].attrs for coord in data_coords]
129+
125130
for coord in coords:
126131
if coord in data.coords:
127132
target_coords = coords[coord].to_numpy()
@@ -139,9 +144,12 @@ def conservative_regrid_dataarray(
139144

140145
data = data.transpose(coord, ...)
141146
data = apply_weights(data, weights, coord, target_coords)
142-
for dcrd, _attr in zip(datacoords, dcrds_attrs, strict=True):
143-
dcrd.attrs = _attr
147+
148+
new_coords = [data[coord] for coord in data_coords]
149+
for coord, attr in zip(new_coords, coord_attrs, strict=True):
150+
coord.attrs = attr
144151
data.attrs = attrs
152+
145153
return data
146154

147155

0 commit comments

Comments
 (0)