|
1 | | -from typing import Any, List, Optional, Tuple |
| 1 | +from typing import Any, Dict, List, Optional, Tuple |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import xarray as xr |
5 | 5 |
|
| 6 | +import xcdat as xc |
6 | 7 | from xcdat.axis import get_dim_keys |
7 | 8 | from xcdat.regridder.base import BaseRegridder, _preserve_bounds |
8 | 9 |
|
@@ -105,8 +106,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: |
105 | 106 | ds, |
106 | 107 | data_var, |
107 | 108 | output_data, |
108 | | - dst_lat_bnds, |
109 | | - dst_lon_bnds, |
110 | 109 | self._input_grid, |
111 | 110 | self._output_grid, |
112 | 111 | ) |
@@ -228,38 +227,90 @@ def _build_dataset( |
228 | 227 | ds: xr.Dataset, |
229 | 228 | data_var: str, |
230 | 229 | output_data: np.ndarray, |
231 | | - dst_lat_bnds, |
232 | | - dst_lon_bnds, |
233 | 230 | input_grid: xr.Dataset, |
234 | 231 | output_grid: xr.Dataset, |
235 | 232 | ) -> xr.Dataset: |
236 | | - input_data_var = ds[data_var] |
| 233 | + """Build a new xarray Dataset with the given output data and coordinates. |
| 234 | +
|
| 235 | + Parameters |
| 236 | + ---------- |
| 237 | + ds : xr.Dataset |
| 238 | + The input dataset containing the data variable to be regridded. |
| 239 | + data_var : str |
| 240 | + The name of the data variable in the input dataset to be regridded. |
| 241 | + output_data : np.ndarray |
| 242 | + The regridded data to be included in the output dataset. |
| 243 | + input_grid : xr.Dataset |
| 244 | + The input grid dataset containing the original grid information. |
| 245 | + output_grid : xr.Dataset |
| 246 | + The output grid dataset containing the new grid information. |
237 | 247 |
|
238 | | - output_coords: dict[str, xr.DataArray] = {} |
239 | | - output_data_vars: dict[str, xr.DataArray] = {} |
| 248 | + Returns |
| 249 | + ------- |
| 250 | + xr.Dataset |
| 251 | + A new dataset containing the regridded data variable with updated |
| 252 | + coordinates and attributes. |
| 253 | + """ |
| 254 | + dv_input = ds[data_var] |
240 | 255 |
|
241 | | - dims = list(input_data_var.dims) |
| 256 | + output_coords = _get_output_coords(dv_input, output_grid) |
242 | 257 |
|
243 | 258 | output_da = xr.DataArray( |
244 | 259 | output_data, |
245 | | - dims=dims, |
| 260 | + dims=dv_input.dims, |
246 | 261 | coords=output_coords, |
247 | 262 | attrs=ds[data_var].attrs.copy(), |
248 | 263 | name=data_var, |
249 | 264 | ) |
250 | 265 |
|
251 | | - output_data_vars[data_var] = output_da |
252 | | - |
253 | | - output_ds = xr.Dataset( |
254 | | - output_data_vars, |
255 | | - attrs=input_grid.attrs.copy(), |
256 | | - ) |
257 | | - |
| 266 | + output_ds = output_da.to_dataset() |
| 267 | + output_ds.attrs = input_grid.attrs.copy() |
258 | 268 | output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"]) |
259 | 269 |
|
260 | 270 | return output_ds |
261 | 271 |
|
262 | 272 |
|
| 273 | +def _get_output_coords( |
| 274 | + dv_input: xr.DataArray, output_grid: xr.Dataset |
| 275 | +) -> Dict[str, xr.DataArray]: |
| 276 | + """ |
| 277 | + Generate the output coordinates for regridding based on the input data |
| 278 | + variable and output grid. |
| 279 | +
|
| 280 | + Parameters |
| 281 | + ---------- |
| 282 | + dv_input : xr.DataArray |
| 283 | + The input data variable containing the original coordinates. |
| 284 | + output_grid : xr.Dataset |
| 285 | + The dataset containing the target grid coordinates. |
| 286 | +
|
| 287 | + Returns |
| 288 | + ------- |
| 289 | + Dict[str, xr.DataArray] |
| 290 | + A dictionary where keys are coordinate names and values are the |
| 291 | + corresponding coordinates from the output grid or input data variable, |
| 292 | + aligned with the dimensions of the input data variable. |
| 293 | + """ |
| 294 | + output_coords: Dict[str, xr.DataArray] = {} |
| 295 | + |
| 296 | + # First get the X and Y axes from the output grid. |
| 297 | + for key in ["X", "Y"]: |
| 298 | + input_coord = xc.get_dim_coords(dv_input, key) # type: ignore |
| 299 | + output_coord = xc.get_dim_coords(output_grid, key) # type: ignore |
| 300 | + |
| 301 | + output_coords[str(input_coord.name)] = output_coord # type: ignore |
| 302 | + |
| 303 | + # Get the remaining axes the input data variable (e.g., "time"). |
| 304 | + for dim in dv_input.dims: |
| 305 | + if dim not in output_coords: |
| 306 | + output_coords[str(dim)] = dv_input[dim] |
| 307 | + |
| 308 | + # Sort the coords to align with the input data variable dims. |
| 309 | + output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims} |
| 310 | + |
| 311 | + return output_coords |
| 312 | + |
| 313 | + |
263 | 314 | def _map_latitude( |
264 | 315 | src: np.ndarray, dst: np.ndarray |
265 | 316 | ) -> Tuple[List[np.ndarray], List[np.ndarray]]: |
@@ -553,12 +604,15 @@ def _get_dimension(input_data_var, cf_axis_name): |
553 | 604 |
|
554 | 605 |
|
555 | 606 | def _get_bounds_ensure_dtype(ds, axis): |
| 607 | + bounds = None |
| 608 | + |
556 | 609 | try: |
557 | | - name = ds.cf.bounds[axis][0] |
558 | | - except (KeyError, IndexError) as e: |
559 | | - raise RuntimeError(f"Could not determine {axis!r} bounds") from e |
560 | | - else: |
561 | | - bounds = ds[name] |
| 610 | + bounds = ds.bounds.get_bounds(axis) |
| 611 | + except KeyError: |
| 612 | + pass |
| 613 | + |
| 614 | + if bounds is None: |
| 615 | + raise RuntimeError(f"Could not determine {axis!r} bounds") |
562 | 616 |
|
563 | 617 | if bounds.dtype != np.float32: |
564 | 618 | bounds = bounds.astype(np.float32) |
|
0 commit comments