1- from collections .abc import Callable
1+ from collections .abc import Callable , Hashable
22from dataclasses import dataclass
3- from typing import Any , overload
3+ from typing import Any , TypedDict , overload
44
55import numpy as np
66import pandas as pd
1010class InvalidBoundsError (Exception ): ...
1111
1212
13+ class CoordHandler (TypedDict ):
14+ names : list [str ]
15+ func : Callable
16+
17+
1318@dataclass
1419class Grid :
1520 """Object storing grid information."""
@@ -75,7 +80,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
7580 grid .south , grid .north + grid .resolution_lat , grid .resolution_lat
7681 )
7782
78- if np .remainder ((grid .north - grid .south ), grid .resolution_lat ) > 0 :
83+ if np .remainder ((grid .east - grid .west ), grid .resolution_lat ) > 0 :
7984 lon_coords = np .arange (grid .west , grid .east , grid .resolution_lon )
8085 else :
8186 lon_coords = np .arange (
@@ -193,24 +198,6 @@ def common_coords(
193198 return sorted ([str (coord ) for coord in coords ])
194199
195200
196- @overload
197- def call_on_dataset (
198- func : Callable [..., xr .Dataset ],
199- obj : xr .DataArray ,
200- * args : Any ,
201- ** kwargs : Any ,
202- ) -> xr .DataArray : ...
203-
204-
205- @overload
206- def call_on_dataset (
207- func : Callable [..., xr .Dataset ],
208- obj : xr .Dataset ,
209- * args : Any ,
210- ** kwargs : Any ,
211- ) -> xr .Dataset : ...
212-
213-
214201def call_on_dataset (
215202 func : Callable [..., xr .Dataset ],
216203 obj : xr .DataArray | xr .Dataset ,
@@ -235,3 +222,193 @@ def call_on_dataset(
235222 return next (iter (result .data_vars .values ())).rename (obj .name )
236223
237224 return result
225+
226+
227+ def format_for_regrid (
228+ obj : xr .DataArray | xr .Dataset , target : xr .Dataset
229+ ) -> xr .DataArray | xr .Dataset :
230+ """Apply any pre-formatting to the input dataset to prepare for regridding.
231+ Currently handles padding of spherical geometry if lat/lon coordinates can
232+ be inferred and the domain size requires boundary padding.
233+ """
234+ orig_chunksizes = obj .chunksizes
235+
236+ # Special-cased coordinates with accepted names and formatting function
237+ coord_handlers : dict [str , CoordHandler ] = {
238+ "lat" : {"names" : ["lat" , "latitude" ], "func" : format_lat },
239+ "lon" : {"names" : ["lon" , "longitude" ], "func" : format_lon },
240+ }
241+ # Identify coordinates that need to be formatted
242+ formatted_coords = {}
243+ for coord_type , handler in coord_handlers .items ():
244+ for coord in obj .coords .keys ():
245+ if str (coord ).lower () in handler ["names" ]:
246+ formatted_coords [coord_type ] = str (coord )
247+
248+ # Apply formatting
249+ for coord_type , coord in formatted_coords .items ():
250+ # Make sure formatted coords are sorted
251+ obj = ensure_monotonic (obj , coord )
252+ target = ensure_monotonic (target , coord )
253+ obj = coord_handlers [coord_type ]["func" ](obj , target , formatted_coords )
254+ # Coerce back to a single chunk if that's what was passed
255+ if len (orig_chunksizes .get (coord , [])) == 1 :
256+ obj = obj .chunk ({coord : - 1 })
257+
258+ return obj
259+
260+
261+ def format_lat (
262+ obj : xr .DataArray | xr .Dataset ,
263+ target : xr .Dataset , # noqa ARG001
264+ formatted_coords : dict [str , str ],
265+ ) -> xr .DataArray | xr .Dataset :
266+ """If the latitude coordinate is inferred to be global, defined as having
267+ a value within one grid spacing of the poles, and the grid does not natively
268+ have values at -90 and 90, add a single value at each pole computed as the
269+ mean of the first and last latitude bands. This should be roughly equivalent
270+ to the `Pole="all"` option in `ESMF`.
271+
272+ For example, with a grid spacing of 1 degree, and a source grid ranging from
273+ -89.5 to 89.5, the poles would be padded with values at -90 and 90. A grid ranging
274+ from -88 to 88 would not be padded because coverage does not extend all the way
275+ to the poles. A grid ranging from -90 to 90 would also not be padded because the
276+ poles will already be covered in the regridding weights.
277+ """
278+ lat_coord = formatted_coords ["lat" ]
279+ lon_coord = formatted_coords .get ("lon" )
280+
281+ # Concat a padded value representing the mean of the first/last lat bands
282+ # This should match the Pole="all" option of ESMF
283+ # TODO: with cos(90) = 0 weighting, these weights might be 0?
284+
285+ polar_lat = 90
286+ dy = obj .coords [lat_coord ].diff (lat_coord ).max ().values .item ()
287+
288+ # Only pad if global but don't have edge values directly at poles
289+ # NOTE: could use xr.pad here instead of xr.concat, but none of the
290+ # modes are an exact fit for this scheme
291+ lat_vals = obj .coords [lat_coord ].values
292+ # South pole
293+ if dy - polar_lat >= obj .coords [lat_coord ].values [0 ] > - polar_lat :
294+ south_pole = obj .isel ({lat_coord : 0 })
295+ if lon_coord is not None :
296+ south_pole = south_pole .mean (lon_coord )
297+ obj = xr .concat ([south_pole , obj ], dim = lat_coord ) # type: ignore
298+ lat_vals = np .concatenate ([[- polar_lat ], lat_vals ])
299+
300+ # North pole
301+ if polar_lat - dy <= obj .coords [lat_coord ].values [- 1 ] < polar_lat :
302+ north_pole = obj .isel ({lat_coord : - 1 })
303+ if lon_coord is not None :
304+ north_pole = north_pole .mean (lon_coord )
305+ obj = xr .concat ([obj , north_pole ], dim = lat_coord ) # type: ignore
306+ lat_vals = np .concatenate ([lat_vals , [polar_lat ]])
307+
308+ obj = update_coord (obj , lat_coord , lat_vals )
309+
310+ return obj
311+
312+
313+ def format_lon (
314+ obj : xr .DataArray | xr .Dataset , target : xr .Dataset , formatted_coords : dict [str , str ]
315+ ) -> xr .DataArray | xr .Dataset :
316+ """Format the longitude coordinate by shifting the source grid to line up with
317+ the target anywhere in the range of -360 to 360, and then add a single wraparound
318+ padding column if the domain is inferred to be global and the east or west edges
319+ of the target lie outside the source grid centers.
320+
321+ For example, with a source grid ranging from 0.5 to 359.5 and a target grid ranging
322+ from -180 to 180, the source grid would be shifted to -179.5 to 179.5 and then
323+ padded on both the left and right with wraparound values at -180.5 and 180.5 to
324+ provide full coverage for the target edge cells at -180 and 180.
325+ """
326+ lon_coord = formatted_coords ["lon" ]
327+
328+ # Find a wrap point outside of the left and right bounds of the target
329+ # This ensures we have coverage on the target and handles global > regional
330+ source_vals = obj .coords [lon_coord ].values
331+ target_vals = target .coords [lon_coord ].values
332+ wrap_point = (target_vals [- 1 ] + target_vals [0 ] + 360 ) / 2
333+ source_vals = np .where (
334+ source_vals < wrap_point - 360 , source_vals + 360 , source_vals
335+ )
336+ source_vals = np .where (source_vals > wrap_point , source_vals - 360 , source_vals )
337+ obj = update_coord (obj , lon_coord , source_vals )
338+
339+ obj = ensure_monotonic (obj , lon_coord )
340+
341+ # Only pad if domain is global in lon
342+ source_lon = obj .coords [lon_coord ]
343+ target_lon = target .coords [lon_coord ]
344+ dx_s = source_lon .diff (lon_coord ).max ().values .item ()
345+ dx_t = target_lon .diff (lon_coord ).max ().values .item ()
346+ is_global_lon = source_lon .max ().values - source_lon .min ().values >= 360 - dx_s
347+
348+ if is_global_lon :
349+ left_pad = (source_lon .values [0 ] - target_lon .values [0 ] + dx_t / 2 ) / dx_s
350+ right_pad = (target_lon .values [- 1 ] - source_lon .values [- 1 ] + dx_t / 2 ) / dx_s
351+ left_pad = int (np .ceil (np .max ([left_pad , 0 ])))
352+ right_pad = int (np .ceil (np .max ([right_pad , 0 ])))
353+ obj = obj .pad ({lon_coord : (left_pad , right_pad )}, mode = "wrap" , keep_attrs = True )
354+ lon_vals = obj .coords [lon_coord ].values
355+ if left_pad :
356+ lon_vals [:left_pad ] = source_lon .values [- left_pad :] - 360
357+ if right_pad :
358+ lon_vals [- right_pad :] = source_lon .values [:right_pad ] + 360
359+ obj = update_coord (obj , lon_coord , lon_vals )
360+
361+ return obj
362+
363+
364+ def coord_is_covered (
365+ obj : xr .DataArray | xr .Dataset , target : xr .Dataset , coord : Hashable
366+ ) -> bool :
367+ """Check if the source coord fully covers the target coord."""
368+ pad = target [coord ].diff (coord ).max ().values
369+ left_covered = obj [coord ].min () <= target [coord ].min () - pad
370+ right_covered = obj [coord ].max () >= target [coord ].max () + pad
371+ return bool (left_covered .item () and right_covered .item ())
372+
373+
374+ @overload
375+ def ensure_monotonic (obj : xr .DataArray , coord : Hashable ) -> xr .DataArray : ...
376+
377+
378+ @overload
379+ def ensure_monotonic (obj : xr .Dataset , coord : Hashable ) -> xr .Dataset : ...
380+
381+
382+ def ensure_monotonic (
383+ obj : xr .DataArray | xr .Dataset , coord : Hashable
384+ ) -> xr .DataArray | xr .Dataset :
385+ """Ensure that an object has monotonically increasing indexes for a
386+ given coordinate. Only sort and drop duplicates if needed because this
387+ requires reindexing which can be expensive."""
388+ if not obj .indexes [coord ].is_monotonic_increasing :
389+ obj = obj .sortby (coord )
390+ if not obj .indexes [coord ].is_unique :
391+ obj = obj .drop_duplicates (coord )
392+ return obj
393+
394+
395+ @overload
396+ def update_coord (
397+ obj : xr .DataArray , coord : Hashable , coord_vals : np .ndarray
398+ ) -> xr .DataArray : ...
399+
400+
401+ @overload
402+ def update_coord (
403+ obj : xr .Dataset , coord : Hashable , coord_vals : np .ndarray
404+ ) -> xr .Dataset : ...
405+
406+
407+ def update_coord (
408+ obj : xr .DataArray | xr .Dataset , coord : Hashable , coord_vals : np .ndarray
409+ ) -> xr .DataArray | xr .Dataset :
410+ """Update the values of a coordinate, ensuring indexes stay in sync."""
411+ attrs = obj .coords [coord ].attrs
412+ obj = obj .assign_coords ({coord : coord_vals })
413+ obj .coords [coord ].attrs = attrs
414+ return obj
0 commit comments