|
| 1 | +"""Test the diagnostics module.""" |
| 2 | + |
| 3 | +import pytest |
| 4 | +from matplotlib.axes import Axes |
| 5 | +from matplotlib.patches import Rectangle |
| 6 | + |
| 7 | +from data_morph.data.loader import DataLoader |
| 8 | +from data_morph.plotting.diagnostics import plot_shape_on_dataset |
| 9 | +from data_morph.shapes.bases.line_collection import LineCollection |
| 10 | +from data_morph.shapes.bases.point_collection import PointCollection |
| 11 | +from data_morph.shapes.factory import ShapeFactory |
| 12 | + |
| 13 | + |
| 14 | +@pytest.mark.parametrize( |
| 15 | + ('dataset_name', 'shape_name', 'show_bounds', 'alpha'), |
| 16 | + [ |
| 17 | + ('panda', 'heart', True, 0.4), |
| 18 | + ('music', 'rectangle', False, 0.25), |
| 19 | + ('sheep', 'circle', False, 0.5), |
| 20 | + ], |
| 21 | +) |
| 22 | +def test_plot_shape_on_dataset(dataset_name, shape_name, show_bounds, alpha): |
| 23 | + """Test the plot_shape_on_dataset() function.""" |
| 24 | + dataset = DataLoader.load_dataset(dataset_name) |
| 25 | + shape = ShapeFactory(dataset).generate_shape(shape_name) |
| 26 | + ax = plot_shape_on_dataset(dataset, shape, show_bounds, alpha) |
| 27 | + |
| 28 | + assert isinstance(ax, Axes) |
| 29 | + assert not ax.get_title() |
| 30 | + |
| 31 | + assert ax.collections[0].get_alpha() == alpha |
| 32 | + |
| 33 | + points_expected = dataset.data.shape[0] |
| 34 | + if isinstance(shape, PointCollection): |
| 35 | + points_expected += shape.points.shape[0] |
| 36 | + |
| 37 | + points_plotted = sum( |
| 38 | + collection.get_offsets().data.shape[0] for collection in ax.collections |
| 39 | + ) |
| 40 | + assert points_expected == points_plotted |
| 41 | + |
| 42 | + if isinstance(shape, LineCollection): |
| 43 | + assert len(ax.lines) == len(shape.lines) |
| 44 | + |
| 45 | + if show_bounds: |
| 46 | + assert sum(isinstance(patch, Rectangle) for patch in ax.patches) == 3 |
0 commit comments