Skip to content

Commit fa78d47

Browse files
committed
reduction tests
1 parent a210221 commit fa78d47

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

xarray_array_testing/reduction.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from contextlib import nullcontext
2+
from types import ModuleType
3+
4+
import hypothesis.strategies as st
5+
import numpy as np
6+
import xarray.testing.strategies as xrst
7+
from hypothesis import given
8+
9+
10+
class ReductionTests:
11+
xp: ModuleType
12+
13+
@staticmethod
14+
def array_strategy_fn(*, shape, dtype):
15+
raise NotImplementedError
16+
17+
@staticmethod
18+
def assert_equal(a, b):
19+
np.testing.assert_allclose(a, b)
20+
21+
@staticmethod
22+
def expected_errors(op, **parameters):
23+
return nullcontext()
24+
25+
@given(st.data())
26+
def test_variable_mean(self, data):
27+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
28+
29+
with self.expected_errors("mean", variable=variable):
30+
actual = variable.mean().data
31+
expected = self.xp.mean(variable.data)
32+
33+
self.assert_equal(actual, expected)
34+
35+
@given(st.data())
36+
def test_variable_prod(self, data):
37+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
38+
39+
with self.expected_errors("prod", variable=variable):
40+
actual = variable.prod().data
41+
expected = self.xp.prod(variable.data)
42+
43+
self.assert_equal(actual, expected)

xarray_array_testing/tests/test_numpy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from xarray_array_testing.creation import CreationTests
5+
from xarray_array_testing.reduction import ReductionTests
56

67

78
def create_numpy_array(*, shape, dtype):
@@ -15,3 +16,11 @@ class TestCreationNumpy(CreationTests):
1516
@staticmethod
1617
def array_strategy_fn(*, shape, dtype):
1718
return create_numpy_array(shape=shape, dtype=dtype)
19+
20+
21+
class TestReductionNumpy(ReductionTests):
22+
xp = np
23+
24+
@staticmethod
25+
def array_strategy_fn(*, shape, dtype):
26+
return create_numpy_array(shape=shape, dtype=dtype)

0 commit comments

Comments
 (0)