Skip to content

Commit 2b18796

Browse files
authored
Cleanup shape test files (#271)
1 parent 2f78dfc commit 2b18796

31 files changed

+698
-544
lines changed

tests/shapes/circles/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Test data_morph.shapes.circles subpackage."""

tests/shapes/circles/bases.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Base test class for circle shapes."""
2+
3+
from __future__ import annotations
4+
5+
import re
6+
from typing import TYPE_CHECKING
7+
8+
import pytest
9+
10+
if TYPE_CHECKING:
11+
from numbers import Number
12+
13+
CIRCLE_REPR = r'<Circle center=\((\d+\.*\d*), (\d+\.*\d*)\) radius=(\d+\.*\d*)>'
14+
15+
16+
class CirclesModuleTestBase:
17+
"""Base for testing circle shapes."""
18+
19+
shape_name: str
20+
distance_test_cases: tuple[tuple[tuple[Number], float]]
21+
repr_regex: str
22+
23+
@pytest.fixture(scope='class')
24+
def shape(self, shape_factory):
25+
"""Fixture to get the shape for testing."""
26+
return shape_factory.generate_shape(self.shape_name)
27+
28+
def test_distance(self, shape, test_point, expected_distance):
29+
"""
30+
Test the distance() method parametrized by distance_test_cases
31+
(see conftest.py).
32+
"""
33+
assert pytest.approx(shape.distance(*test_point)) == expected_distance
34+
35+
def test_repr(self, shape):
36+
"""Test that the __repr__() method is working."""
37+
assert re.match(self.repr_regex, repr(shape)) is not None
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Test the bullseye module."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from .bases import CIRCLE_REPR, CirclesModuleTestBase
7+
8+
pytestmark = [pytest.mark.shapes, pytest.mark.circles]
9+
10+
11+
class TestBullseye(CirclesModuleTestBase):
12+
"""Test the Bullseye class."""
13+
14+
shape_name = 'bullseye'
15+
distance_test_cases = (((20, 50), 3.660254), ((10, 25), 9.08004))
16+
repr_regex = (
17+
r'^<Bullseye>\n'
18+
r' circles=\n'
19+
r' ' + CIRCLE_REPR + '\n'
20+
r' ' + CIRCLE_REPR + '$'
21+
)
22+
23+
def test_init(self, shape):
24+
"""Test that the Bullseye contains two concentric circles."""
25+
assert len(shape.circles) == 2
26+
27+
a, b = shape.circles
28+
assert np.array_equal(a.center, b.center)
29+
assert a.radius != b.radius
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Test the circle module."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from .bases import CIRCLE_REPR, CirclesModuleTestBase
7+
8+
pytestmark = [pytest.mark.shapes, pytest.mark.circles]
9+
10+
11+
class TestCircle(CirclesModuleTestBase):
12+
"""Test the Circle class."""
13+
14+
shape_name = 'circle'
15+
distance_test_cases = (((20, 50), 10.490381), ((10, 25), 15.910168))
16+
repr_regex = '^' + CIRCLE_REPR + '$'
17+
18+
def test_is_circle(self, shape):
19+
"""Test that the Circle is a valid circle (mathematically)."""
20+
angles = np.arange(0, 361, 45)
21+
cx, cy = shape.center
22+
for x, y in zip(
23+
cx + shape.radius * np.cos(angles),
24+
cy + shape.radius * np.sin(angles),
25+
):
26+
assert pytest.approx(shape.distance(x, y)) == 0

tests/shapes/circles/test_rings.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Test the rings module."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from .bases import CIRCLE_REPR, CirclesModuleTestBase
7+
8+
pytestmark = [pytest.mark.shapes, pytest.mark.circles]
9+
10+
11+
class TestRings(CirclesModuleTestBase):
12+
"""Test the Rings class."""
13+
14+
shape_name = 'rings'
15+
distance_test_cases = (((20, 50), 3.16987), ((10, 25), 9.08004))
16+
repr_regex = (
17+
r'^<Rings>\n'
18+
r' circles=\n'
19+
r' ' + CIRCLE_REPR + '\n'
20+
r' ' + CIRCLE_REPR + '\n'
21+
r' ' + CIRCLE_REPR + '\n'
22+
r' ' + CIRCLE_REPR + '$'
23+
)
24+
25+
@pytest.mark.parametrize('num_rings', [3, 5])
26+
def test_init(self, shape_factory, num_rings):
27+
"""Test that the Rings contains multiple concentric circles."""
28+
shape = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)
29+
30+
assert len(shape.circles) == num_rings
31+
assert all(
32+
np.array_equal(circle.center, shape.circles[0].center)
33+
for circle in shape.circles[1:]
34+
)
35+
assert len({circle.radius for circle in shape.circles}) == num_rings
36+
37+
@pytest.mark.parametrize('num_rings', ['3', -5, 1, True])
38+
def test_num_rings_is_valid(self, shape_factory, num_rings):
39+
"""Test that num_rings input validation is working."""
40+
if isinstance(num_rings, int):
41+
with pytest.raises(ValueError, match='num_rings must be greater than 1'):
42+
_ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)
43+
else:
44+
with pytest.raises(TypeError, match='num_rings must be an integer'):
45+
_ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)

