Skip to content

Commit 4a5d39a

Browse files
committed
Add tests
1 parent 13847b2 commit 4a5d39a

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

cf_xarray/tests/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import re
2+
from contextlib import contextmanager
3+
4+
import dask
5+
import pytest
6+
7+
8+
@contextmanager
9+
def raises_regex(error, pattern):
10+
__tracebackhide__ = True
11+
with pytest.raises(error) as excinfo:
12+
yield
13+
message = str(excinfo.value)
14+
if not re.search(pattern, message):
15+
raise AssertionError(
16+
f"exception {excinfo.value!r} did not match pattern {pattern!r}"
17+
)
18+
19+
20+
class CountingScheduler:
21+
""" Simple dask scheduler counting the number of computes.
22+
23+
Reference: https://stackoverflow.com/questions/53289286/ """
24+
25+
def __init__(self, max_computes=0):
26+
self.total_computes = 0
27+
self.max_computes = max_computes
28+
29+
def __call__(self, dsk, keys, **kwargs):
30+
self.total_computes += 1
31+
if self.total_computes > self.max_computes:
32+
raise RuntimeError(
33+
"Too many computes. Total: %d > max: %d."
34+
% (self.total_computes, self.max_computes)
35+
)
36+
return dask.get(dsk, keys, **kwargs)
37+
38+
39+
def raise_if_dask_computes(max_computes=0):
40+
scheduler = CountingScheduler(max_computes)
41+
return dask.config.set(scheduler=scheduler)

cf_xarray/tests/test_accessor.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
import xarray as xr
3+
from xarray.testing import assert_identical
4+
5+
import cf_xarray # noqa
6+
7+
from . import raise_if_dask_computes
8+
9+
ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4))
10+
objects = [
11+
pytest.param(ds, marks=pytest.mark.xfail),
12+
ds.air,
13+
pytest.param(ds.chunk({"lat": 5}), marks=pytest.mark.xfail),
14+
ds.air.chunk({"lat": 5}),
15+
]
16+
17+
18+
@pytest.mark.parametrize("obj", objects)
19+
def test_wrapped_classes(obj):
20+
with raise_if_dask_computes():
21+
expected = obj.resample(time="M").mean("lat")
22+
actual = obj.cf.resample(T="M").mean("Y")
23+
assert_identical(expected, actual)
24+
25+
# groupby
26+
# rolling
27+
# coarsen
28+
# weighted
29+
30+
31+
@pytest.mark.parametrize("obj", objects)
32+
def test_other_methods(obj):
33+
with raise_if_dask_computes():
34+
expected = obj.isel(time=slice(2))
35+
actual = obj.cf.isel(T=slice(2))
36+
assert_identical(expected, actual)
37+
38+
with raise_if_dask_computes():
39+
expected = obj.sum("time")
40+
actual = obj.cf.sum("T")
41+
assert_identical(expected, actual)
42+
43+
44+
@pytest.mark.parametrize("obj", objects)
45+
def test_plot(obj):
46+
obj.isel(time=1).cf.plot(x="X", y="Y")
47+
obj.isel(time=1).cf.plot.contourf(x="X", y="Y")
48+
49+
obj.cf.plot(x="X", y="Y", col="T")
50+
obj.cf.plot.contourf(x="X", y="Y", col="T")
51+
52+
obj.isel(lat=[0, 1], lon=1).cf.plot.line(x="T", hue="Y")

ci/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: xcoare-test
1+
name: cf_xarray_test
22
channels:
33
- conda-forge
44
dependencies:

0 commit comments

Comments
 (0)