Skip to content

Commit 1fbc074

Browse files
authored
Explicit support for curvefit (#337)
1 parent f24c1c3 commit 1fbc074

File tree

5 files changed

+77
-3
lines changed

5 files changed

+77
-3
lines changed

cf_xarray/accessor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
List,
1616
Mapping,
1717
MutableMapping,
18+
Sequence,
1819
TypeVar,
1920
Union,
2021
cast,
@@ -1059,6 +1060,52 @@ def _get_all_cell_measures(self):
10591060

10601061
return self._all_cell_measures
10611062

1063+
def curvefit(
1064+
self,
1065+
coords: str | DataArray | Iterable[str | DataArray],
1066+
func: Callable[..., Any],
1067+
reduce_dims: Hashable | Iterable[Hashable] = None,
1068+
skipna: bool = True,
1069+
p0: dict[str, Any] = None,
1070+
bounds: dict[str, Any] = None,
1071+
param_names: Sequence[str] = None,
1072+
kwargs: dict[str, Any] = None,
1073+
):
1074+
1075+
if coords is not None:
1076+
if isinstance(coords, str):
1077+
coords = (coords,)
1078+
coords = [
1079+
apply_mapper( # type: ignore
1080+
[_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore
1081+
)[
1082+
0
1083+
] # type: ignore
1084+
for v in coords
1085+
]
1086+
if reduce_dims is not None:
1087+
if isinstance(reduce_dims, Hashable):
1088+
reduce_dims: Iterable[Hashable] = (reduce_dims,) # type: ignore
1089+
reduce_dims = [
1090+
apply_mapper( # type: ignore
1091+
[_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore
1092+
)[
1093+
0
1094+
] # type: ignore
1095+
for v in reduce_dims # type: ignore
1096+
]
1097+
1098+
return self._obj.curvefit(
1099+
coords=coords,
1100+
func=func,
1101+
reduce_dims=reduce_dims,
1102+
skipna=skipna,
1103+
p0=p0,
1104+
bounds=bounds,
1105+
param_names=param_names,
1106+
kwargs=kwargs,
1107+
)
1108+
10621109
def _process_signature(
10631110
self,
10641111
func: Callable,

cf_xarray/tests/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def LooseVersion(vstring):
6363
return version.LooseVersion(vstring)
6464

6565

66-
has_pint, requires_pint = _importorskip("pint")
67-
has_shapely, requires_shapely = _importorskip("shapely")
6866
has_cftime, requires_cftime = _importorskip("cftime")
67+
has_scipy, requires_scipy = _importorskip("scipy")
68+
has_shapely, requires_shapely = _importorskip("shapely")
69+
has_pint, requires_pint = _importorskip("pint")

cf_xarray/tests/test_accessor.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
romsds,
3131
vert,
3232
)
33-
from . import raise_if_dask_computes, requires_cftime, requires_pint
33+
from . import raise_if_dask_computes, requires_cftime, requires_pint, requires_scipy
3434

3535
mpl.use("Agg")
3636

@@ -1619,3 +1619,26 @@ def test_cf_role():
16191619

16201620
dsg.foo.cf.plot(x="profile_id")
16211621
dsg.foo.cf.plot(x="trajectory_id")
1622+
1623+
1624+
@requires_scipy
1625+
def test_curvefit():
1626+
from cf_xarray.datasets import airds
1627+
1628+
def line(time, slope):
1629+
t = (time - time[0]).astype(float)
1630+
return slope * t
1631+
1632+
actual = airds.air.cf.isel(lat=4, lon=5).curvefit(coords=("time",), func=line)
1633+
expected = airds.air.cf.isel(lat=4, lon=5).cf.curvefit(coords="T", func=line)
1634+
assert_identical(expected, actual)
1635+
1636+
def plane(coords, slopex, slopey):
1637+
x, y = coords
1638+
return slopex * (x - x.mean()) + slopey * (y - y.mean())
1639+
1640+
actual = airds.air.isel(time=0).curvefit(coords=("lat", "lon"), func=plane)
1641+
expected = airds.air.isel(time=0).cf.curvefit(
1642+
coords=("latitude", "longitude"), func=plane
1643+
)
1644+
assert_identical(expected, actual)

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ dependencies:
1111
- pandas
1212
- pint
1313
- pooch
14+
- scipy
1415
- shapely
1516
- xarray

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ known_first_party = cf_xarray
4545
known_third_party = dask,matplotlib,numpy,pandas,pint,pytest,setuptools,sphinx_autosummary_accessors,xarray
4646

4747
# Most of the numerical computing stack doesn't have type annotations yet.
48+
[mypy]
49+
allow_redefinition = True
4850
[mypy-affine.*]
4951
ignore_missing_imports = True
5052
[mypy-bottleneck.*]

0 commit comments

Comments
 (0)