Fixes preserving coordinates in regrid2#716
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #716 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 15 15
Lines 1609 1621 +12
=========================================
+ Hits 1609 1621 +12 ☔ View full report in Codecov by Sentry. |
lee1043
left a comment
There was a problem hiding this comment.
Thank you for the PR. I confirmed that the minimal code in #709 (comment) returns results as expected, with coordinates properly preserved.
|
I'll carry this forward while @jasonb5 is out. |
2ab5cc4 to
e067a6c
Compare
There was a problem hiding this comment.
Hey @jasonb5, I refactored your logic to handle all axes that are map-able via xCDAT (using CF "axis" attr, "standard_name" attr, and valid dim name). Before it could only map to axes that have the "axis" attr set (via ds.cf.axes).
The GH Actions build still passes.
@lee1043 let's wait for Jason to review my changes when he's back.
xcdat/regridder/regrid2.py
Outdated
| for dim in input_data_var.dims: | ||
| dim = str(dim) | ||
| Returns | ||
| ------- | ||
| xr.Dataset | ||
| A new dataset containing the regridded data variable with updated | ||
| coordinates and attributes. | ||
| """ | ||
| dv_input = ds[data_var] | ||
|
|
||
| try: | ||
| axis_name = [ | ||
| cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims | ||
| ][0] | ||
| except IndexError as e: | ||
| raise ValueError( | ||
| f"Could not determine axis name for dimension {dim}" | ||
| ) from e | ||
|
|
||
| if axis_name in ["X", "Y"]: | ||
| output_coords[dim] = output_grid.cf[axis_name] | ||
| else: | ||
| output_coords[dim] = input_data_var.cf[axis_name] |
There was a problem hiding this comment.
Your logic looped over input data variable dims and checks if the dim name is in ds.cf.axes (must have "axis" attr set).
| output_coords: Dict[str, xr.DataArray] = {} | ||
|
|
||
| # First get the X and Y axes from the output grid. | ||
| for key in ["X", "Y"]: | ||
| input_coord = xc.get_dim_coords(dv_input, key) # type: ignore | ||
| output_coord = xc.get_dim_coords(output_grid, key) # type: ignore | ||
|
|
||
| output_coords[str(input_coord.name)] = output_coord # type: ignore | ||
|
|
||
| # Get the remaining axes the input data variable (e.g., "time"). | ||
| for dim in dv_input.dims: | ||
| if dim not in output_coords: | ||
| output_coords[str(dim)] = dv_input[dim] | ||
|
|
||
| # Sort the coords to align with the input data variable dims. | ||
| output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims} | ||
|
|
||
| return output_coords |
There was a problem hiding this comment.
My logic gets the X and Y axes from the output_grid via xc.get_dim_coords(). This function can map to axes via "axis" attr, "standard_name" attr, and accepted dim names (e.g., lat, lon).
For remaining axes, it just gets them directly from the input data variable (dv_input) like in your logic.
|
|
||
| for key in cf_keys: | ||
| try: | ||
| name = ds.cf.bounds[key][0] | ||
| except (KeyError, IndexError): | ||
| pass | ||
| else: | ||
| bounds = ds[name] | ||
| try: | ||
| bounds = ds.bounds.get_bounds(axis) | ||
| except KeyError: | ||
| pass | ||
|
|
There was a problem hiding this comment.
I simplified this logic by using ds.bounds.get_bounds() instead of ds.cf.bounds.
There was a problem hiding this comment.
If I remember correctly, I had switched this to ds.cf.bounds because ds.bounds.get_bounds was slow in comparison and didn't provide any additional benefit. I haven't tested recently so maybe the performance is more inline. Either way we can just leave it and address any issues later.
|
Hey @jasonb5, just pinging you again for review when you have time this week. I am hoping to get this fix in for v0.8.0, which I'm aiming for within the next few weeks. |
- Fix `_get_bounds_ensure_dtype` to determine `bounds` with axis that has `standard_name` attr (in addition to `axis` attr check) - Remove unused `dst_lat_bnds` and `dst_lon_bnds` args for `_build_dataset()` - Add unit test to cover `ValueError` in `regrid2.py` `_build_dataset()`
e149b9d to
83ae517
Compare
Description
When regrid2 constructs the output dataset it did not preserve any coords from the input dataset or the output grid. This PR fixes the issue by populating coords for every dimension from the appropriate source (input dataset, output grid).
Checklist
If applicable: