Skip to content

Commit 84e4a3f

Browse files
Fixes preserving coordinates in regrid2 (#716)
Co-authored-by: Tom Vo <tomvothecoder@gmail.com>
1 parent 5cc9d23 commit 84e4a3f

File tree

1 file changed

+76
-22
lines changed

1 file changed

+76
-22
lines changed

xcdat/regridder/regrid2.py

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, List, Optional, Tuple
1+
from typing import Any, Dict, List, Optional, Tuple
22

33
import numpy as np
44
import xarray as xr
55

6+
import xcdat as xc
67
from xcdat.axis import get_dim_keys
78
from xcdat.regridder.base import BaseRegridder, _preserve_bounds
89

@@ -105,8 +106,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
105106
ds,
106107
data_var,
107108
output_data,
108-
dst_lat_bnds,
109-
dst_lon_bnds,
110109
self._input_grid,
111110
self._output_grid,
112111
)
@@ -228,38 +227,90 @@ def _build_dataset(
228227
ds: xr.Dataset,
229228
data_var: str,
230229
output_data: np.ndarray,
231-
dst_lat_bnds,
232-
dst_lon_bnds,
233230
input_grid: xr.Dataset,
234231
output_grid: xr.Dataset,
235232
) -> xr.Dataset:
236-
input_data_var = ds[data_var]
233+
"""Build a new xarray Dataset with the given output data and coordinates.
234+
235+
Parameters
236+
----------
237+
ds : xr.Dataset
238+
The input dataset containing the data variable to be regridded.
239+
data_var : str
240+
The name of the data variable in the input dataset to be regridded.
241+
output_data : np.ndarray
242+
The regridded data to be included in the output dataset.
243+
input_grid : xr.Dataset
244+
The input grid dataset containing the original grid information.
245+
output_grid : xr.Dataset
246+
The output grid dataset containing the new grid information.
237247
238-
output_coords: dict[str, xr.DataArray] = {}
239-
output_data_vars: dict[str, xr.DataArray] = {}
248+
Returns
249+
-------
250+
xr.Dataset
251+
A new dataset containing the regridded data variable with updated
252+
coordinates and attributes.
253+
"""
254+
dv_input = ds[data_var]
240255

241-
dims = list(input_data_var.dims)
256+
output_coords = _get_output_coords(dv_input, output_grid)
242257

243258
output_da = xr.DataArray(
244259
output_data,
245-
dims=dims,
260+
dims=dv_input.dims,
246261
coords=output_coords,
247262
attrs=ds[data_var].attrs.copy(),
248263
name=data_var,
249264
)
250265

251-
output_data_vars[data_var] = output_da
252-
253-
output_ds = xr.Dataset(
254-
output_data_vars,
255-
attrs=input_grid.attrs.copy(),
256-
)
257-
266+
output_ds = output_da.to_dataset()
267+
output_ds.attrs = input_grid.attrs.copy()
258268
output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"])
259269

260270
return output_ds
261271

262272

273+
def _get_output_coords(
274+
dv_input: xr.DataArray, output_grid: xr.Dataset
275+
) -> Dict[str, xr.DataArray]:
276+
"""
277+
Generate the output coordinates for regridding based on the input data
278+
variable and output grid.
279+
280+
Parameters
281+
----------
282+
dv_input : xr.DataArray
283+
The input data variable containing the original coordinates.
284+
output_grid : xr.Dataset
285+
The dataset containing the target grid coordinates.
286+
287+
Returns
288+
-------
289+
Dict[str, xr.DataArray]
290+
A dictionary where keys are coordinate names and values are the
291+
corresponding coordinates from the output grid or input data variable,
292+
aligned with the dimensions of the input data variable.
293+
"""
294+
output_coords: Dict[str, xr.DataArray] = {}
295+
296+
# First get the X and Y axes from the output grid.
297+
for key in ["X", "Y"]:
298+
input_coord = xc.get_dim_coords(dv_input, key) # type: ignore
299+
output_coord = xc.get_dim_coords(output_grid, key) # type: ignore
300+
301+
output_coords[str(input_coord.name)] = output_coord # type: ignore
302+
303+
# Get the remaining axes the input data variable (e.g., "time").
304+
for dim in dv_input.dims:
305+
if dim not in output_coords:
306+
output_coords[str(dim)] = dv_input[dim]
307+
308+
# Sort the coords to align with the input data variable dims.
309+
output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims}
310+
311+
return output_coords
312+
313+
263314
def _map_latitude(
264315
src: np.ndarray, dst: np.ndarray
265316
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
@@ -553,12 +604,15 @@ def _get_dimension(input_data_var, cf_axis_name):
553604

554605

555606
def _get_bounds_ensure_dtype(ds, axis):
607+
bounds = None
608+
556609
try:
557-
name = ds.cf.bounds[axis][0]
558-
except (KeyError, IndexError) as e:
559-
raise RuntimeError(f"Could not determine {axis!r} bounds") from e
560-
else:
561-
bounds = ds[name]
610+
bounds = ds.bounds.get_bounds(axis)
611+
except KeyError:
612+
pass
613+
614+
if bounds is None:
615+
raise RuntimeError(f"Could not determine {axis!r} bounds")
562616

563617
if bounds.dtype != np.float32:
564618
bounds = bounds.astype(np.float32)

0 commit comments

Comments
 (0)