Skip to content

Commit 7d43594

Browse files
committed
Simplify accessors structure
1 parent 61bea34 commit 7d43594

File tree

2 files changed

+46
-65
lines changed

2 files changed

+46
-65
lines changed

src/xarray_regrid/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from xarray_regrid.regrid import DataArrayRegridder, DatasetRegridder
1+
from xarray_regrid.regrid import Regridder
22
from xarray_regrid.utils import Grid, create_regridding_dataset
33

44
__all__ = [
55
"Grid",
6-
"DataArrayRegridder",
7-
"DatasetRegridder",
6+
"Regridder",
87
"create_regridding_dataset",
98
]

src/xarray_regrid/regrid.py

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
11
import xarray as xr
22

3-
from xarray_regrid import methods
3+
from xarray_regrid import methods, most_common
44

55

66
@xr.register_dataarray_accessor("regrid")
7-
class DataArrayRegridder:
7+
@xr.register_dataset_accessor("regrid")
8+
class Regridder:
89
"""Regridding xarray dataarrays.
910
1011
Available methods:
1112
linear: linear, bilinear, or higher dimensional linear interpolation.
1213
nearest: nearest-neighbor regridding.
1314
cubic: cubic spline regridding.
1415
conservative: conservative regridding.
16+
most_common: most common value regridder
1517
"""
1618

17-
def __init__(self, xarray_obj: xr.DataArray):
19+
def __init__(self, xarray_obj: xr.DataArray | xr.Dataset):
1820
self._obj = xarray_obj
1921

2022
def linear(
2123
self,
2224
ds_target_grid: xr.Dataset,
2325
time_dim: str = "time",
24-
) -> xr.DataArray:
25-
"""Return a dataset regridded linearily to the coords of the target dataset.
26+
) -> xr.DataArray | xr.Dataset:
27+
"""Regrid to the coords of the target dataset with linear interpolation.
2628
2729
Args:
2830
ds_target_grid: Dataset containing the target coordinates.
2931
time_dim: The name of the time dimension/coordinate
3032
3133
Returns:
32-
Dataset regridded to the target dataset coordinates.
34+
Data regridded to the target dataset coordinates.
3335
"""
3436
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
3537
return methods.interp_regrid(self._obj, ds_target_grid, "linear")
@@ -38,15 +40,15 @@ def nearest(
3840
self,
3941
ds_target_grid: xr.Dataset,
4042
time_dim: str = "time",
41-
) -> xr.DataArray:
42-
"""Return a dataset regridded by taking the values of the nearest target coords.
43+
) -> xr.DataArray | xr.Dataset:
44+
"""Regrid to the coords of the target with nearest-neighbor interpolation.
4345
4446
Args:
4547
ds_target_grid: Dataset containing the target coordinates.
4648
time_dim: The name of the time dimension/coordinate
4749
4850
Returns:
49-
Dataset regridded to the target dataset coordinates.
51+
Data regridded to the target dataset coordinates.
5052
"""
5153
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
5254
return methods.interp_regrid(self._obj, ds_target_grid, "nearest")
@@ -55,84 +57,64 @@ def cubic(
5557
self,
5658
ds_target_grid: xr.Dataset,
5759
time_dim: str = "time",
58-
) -> xr.DataArray:
60+
) -> xr.DataArray | xr.Dataset:
5961
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
60-
return methods.interp_regrid(self._obj, ds_target_grid, "cubic")
61-
62-
def conservative(
63-
self,
64-
ds_target_grid: xr.Dataset,
65-
latitude_coord: str | None,
66-
time_dim: str = "time",
67-
) -> xr.DataArray:
68-
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
69-
return methods.conservative_regrid(self._obj, ds_target_grid, latitude_coord)
70-
71-
72-
@xr.register_dataset_accessor("regrid")
73-
class DatasetRegridder:
74-
"""Regridding xarray datasets.
75-
76-
Available methods:
77-
linear: linear, bilinear, or higher dimensional linear interpolation.
78-
nearest: nearest-neighbor regridding.
79-
cubic: cubic spline regridding.
80-
conservative: conservative regridding.
81-
"""
82-
83-
def __init__(self, xarray_obj: xr.Dataset):
84-
self._obj = xarray_obj
85-
86-
def linear(
87-
self,
88-
ds_target_grid: xr.Dataset,
89-
time_dim: str = "time",
90-
) -> xr.Dataset:
91-
"""Return a dataset regridded linearily to the coords of the target dataset.
62+
"""Regrid to the coords of the target dataset with cubic interpolation.
9263
9364
Args:
9465
ds_target_grid: Dataset containing the target coordinates.
9566
time_dim: The name of the time dimension/coordinate
9667
9768
Returns:
98-
Dataset regridded to the target dataset coordinates.
69+
Data regridded to the target dataset coordinates.
9970
"""
100-
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
101-
return methods.interp_regrid(self._obj, ds_target_grid, "linear")
71+
return methods.interp_regrid(self._obj, ds_target_grid, "cubic")
10272

103-
def nearest(
73+
def conservative(
10474
self,
10575
ds_target_grid: xr.Dataset,
76+
latitude_coord: str | None,
10677
time_dim: str = "time",
107-
) -> xr.Dataset:
108-
"""Return a dataset regridded by taking the values of the nearest target coords.
78+
) -> xr.DataArray | xr.Dataset:
79+
"""Regrid to the coords of the target dataset with a conservative scheme.
10980
11081
Args:
11182
ds_target_grid: Dataset containing the target coordinates.
112-
time_dim: The name of the time dimension/coordinate
83+
latitude_coord: Name of the latitude coord, to be used for applying the
84+
spherical correction.
85+
time_dim: The name of the time dimension/coordinate.
11386
11487
Returns:
115-
Dataset regridded to the target dataset coordinates.
88+
Data regridded to the target dataset coordinates.
11689
"""
117-
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
118-
return methods.interp_regrid(self._obj, ds_target_grid, "nearest")
11990

120-
def cubic(
121-
self,
122-
ds_target_grid: xr.Dataset,
123-
time_dim: str = "time",
124-
) -> xr.Dataset:
12591
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
126-
return methods.interp_regrid(self._obj, ds_target_grid, "cubic")
92+
return methods.conservative_regrid(self._obj, ds_target_grid, latitude_coord)
12793

128-
def conservative(
94+
def most_common(
12995
self,
13096
ds_target_grid: xr.Dataset,
131-
latitude_coord: str | None,
13297
time_dim: str = "time",
133-
) -> xr.Dataset:
98+
max_mem: int = 1e9
99+
) -> xr.DataArray | xr.Dataset:
100+
"""Regrid by taking the most common value within the new grid cells.
101+
102+
To be used for regridding data to a much coarser resolution.
103+
104+
Args:
105+
ds_target_grid: Target grid dataset
106+
time_dim: Name of the time dimension. Defaults to "time".
107+
max_mem: (Approximate) maximum memory in bytes that the regridding routine
108+
can use. Note that this is not the total memory consumption and does not
109+
include the size of the final dataset. Defaults to 1e9 (1 GB).
110+
111+
Returns:
112+
Regridded data.
113+
"""
134114
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
135-
return methods.conservative_regrid(self._obj, ds_target_grid, latitude_coord)
115+
return most_common.most_common_wrapper(
116+
self._obj, ds_target_grid, time_dim, max_mem
117+
)
136118

137119

138120
def validate_input(

0 commit comments

Comments
 (0)