Skip to content

Commit a8732a8

Browse files
committed
Refactor logic for preserving coordinates with regrid2
1 parent e2d259e commit a8732a8

File tree

2 files changed

+76
-49
lines changed

2 files changed

+76
-49
lines changed

tests/test_regrid.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -517,18 +517,6 @@ def test_unknown_variable(self):
517517
with pytest.raises(KeyError):
518518
regridder.horizontal("unknown", self.coarse_2d_ds)
519519

520-
def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self):
521-
ds = self.coarse_2d_ds.copy()
522-
ds["lat"].attrs["standard_name"] = "latitude"
523-
ds["lat"].attrs.pop("axis")
524-
525-
regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds)
526-
527-
with pytest.raises(
528-
ValueError, match="Could not determine axis name for dimension"
529-
):
530-
regridder.horizontal("ts", ds)
531-
532520
@pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning")
533521
def test_regrid_input_mask(self):
534522
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)

xcdat/regridder/regrid2.py

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
from typing import Any, List, Optional, Tuple
1+
from typing import Any, Dict, List, Optional, Tuple
22

33
import numpy as np
44
import xarray as xr
55

6-
from xcdat.axis import CF_ATTR_MAP, get_dim_keys
6+
import xcdat as xc
7+
from xcdat.axis import VAR_NAME_MAP, get_dim_keys
78
from xcdat.regridder.base import BaseRegridder, _preserve_bounds
89

10+
# Spatial axes keys used to map to the axes in an input data variable to build
11+
# the output variable.
12+
VALID_SPATIAL_AXES_KEYS = ["X", "Y"] + VAR_NAME_MAP["X"] + VAR_NAME_MAP["Y"]
13+
914

1015
class Regrid2Regridder(BaseRegridder):
1116
def __init__(
@@ -229,48 +234,87 @@ def _build_dataset(
229234
input_grid: xr.Dataset,
230235
output_grid: xr.Dataset,
231236
) -> xr.Dataset:
232-
input_data_var = ds[data_var]
237+
"""Build a new xarray Dataset with the given output data and coordinates.
233238
234-
output_coords: dict[str, xr.DataArray] = {}
235-
output_data_vars: dict[str, xr.DataArray] = {}
239+
Parameters
240+
----------
241+
ds : xr.Dataset
242+
The input dataset containing the data variable to be regridded.
243+
data_var : str
244+
The name of the data variable in the input dataset to be regridded.
245+
output_data : np.ndarray
246+
The regridded data to be included in the output dataset.
247+
input_grid : xr.Dataset
248+
The input grid dataset containing the original grid information.
249+
output_grid : xr.Dataset
250+
The output grid dataset containing the new grid information.
236251
237-
for dim in input_data_var.dims:
238-
dim = str(dim)
252+
Returns
253+
-------
254+
xr.Dataset
255+
A new dataset containing the regridded data variable with updated
256+
coordinates and attributes.
257+
"""
258+
dv_input = ds[data_var]
239259

240-
try:
241-
axis_name = [
242-
cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims
243-
][0]
244-
except IndexError as e:
245-
raise ValueError(
246-
f"Could not determine axis name for dimension {dim}"
247-
) from e
248-
249-
if axis_name in ["X", "Y"]:
250-
output_coords[dim] = output_grid.cf[axis_name]
251-
else:
252-
output_coords[dim] = input_data_var.cf[axis_name]
260+
output_coords = _get_output_coords(dv_input, output_grid)
253261

254262
output_da = xr.DataArray(
255263
output_data,
256-
dims=input_data_var.dims,
264+
dims=dv_input.dims,
257265
coords=output_coords,
258266
attrs=ds[data_var].attrs.copy(),
259267
name=data_var,
260268
)
261269

262-
output_data_vars[data_var] = output_da
263-
264-
output_ds = xr.Dataset(
265-
output_data_vars,
266-
attrs=input_grid.attrs.copy(),
267-
)
268-
270+
output_ds = output_da.to_dataset()
271+
output_ds.attrs = input_grid.attrs.copy()
269272
output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"])
270273

271274
return output_ds
272275

273276

277+
def _get_output_coords(
278+
dv_input: xr.DataArray, output_grid: xr.Dataset
279+
) -> Dict[str, xr.DataArray]:
280+
"""
281+
Generate the output coordinates for regridding based on the input data
282+
variable and output grid.
283+
284+
Parameters
285+
----------
286+
dv_input : xr.DataArray
287+
The input data variable containing the original coordinates.
288+
output_grid : xr.Dataset
289+
The dataset containing the target grid coordinates.
290+
291+
Returns
292+
-------
293+
Dict[str, xr.DataArray]
294+
A dictionary where keys are coordinate names and values are the
295+
corresponding coordinates from the output grid or input data variable,
296+
aligned with the dimensions of the input data variable.
297+
"""
298+
output_coords: Dict[str, xr.DataArray] = {}
299+
300+
# First get the X and Y axes from the output grid.
301+
for key in ["X", "Y"]:
302+
input_coord = xc.get_dim_coords(dv_input, key) # type: ignore
303+
output_coord = xc.get_dim_coords(output_grid, key) # type: ignore
304+
305+
output_coords[str(input_coord.name)] = output_coord # type: ignore
306+
307+
# Get the remaining axes the input data variable (e.g., "time").
308+
for dim in dv_input.dims:
309+
if dim not in output_coords:
310+
output_coords[str(dim)] = dv_input[dim]
311+
312+
# Sort the coords to align with the input data variable dims.
313+
output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims}
314+
315+
return output_coords
316+
317+
274318
def _map_latitude(
275319
src: np.ndarray, dst: np.ndarray
276320
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
@@ -564,17 +608,12 @@ def _get_dimension(input_data_var, cf_axis_name):
564608

565609

566610
def _get_bounds_ensure_dtype(ds, axis):
567-
cf_keys = CF_ATTR_MAP[axis].values()
568-
569611
bounds = None
570612

571-
for key in cf_keys:
572-
try:
573-
name = ds.cf.bounds[key][0]
574-
except (KeyError, IndexError):
575-
pass
576-
else:
577-
bounds = ds[name]
613+
try:
614+
bounds = ds.bounds.get_bounds(axis)
615+
except KeyError:
616+
pass
578617

579618
if bounds is None:
580619
raise RuntimeError(f"Could not determine {axis!r} bounds")

0 commit comments

Comments
 (0)