1- from typing import Literal
1+ from collections .abc import Hashable
2+ from typing import Literal , overload
23
4+ import dask .array
5+ import numpy as np
36import xarray as xr
47
58from xarray_regrid import utils
69
710
11+ @overload
12+ def interp_regrid (
13+ data : xr .DataArray ,
14+ target_ds : xr .Dataset ,
15+ method : Literal ["linear" , "nearest" , "cubic" ],
16+ ) -> xr .DataArray :
17+ ...
18+
19+
20+ @overload
821def interp_regrid (
922 data : xr .Dataset ,
1023 target_ds : xr .Dataset ,
1124 method : Literal ["linear" , "nearest" , "cubic" ],
1225) -> xr .Dataset :
26+ ...
27+
28+
29+ def interp_regrid (
30+ data : xr .DataArray | xr .Dataset ,
31+ target_ds : xr .Dataset ,
32+ method : Literal ["linear" , "nearest" , "cubic" ],
33+ ) -> xr .DataArray | xr .Dataset :
1334 """Refine a dataset using xarray's interp method.
1435
1536 Args:
@@ -29,10 +50,29 @@ def interp_regrid(
2950 )
3051
3152
53+ @overload
54+ def conservative_regrid (
55+ data : xr .DataArray ,
56+ target_ds : xr .Dataset ,
57+ latitude_coord : str | None ,
58+ ) -> xr .DataArray :
59+ ...
60+
61+
62+ @overload
3263def conservative_regrid (
3364 data : xr .Dataset ,
3465 target_ds : xr .Dataset ,
66+ latitude_coord : str | None ,
3567) -> xr .Dataset :
68+ ...
69+
70+
71+ def conservative_regrid (
72+ data : xr .DataArray | xr .Dataset ,
73+ target_ds : xr .Dataset ,
74+ latitude_coord : str | None ,
75+ ) -> xr .DataArray | xr .Dataset :
3676 """Refine a dataset using conservative regridding.
3777
3878 The method implementation is based on a post by Stephan Hoyer; "For the case of
@@ -49,34 +89,156 @@ def conservative_regrid(
4989 Returns:
5090 Regridded input dataset
5191 """
92+ if latitude_coord is not None :
93+ if latitude_coord not in data .coords :
94+ msg = "Latitude coord not in input data!"
95+ raise ValueError (msg )
96+ else :
97+ latitude_coord = ""
98+
99+ dim_order = list (target_ds .dims )
100+
52101 coord_names = set (target_ds .coords ).intersection (set (data .coords ))
53102 coords = {name : target_ds [name ] for name in coord_names }
54103 data = data .sortby (list (coord_names ))
55104
56- # TODO: filter out data vars lacking the target coordinates
105+ if isinstance (data , xr .Dataset ):
106+ return conservative_regrid_dataset (data , coords , latitude_coord ).transpose (
107+ * dim_order , ...
108+ )
109+ else :
110+ return conservative_regrid_dataarray (data , coords , latitude_coord ).transpose (
111+ * dim_order , ...
112+ )
113+
114+
115+ def conservative_regrid_dataset (
116+ data : xr .Dataset ,
117+ coords : dict [Hashable , xr .DataArray ],
118+ latitude_coord : str ,
119+ ) -> xr .Dataset :
120+ """Dataset implementation of the conservative regridding method."""
57121 data_vars : list [str ] = list (data .data_vars )
58122 dataarrays = [data [var ] for var in data_vars ]
59123
60124 for coord in coords :
61125 target_coords = coords [coord ].to_numpy ()
62- # TODO: better resolution/IntervalIndex inference
63- target_intervals = utils .to_intervalindex (
64- target_coords , resolution = target_coords [1 ] - target_coords [0 ]
65- )
66126 source_coords = data [coord ].to_numpy ()
67- source_intervals = utils .to_intervalindex (
68- source_coords , resolution = source_coords [1 ] - source_coords [0 ]
69- )
70- overlap = utils .overlap (source_intervals , target_intervals )
71- weights = utils .normalize_overlap (overlap )
127+ weights = get_weights (source_coords , target_coords )
128+
129+ # Modify weights to correct for latitude distortion
130+ if str (coord ) == latitude_coord :
131+ dot_array = utils .create_dot_dataarray (
132+ weights , str (coord ), target_coords , source_coords
133+ )
134+ dot_array = apply_spherical_correction (dot_array , latitude_coord )
135+ weights = dot_array .to_numpy ()
136+
137+ for i in range (len (dataarrays )):
138+ if coord in dataarrays [i ].coords :
139+ da = dataarrays [i ].transpose (coord , ...)
140+ dataarrays [i ] = apply_weights (da , weights , coord , target_coords )
72141
73- # TODO: Use `sparse.COO(weights)`. xr.dot does not support this. Much faster!
74- dot_array = utils .create_dot_dataarray (
75- weights , str (coord ), target_coords , source_coords
76- )
77- # TODO: modify weights to correct for latitude.
78- dataarrays = [
79- xr .dot (da , dot_array ).rename ({f"target_{ coord } " : coord }).rename (da .name )
80- for da in dataarrays
81- ]
82142 return xr .merge (dataarrays ) # TODO: add other coordinates/data variables back in.
143+
144+
145+ def conservative_regrid_dataarray (
146+ data : xr .DataArray ,
147+ coords : dict [Hashable , xr .DataArray ],
148+ latitude_coord : str ,
149+ ) -> xr .DataArray :
150+ """DataArray implementation of the conservative regridding method."""
151+ for coord in coords :
152+ if coord in data .coords :
153+ target_coords = coords [coord ].to_numpy ()
154+ source_coords = data [coord ].to_numpy ()
155+
156+ weights = get_weights (source_coords , target_coords )
157+
158+ # Modify weights to correct for latitude distortion
159+ if str (coord ) == latitude_coord :
160+ dot_array = utils .create_dot_dataarray (
161+ weights , str (coord ), target_coords , source_coords
162+ )
163+ dot_array = apply_spherical_correction (dot_array , latitude_coord )
164+ weights = dot_array .to_numpy ()
165+
166+ data = data .transpose (coord , ...)
167+ data = apply_weights (data , weights , coord , target_coords )
168+
169+ return data
170+
171+
172+ def apply_weights (
173+ da : xr .DataArray , weights : np .ndarray , coord_name : Hashable , new_coords : np .ndarray
174+ ) -> xr .DataArray :
175+ """Apply the weights to convert data to the new coordinates."""
176+ if da .chunks is not None :
177+ # Dask routine
178+ new_data = dask .array .einsum (
179+ "i...,ij->j..." , da .data , weights , optimize = "greedy"
180+ )
181+ else :
182+ # numpy routine
183+ new_data = np .einsum ("i...,ij->j..." , da .data , weights )
184+
185+ coord_mapping = {coord_name : new_coords }
186+ coords = list (da .dims )
187+ coords .remove (coord_name )
188+ for coord in coords :
189+ coord_mapping [coord ] = da [coord ].to_numpy ()
190+
191+ return xr .DataArray (
192+ data = new_data ,
193+ coords = coord_mapping ,
194+ name = da .name ,
195+ )
196+
197+
198+ def get_weights (source_coords : np .ndarray , target_coords : np .ndarray ) -> np .ndarray :
199+ """Determine the weights to map from the old coordinates to the new coordinates.
200+
201+ Args:
202+ source_coords: Source coordinates (center points)
203+ target_coords Target coordinates (center points)
204+
205+ Returns:
206+ Weights, which can be used with a dot product to apply the conservative regrid.
207+ """
208+ # TODO: better resolution/IntervalIndex inference
209+ target_intervals = utils .to_intervalindex (
210+ target_coords , resolution = target_coords [1 ] - target_coords [0 ]
211+ )
212+
213+ source_intervals = utils .to_intervalindex (
214+ source_coords , resolution = source_coords [1 ] - source_coords [0 ]
215+ )
216+ overlap = utils .overlap (source_intervals , target_intervals )
217+ return utils .normalize_overlap (overlap )
218+
219+
220+ def apply_spherical_correction (
221+ dot_array : xr .DataArray , latitude_coord : str
222+ ) -> xr .DataArray :
223+ """Apply a sperical earth correction on the prepared dot product weights."""
224+ da = dot_array .copy ()
225+ latitude_res = np .median (np .diff (dot_array [latitude_coord ].to_numpy (), 1 ))
226+ lat_weights = lat_weight (dot_array [latitude_coord ].to_numpy (), latitude_res )
227+ da .values = utils .normalize_overlap (dot_array .values * lat_weights [:, np .newaxis ])
228+ return da
229+
230+
231+ def lat_weight (latitude : np .ndarray , latitude_res : float ) -> np .ndarray :
232+ """Return the weight of gridcells based on their latitude.
233+
234+ Args:
235+ latitude: (Center) latitude values of the gridcells, in degrees.
236+ latitude_res: Resolution/width of the grid cells, in degrees.
237+
238+ Returns:
239+ Weights, same shape as latitude input.
240+ """
241+ dlat : float = np .radians (latitude_res )
242+ lat = np .radians (latitude )
243+ h = np .sin (lat + dlat / 2 ) - np .sin (lat - dlat / 2 )
244+ return h * dlat / (np .pi * 4 ) # type: ignore
0 commit comments