Skip to content

Commit 845286d

Browse files
committed
Vectorize the calculation of interval overlap
1 parent c6048a2 commit 845286d

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

src/xarray_regrid/methods.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Literal
22

3-
import numpy as np
43
import xarray as xr
54

65
from xarray_regrid import utils
@@ -68,16 +67,16 @@ def conservative_regrid(
6867
source_intervals = utils.to_intervalindex(
6968
source_coords, resolution=source_coords[1]-source_coords[0]
7069
)
71-
overlap = np.zeros((source_intervals.size, target_intervals.size), dtype=float)
72-
for i, source_iv in enumerate(source_intervals):
73-
for j, target_iv in enumerate(target_intervals):
74-
overlap[i,j] = utils.overlaps(source_iv, target_iv)
70+
overlap = utils.overlap(source_intervals, target_intervals)
7571
weights = utils.normalize_overlap(overlap)
7672

77-
dot_array = utils.create_dot_dataarray(weights, coord, target_coords, source_coords)
78-
73+
# TODO: Use `sparse.COO(weights)`. xr.dot does not support this. Much faster!
74+
dot_array = utils.create_dot_dataarray(
75+
weights, coord, target_coords, source_coords
76+
)
77+
# TODO: modify weights to correct for latitude.
7978
dataarrays = [
80-
da.dot(dot_array).rename({f"target_{coord}": coord}).rename(da.name)
79+
xr.dot(da, dot_array).rename({f"target_{coord}": coord}).rename(da.name)
8180
for da in dataarrays
8281
]
8382
return xr.merge(dataarrays) # TODO: add other coordinates/data variables back in.

src/xarray_regrid/utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,35 @@ def to_intervalindex(coords: np.ndarray, resolution: float) -> pd.IntervalIndex:
108108
)
109109

110110

111-
def overlaps(a: pd.Interval, b: pd.Interval):
112-
"""Return the overlap (fraction) between two Pandas intervals."""
113-
return max(
114-
min(a.right, b.right) - max(a.left, b.left),
115-
0
111+
def overlap(a: pd.IntervalIndex, b: pd.IntervalIndex) -> np.ndarray:
112+
"""Calculate the overlap between two sets of intervals.
113+
114+
Args:
115+
a: Pandas IntervalIndex containing the first set of intervals.
116+
b: Pandas IntervalIndex containing the second set of intervals.
117+
118+
Returns:
119+
2D numpy array containing overlap (as a fraction) between the intervals of a
120+
and b. If there is no overlap, the value will be 0.
121+
"""
122+
# TODO: newaxis on B and transpose is MUCH faster on benchmark.
123+
# likely due to it being the bigger dimension.
124+
# size(a) > size(b) leads to better perf than size(b) > size(a)
125+
mins = np.minimum(
126+
a.right.to_numpy(),
127+
b.right.to_numpy()[:, np.newaxis]
128+
)
129+
maxs = np.maximum(
130+
a.left.to_numpy(),
131+
b.left.to_numpy()[:, np.newaxis]
116132
)
133+
return np.maximum(mins-maxs, 0).T
117134

118135

119136
def normalize_overlap(overlap: np.ndarray) -> np.ndarray:
120137
"""Normalize overlap values so they sum up to 1.0 along the first axis."""
121138
overlap_sum = overlap.sum(axis=0)
122-
overlap_sum[overlap_sum==0] = 1e-6 # Avoid dividing by 0
139+
overlap_sum[overlap_sum==0] = 1e-12 # Avoid dividing by 0.
123140
return (overlap / overlap_sum)
124141

125142

0 commit comments

Comments
 (0)