Skip to content

Commit 6681cab

Browse files
committed
Format code, improve typing. Add pandas-stubs.
1 parent acf692c commit 6681cab

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

benchmarks/benchmarking_conservative.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,8 @@
10951095
],
10961096
"source": [
10971097
"import numpy as np\n",
1098-
"rmse = np.sqrt(np.mean((data_regrid - data_cdo)**2))[\"tp\"].to_numpy() * 1000\n",
1098+
"\n",
1099+
"rmse = np.sqrt(np.mean((data_regrid - data_cdo) ** 2))[\"tp\"].to_numpy() * 1000\n",
10991100
"print(f\"RMSE: {rmse:.4f} mm\")"
11001101
]
11011102
},

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dev = [
5151
"mypy",
5252
"pytest",
5353
"pytest-cov",
54+
"pandas-stubs", # Adds typing for pandas.
5455
]
5556

5657
[tool.hatch.version]

src/xarray_regrid/methods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,18 @@ def conservative_regrid(
6161
target_coords = coords[coord].to_numpy()
6262
# TODO: better resolution/IntervalIndex inference
6363
target_intervals = utils.to_intervalindex(
64-
target_coords, resolution=target_coords[1]-target_coords[0]
64+
target_coords, resolution=target_coords[1] - target_coords[0]
6565
)
6666
source_coords = data[coord].to_numpy()
6767
source_intervals = utils.to_intervalindex(
68-
source_coords, resolution=source_coords[1]-source_coords[0]
68+
source_coords, resolution=source_coords[1] - source_coords[0]
6969
)
7070
overlap = utils.overlap(source_intervals, target_intervals)
7171
weights = utils.normalize_overlap(overlap)
7272

7373
# TODO: Use `sparse.COO(weights)`. xr.dot does not support this. Much faster!
7474
dot_array = utils.create_dot_dataarray(
75-
weights, coord, target_coords, source_coords
75+
weights, str(coord), target_coords, source_coords
7676
)
7777
# TODO: modify weights to correct for latitude.
7878
dataarrays = [

src/xarray_regrid/utils.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def to_intervalindex(coords: np.ndarray, resolution: float) -> pd.IntervalIndex:
102102
"""
103103
return pd.IntervalIndex(
104104
[
105-
pd.Interval(left=coord - resolution/2, right=coord + resolution/2)
105+
pd.Interval(left=coord - resolution / 2, right=coord + resolution / 2)
106106
for coord in coords
107107
]
108108
)
@@ -122,25 +122,25 @@ def overlap(a: pd.IntervalIndex, b: pd.IntervalIndex) -> np.ndarray:
122122
# TODO: newaxis on B and transpose is MUCH faster on benchmark.
123123
# likely due to it being the bigger dimension.
124124
# size(a) > size(b) leads to better perf than size(b) > size(a)
125-
mins = np.minimum(
126-
a.right.to_numpy(),
127-
b.right.to_numpy()[:, np.newaxis]
128-
)
129-
maxs = np.maximum(
130-
a.left.to_numpy(),
131-
b.left.to_numpy()[:, np.newaxis]
132-
)
133-
return np.maximum(mins-maxs, 0).T
125+
mins = np.minimum(a.right.to_numpy(), b.right.to_numpy()[:, np.newaxis])
126+
maxs = np.maximum(a.left.to_numpy(), b.left.to_numpy()[:, np.newaxis])
127+
overlap: np.ndarray = np.maximum(mins - maxs, 0).T
128+
return overlap
134129

135130

136131
def normalize_overlap(overlap: np.ndarray) -> np.ndarray:
137132
"""Normalize overlap values so they sum up to 1.0 along the first axis."""
138-
overlap_sum = overlap.sum(axis=0)
139-
overlap_sum[overlap_sum==0] = 1e-12 # Avoid dividing by 0.
140-
return (overlap / overlap_sum)
133+
overlap_sum: np.ndarray = overlap.sum(axis=0)
134+
overlap_sum[overlap_sum == 0] = 1e-12 # Avoid dividing by 0.
135+
return overlap / overlap_sum # type: ignore
141136

142137

143-
def create_dot_dataarray(weights, coord, target_coords, source_coords):
138+
def create_dot_dataarray(
139+
weights: np.ndarray,
140+
coord: str,
141+
target_coords: np.ndarray,
142+
source_coords: np.ndarray,
143+
) -> xr.DataArray:
144144
"""Create a DataArray to be used at dot product compatible with xr.dot."""
145145
return xr.DataArray(
146146
data=weights,

0 commit comments

Comments
 (0)