tests/shapes/lines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Test data_morph.shapes.lines subpackage."""

tests/shapes/lines/bases.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Base test classes for line shapes."""
2+
3+
from __future__ import annotations
4+
5+
from numbers import Number
6+
7+
import numpy as np
8+
import pytest
9+
10+
11+
class LinesModuleTestBase:
12+
"""Base for testing line-based shapes."""
13+
14+
shape_name: str
15+
distance_test_cases: tuple[tuple[tuple[Number], float]]
16+
expected_line_count: int
17+
expected_slopes: tuple[Number] | Number
18+
19+
@pytest.fixture(scope='class')
20+
def shape(self, shape_factory):
21+
"""Fixture to get the shape for testing."""
22+
return shape_factory.generate_shape(self.shape_name)
23+
24+
@pytest.fixture(scope='class')
25+
def slopes(self, shape):
26+
"""Fixture to get the slopes of the lines."""
27+
xs, ys = np.array(shape.lines).T
28+
runs = np.diff(xs, axis=0)
29+
rises = np.diff(ys, axis=0)
30+
slopes = rises / np.ma.masked_array(runs, mask=runs == 0)
31+
return slopes.filled(np.inf)
32+
33+
def test_init(self, shape):
34+
"""Test that the shape consists of the correct number of distinct lines."""
35+
num_unique_lines, *_ = np.unique(shape.lines, axis=0).shape
36+
assert num_unique_lines == self.expected_line_count
37+
38+
def test_distance(self, shape, test_point, expected_distance):
39+
"""
40+
Test the distance() method parametrized by distance_test_cases
41+
(see conftest.py).
42+
"""
43+
assert pytest.approx(shape.distance(*test_point)) == expected_distance
44+
45+
def test_slopes(self, slopes):
46+
"""Test that the slopes are as expected."""
47+
expected = (
48+
[self.expected_slopes]
49+
if isinstance(self.expected_slopes, Number)
50+
else self.expected_slopes
51+
)
52+
assert np.array_equal(np.unique(slopes), expected)
53+
54+
55+
class ParallelLinesModuleTestBase(LinesModuleTestBase):
56+
"""Base for testing parallel line-based shapes."""
57+
58+
def test_lines_are_parallel(self, slopes):
59+
"""Test that the lines are parallel (slopes are equal)."""
60+
assert np.unique(slopes).size == 1
61+
62+
63+
class PolygonsLineModuleTestBase:
64+
"""Base for testing polygon shapes."""
65+
66+
shape_name: str
67+
distance_test_cases: tuple[tuple[tuple[Number], float]]
68+
expected_line_count: int
69+
70+
@pytest.fixture(scope='class')
71+
def shape(self, shape_factory):
72+
"""Fixture to get the shape for testing."""
73+
return shape_factory.generate_shape(self.shape_name)
74+
75+
@pytest.fixture(scope='class')
76+
def slopes(self, shape):
77+
"""Fixture to get the slopes of the lines."""
78+
xs, ys = np.array(shape.lines).T
79+
runs = np.diff(xs, axis=0)
80+
rises = np.diff(ys, axis=0)
81+
slopes = rises / np.ma.masked_array(runs, mask=runs == 0)
82+
return slopes.filled(np.inf)
83+
84+
def test_init(self, shape):
85+
"""Test that the shape consists of the correct number of distinct lines."""
86+
num_unique_lines, *_ = np.unique(shape.lines, axis=0).shape
87+
assert num_unique_lines == self.expected_line_count
88+
89+
def test_distance(self, shape, test_point, expected_distance):
90+
"""
91+
Test the distance() method parametrized by distance_test_cases
92+
(see conftest.py).
93+
"""
94+
assert pytest.approx(shape.distance(*test_point)) == expected_distance
95+
96+
def test_lines_form_polygon(self, shape):
97+
"""Test that the lines form a polygon."""
98+
endpoints = np.array(shape.lines).reshape(-1, 2)
99+
assert np.unique(endpoints, axis=0).shape[0] == self.expected_line_count

tests/shapes/lines/test_diamond.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Test the diamond module."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from .bases import PolygonsLineModuleTestBase
7+
8+
pytestmark = [pytest.mark.shapes, pytest.mark.lines, pytest.mark.polygons]
9+
10+
11+
class TestDiamond(PolygonsLineModuleTestBase):
12+
"""Test the Diamond class."""
13+
14+
shape_name = 'diamond'
15+
distance_test_cases = (((20, 50), 0.0), ((30, 60), 2.773501))
16+
expected_line_count = 4
17+
18+
def test_slopes(self, slopes):
19+
"""Test that the slopes are as expected."""
20+
np.testing.assert_array_equal(np.sort(slopes).flatten(), [-1.5, -1.5, 1.5, 1.5])
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Test the high_lines module."""
2+
3+
import pytest
4+
5+
from .bases import ParallelLinesModuleTestBase
6+
7+
pytestmark = [pytest.mark.shapes, pytest.mark.lines]
8+
9+
10+
class TestHighLines(ParallelLinesModuleTestBase):
11+
"""Test the HighLines class."""
12+
13+
shape_name = 'high_lines'
14+
distance_test_cases = (((20, 50), 6.0), ((30, 60), 4.0))
15+
expected_line_count = 2
16+
expected_slopes = 0
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Test the horizontal_lines module."""
2+
3+
import pytest
4+
5+
from .bases import ParallelLinesModuleTestBase
6+
7+
pytestmark = [pytest.mark.shapes, pytest.mark.lines]
8+
9+
10+
class TestHorizontalLines(ParallelLinesModuleTestBase):
11+
"""Test the HorizontalLines class."""
12+
13+
shape_name = 'h_lines'
14+
distance_test_cases = (((20, 50), 0.0), ((30, 60), 2.5))
15+
expected_line_count = 5
16+
expected_slopes = 0

0 commit comments

Comments
 (0)