|
1 | 1 | """Test the dataset module.""" |
2 | 2 |
|
| 3 | +import matplotlib.pyplot as plt |
3 | 4 | import pandas as pd |
4 | 5 | import pytest |
| 6 | +from matplotlib.axes import Axes |
| 7 | +from matplotlib.patches import Rectangle |
5 | 8 | from numpy.testing import assert_equal |
6 | 9 | from pandas.testing import assert_frame_equal |
7 | 10 |
|
@@ -108,3 +111,41 @@ def test_repr(self, scale): |
108 | 111 |
|
109 | 112 | dataset = DataLoader.load_dataset('dino', scale=scale) |
110 | 113 | assert repr(dataset) == (f'<Dataset name=dino scaled={scale is not None}>') |
| 114 | + |
| 115 | + @pytest.mark.parametrize( |
| 116 | + ('ax', 'show_bounds', 'title', 'alpha'), |
| 117 | + [ |
| 118 | + (None, True, None, 1), |
| 119 | + (None, True, 'Custom title', 0.75), |
| 120 | + (plt.subplots()[1], False, 'default', 0.5), |
| 121 | + ], |
| 122 | + ) |
| 123 | + def test_plot(self, ax, show_bounds, title, alpha): |
| 124 | + """Test the plot() method.""" |
| 125 | + dataset = DataLoader.load_dataset('dino') |
| 126 | + ax = dataset.plot(ax=ax, show_bounds=show_bounds, title=title, alpha=alpha) |
| 127 | + |
| 128 | + assert isinstance(ax, Axes) |
| 129 | + assert pytest.approx(ax.get_aspect()) == 1.0 |
| 130 | + |
| 131 | + if title is None: |
| 132 | + assert not ax.get_title() |
| 133 | + elif title == 'default': |
| 134 | + assert ax.get_title() == repr(dataset) |
| 135 | + else: |
| 136 | + assert ax.get_title() == title |
| 137 | + |
| 138 | + assert ax.collections[0].get_alpha() == alpha |
| 139 | + |
| 140 | + points_expected = dataset.data.shape[0] |
| 141 | + points_plotted = sum( |
| 142 | + collection.get_offsets().data.shape[0] for collection in ax.collections |
| 143 | + ) |
| 144 | + assert points_expected == points_plotted |
| 145 | + |
| 146 | + if show_bounds: |
| 147 | + assert sum(isinstance(patch, Rectangle) for patch in ax.patches) == 3 |
| 148 | + assert ax.patches[0] != ax.patches[1] != ax.patches[2] |
| 149 | + |
| 150 | + assert len(labels := ax.texts) == 3 |
| 151 | + assert all(label.get_text().endswith('BOUNDS') for label in labels) |
0 commit comments