Skip to content

Commit e2d259e

Browse files
committed
Add test and code cleanup
- Fix `_get_bounds_ensure_dtype` to determine `bounds` with axis that has `standard_name` attr (in addition to `axis` attr check) - Remove unused `dst_lat_bnds` and `dst_lon_bnds` args for `_build_dataset()` - Add unit test to cover `ValueError` in `regrid2.py` `_build_dataset()`
1 parent e067a6c commit e2d259e

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

tests/test_regrid.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,18 @@ def test_unknown_variable(self):
517517
with pytest.raises(KeyError):
518518
regridder.horizontal("unknown", self.coarse_2d_ds)
519519

520+
def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self):
521+
ds = self.coarse_2d_ds.copy()
522+
ds["lat"].attrs["standard_name"] = "latitude"
523+
ds["lat"].attrs.pop("axis")
524+
525+
regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds)
526+
527+
with pytest.raises(
528+
ValueError, match="Could not determine axis name for dimension"
529+
):
530+
regridder.horizontal("ts", ds)
531+
520532
@pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning")
521533
def test_regrid_input_mask(self):
522534
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)

xcdat/regridder/regrid2.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import xarray as xr
55

6-
from xcdat.axis import get_dim_keys
6+
from xcdat.axis import CF_ATTR_MAP, get_dim_keys
77
from xcdat.regridder.base import BaseRegridder, _preserve_bounds
88

99

@@ -105,8 +105,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
105105
ds,
106106
data_var,
107107
output_data,
108-
dst_lat_bnds,
109-
dst_lon_bnds,
110108
self._input_grid,
111109
self._output_grid,
112110
)
@@ -228,8 +226,6 @@ def _build_dataset(
228226
ds: xr.Dataset,
229227
data_var: str,
230228
output_data: np.ndarray,
231-
dst_lat_bnds,
232-
dst_lon_bnds,
233229
input_grid: xr.Dataset,
234230
output_grid: xr.Dataset,
235231
) -> xr.Dataset:
@@ -242,11 +238,13 @@ def _build_dataset(
242238
dim = str(dim)
243239

244240
try:
245-
axis_name = [x for x, y in ds.cf.axes.items() if dim in y][0]
246-
except Exception:
241+
axis_name = [
242+
cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims
243+
][0]
244+
except IndexError as e:
247245
raise ValueError(
248246
f"Could not determine axis name for dimension {dim}"
249-
) from None
247+
) from e
250248

251249
if axis_name in ["X", "Y"]:
252250
output_coords[dim] = output_grid.cf[axis_name]
@@ -566,12 +564,20 @@ def _get_dimension(input_data_var, cf_axis_name):
566564

567565

568566
def _get_bounds_ensure_dtype(ds, axis):
569-
try:
570-
name = ds.cf.bounds[axis][0]
571-
except (KeyError, IndexError) as e:
572-
raise RuntimeError(f"Could not determine {axis!r} bounds") from e
573-
else:
574-
bounds = ds[name]
567+
cf_keys = CF_ATTR_MAP[axis].values()
568+
569+
bounds = None
570+
571+
for key in cf_keys:
572+
try:
573+
name = ds.cf.bounds[key][0]
574+
except (KeyError, IndexError):
575+
pass
576+
else:
577+
bounds = ds[name]
578+
579+
if bounds is None:
580+
raise RuntimeError(f"Could not determine {axis!r} bounds")
575581

576582
if bounds.dtype != np.float32:
577583
bounds = bounds.astype(np.float32)

0 commit comments

Comments
 (0)