|
1 | | -from typing import Any, List, Optional, Tuple |
| 1 | +from typing import Any, Dict, List, Optional, Tuple |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import xarray as xr |
5 | 5 |
|
6 | | -from xcdat.axis import CF_ATTR_MAP, get_dim_keys |
| 6 | +import xcdat as xc |
| 7 | +from xcdat.axis import VAR_NAME_MAP, get_dim_keys |
7 | 8 | from xcdat.regridder.base import BaseRegridder, _preserve_bounds |
8 | 9 |
|
| 10 | +# Spatial axes keys used to map to the axes in an input data variable to build |
| 11 | +# the output variable. |
| 12 | +VALID_SPATIAL_AXES_KEYS = ["X", "Y"] + VAR_NAME_MAP["X"] + VAR_NAME_MAP["Y"] |
| 13 | + |
9 | 14 |
|
10 | 15 | class Regrid2Regridder(BaseRegridder): |
11 | 16 | def __init__( |
@@ -229,48 +234,87 @@ def _build_dataset( |
229 | 234 | input_grid: xr.Dataset, |
230 | 235 | output_grid: xr.Dataset, |
231 | 236 | ) -> xr.Dataset: |
232 | | - input_data_var = ds[data_var] |
| 237 | + """Build a new xarray Dataset with the given output data and coordinates. |
233 | 238 |
|
234 | | - output_coords: dict[str, xr.DataArray] = {} |
235 | | - output_data_vars: dict[str, xr.DataArray] = {} |
| 239 | + Parameters |
| 240 | + ---------- |
| 241 | + ds : xr.Dataset |
| 242 | + The input dataset containing the data variable to be regridded. |
| 243 | + data_var : str |
| 244 | + The name of the data variable in the input dataset to be regridded. |
| 245 | + output_data : np.ndarray |
| 246 | + The regridded data to be included in the output dataset. |
| 247 | + input_grid : xr.Dataset |
| 248 | + The input grid dataset containing the original grid information. |
| 249 | + output_grid : xr.Dataset |
| 250 | + The output grid dataset containing the new grid information. |
236 | 251 |
|
237 | | - for dim in input_data_var.dims: |
238 | | - dim = str(dim) |
| 252 | + Returns |
| 253 | + ------- |
| 254 | + xr.Dataset |
| 255 | + A new dataset containing the regridded data variable with updated |
| 256 | + coordinates and attributes. |
| 257 | + """ |
| 258 | + dv_input = ds[data_var] |
239 | 259 |
|
240 | | - try: |
241 | | - axis_name = [ |
242 | | - cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims |
243 | | - ][0] |
244 | | - except IndexError as e: |
245 | | - raise ValueError( |
246 | | - f"Could not determine axis name for dimension {dim}" |
247 | | - ) from e |
248 | | - |
249 | | - if axis_name in ["X", "Y"]: |
250 | | - output_coords[dim] = output_grid.cf[axis_name] |
251 | | - else: |
252 | | - output_coords[dim] = input_data_var.cf[axis_name] |
| 260 | + output_coords = _get_output_coords(dv_input, output_grid) |
253 | 261 |
|
254 | 262 | output_da = xr.DataArray( |
255 | 263 | output_data, |
256 | | - dims=input_data_var.dims, |
| 264 | + dims=dv_input.dims, |
257 | 265 | coords=output_coords, |
258 | 266 | attrs=ds[data_var].attrs.copy(), |
259 | 267 | name=data_var, |
260 | 268 | ) |
261 | 269 |
|
262 | | - output_data_vars[data_var] = output_da |
263 | | - |
264 | | - output_ds = xr.Dataset( |
265 | | - output_data_vars, |
266 | | - attrs=input_grid.attrs.copy(), |
267 | | - ) |
268 | | - |
| 270 | + output_ds = output_da.to_dataset() |
| 271 | + output_ds.attrs = input_grid.attrs.copy() |
269 | 272 | output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"]) |
270 | 273 |
|
271 | 274 | return output_ds |
272 | 275 |
|
273 | 276 |
|
| 277 | +def _get_output_coords( |
| 278 | + dv_input: xr.DataArray, output_grid: xr.Dataset |
| 279 | +) -> Dict[str, xr.DataArray]: |
| 280 | + """ |
| 281 | + Generate the output coordinates for regridding based on the input data |
| 282 | + variable and output grid. |
| 283 | +
|
| 284 | + Parameters |
| 285 | + ---------- |
| 286 | + dv_input : xr.DataArray |
| 287 | + The input data variable containing the original coordinates. |
| 288 | + output_grid : xr.Dataset |
| 289 | + The dataset containing the target grid coordinates. |
| 290 | +
|
| 291 | + Returns |
| 292 | + ------- |
| 293 | + Dict[str, xr.DataArray] |
| 294 | + A dictionary where keys are coordinate names and values are the |
| 295 | + corresponding coordinates from the output grid or input data variable, |
| 296 | + aligned with the dimensions of the input data variable. |
| 297 | + """ |
| 298 | + output_coords: Dict[str, xr.DataArray] = {} |
| 299 | + |
| 300 | + # First get the X and Y axes from the output grid. |
| 301 | + for key in ["X", "Y"]: |
| 302 | + input_coord = xc.get_dim_coords(dv_input, key) # type: ignore |
| 303 | + output_coord = xc.get_dim_coords(output_grid, key) # type: ignore |
| 304 | + |
| 305 | + output_coords[str(input_coord.name)] = output_coord # type: ignore |
| 306 | + |
| 307 | + # Get the remaining axes the input data variable (e.g., "time"). |
| 308 | + for dim in dv_input.dims: |
| 309 | + if dim not in output_coords: |
| 310 | + output_coords[str(dim)] = dv_input[dim] |
| 311 | + |
| 312 | + # Sort the coords to align with the input data variable dims. |
| 313 | + output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims} |
| 314 | + |
| 315 | + return output_coords |
| 316 | + |
| 317 | + |
274 | 318 | def _map_latitude( |
275 | 319 | src: np.ndarray, dst: np.ndarray |
276 | 320 | ) -> Tuple[List[np.ndarray], List[np.ndarray]]: |
@@ -564,17 +608,12 @@ def _get_dimension(input_data_var, cf_axis_name): |
564 | 608 |
|
565 | 609 |
|
566 | 610 | def _get_bounds_ensure_dtype(ds, axis): |
567 | | - cf_keys = CF_ATTR_MAP[axis].values() |
568 | | - |
569 | 611 | bounds = None |
570 | 612 |
|
571 | | - for key in cf_keys: |
572 | | - try: |
573 | | - name = ds.cf.bounds[key][0] |
574 | | - except (KeyError, IndexError): |
575 | | - pass |
576 | | - else: |
577 | | - bounds = ds[name] |
| 613 | + try: |
| 614 | + bounds = ds.bounds.get_bounds(axis) |
| 615 | + except KeyError: |
| 616 | + pass |
578 | 617 |
|
579 | 618 | if bounds is None: |
580 | 619 | raise RuntimeError(f"Could not determine {axis!r} bounds") |
|
0 commit comments