Skip to content

Commit 30b2a0c

Browse files
committed
Add tests for the shape diagnostic function
1 parent 13b4f0e commit 30b2a0c

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/plotting/test_diagnostics.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Test the diagnostics module."""
2+
3+
import pytest
4+
from matplotlib.axes import Axes
5+
from matplotlib.patches import Rectangle
6+
7+
from data_morph.data.loader import DataLoader
8+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
9+
from data_morph.shapes.bases.line_collection import LineCollection
10+
from data_morph.shapes.bases.point_collection import PointCollection
11+
from data_morph.shapes.factory import ShapeFactory
12+
13+
14+
@pytest.mark.parametrize(
15+
('dataset_name', 'shape_name', 'show_bounds', 'alpha'),
16+
[
17+
('panda', 'heart', True, 0.4),
18+
('music', 'rectangle', False, 0.25),
19+
('sheep', 'circle', False, 0.5),
20+
],
21+
)
22+
def test_plot_shape_on_dataset(dataset_name, shape_name, show_bounds, alpha):
23+
"""Test the plot_shape_on_dataset() function."""
24+
dataset = DataLoader.load_dataset(dataset_name)
25+
shape = ShapeFactory(dataset).generate_shape(shape_name)
26+
ax = plot_shape_on_dataset(dataset, shape, show_bounds, alpha)
27+
28+
assert isinstance(ax, Axes)
29+
assert not ax.get_title()
30+
31+
assert ax.collections[0].get_alpha() == alpha
32+
33+
points_expected = dataset.data.shape[0]
34+
if isinstance(shape, PointCollection):
35+
points_expected += shape.points.shape[0]
36+
37+
points_plotted = sum(
38+
collection.get_offsets().data.shape[0] for collection in ax.collections
39+
)
40+
assert points_expected == points_plotted
41+
42+
if isinstance(shape, LineCollection):
43+
assert len(ax.lines) == len(shape.lines)
44+
45+
if show_bounds:
46+
assert sum(isinstance(patch, Rectangle) for patch in ax.patches) == 3

0 commit comments

Comments
 (0)