Skip to content

Commit 7d2b741

Browse files
authored
Add diagnostic function to visualize shapes superimposed on datasets (#331)
- Add diagnostic function to visualize shapes superimposed on datasets - Update shape creation tutorial to show example of using the shape diagnostic function - Update all shape docs to use the diagnostic function - Add tests for `LineCollection` and `Dataset` `plot()` methods
1 parent 96689c6 commit 7d2b741

28 files changed

+320
-29
lines changed

docs/tutorials/shape-creation.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ shape inherits from :class:`.LineCollection` and uses the morph bounds
6666
Since we inherit from :class:`.LineCollection` here, we don't need to define
6767
the ``distance()`` and ``plot()`` methods (unless we want to override them).
6868

69+
.. tip::
70+
You can use the :func:`.plot_shape_on_dataset` function to visualize your
71+
shape's positioning relative to a given dataset. Your shape can exceed the data
72+
bounds (:attr:`.Dataset.data_bounds`); however, it should not exceed the morph
73+
bounds (:attr:`.Dataset.morph_bounds`):
74+
75+
.. plot::
76+
:scale: 75
77+
:include-source:
78+
:caption:
79+
Visualization of the :class:`.XLines` shape when calculated based on the
80+
music :class:`.Dataset`, with the dataset's bounds.
81+
82+
from data_morph.data.loader import DataLoader
83+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
84+
from data_morph.shapes.lines import XLines
85+
86+
87+
dataset = DataLoader.load_dataset('music')
88+
shape = XLines(dataset)
89+
plot_shape_on_dataset(dataset, shape, show_bounds=True, alpha=0.1)
90+
6991
Test out the shape
7092
------------------
7193

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,6 @@ exclude = [ # don't report on checks for these
187187
'\.__repr__$',
188188
'\.__str__$',
189189
]
190+
override_SS05 = [ # allow docstrings to start with these words
191+
'^Unambiguous ',
192+
]

src/data_morph/data/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def plot(
201201
ax: Axes | None = None,
202202
show_bounds: bool = True,
203203
title: str | None = 'default',
204+
alpha: Number = 1,
204205
) -> Axes:
205206
"""
206207
Plot the dataset and its bounds.
@@ -214,6 +215,8 @@ def plot(
214215
title : str | ``None``, optional
215216
Title to use for the plot. The default will call ``str()`` on the
216217
Dataset. Pass ``None`` to leave the plot untitled.
218+
alpha : Number, default ``1``
219+
The transparency to use for the points in the plot.
217220
218221
Returns
219222
-------
@@ -225,7 +228,7 @@ def plot(
225228
fig.get_layout_engine().set(w_pad=0.2, h_pad=0.2)
226229

227230
ax.axis('equal')
228-
ax.scatter(self.data.x, self.data.y, s=2, color='black')
231+
ax.scatter(self.data.x, self.data.y, s=2, color='black', alpha=alpha)
229232
ax.set(xlabel='', ylabel='', title=self if title == 'default' else title)
230233

231234
if show_bounds:
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Diagnostic plot to visualize a shape superimposed on the dataset."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from ..plotting.style import plot_with_custom_style
8+
9+
if TYPE_CHECKING:
10+
from numbers import Number
11+
12+
from matplotlib.axes import Axes
13+
14+
from ..data.dataset import Dataset
15+
from ..shapes.bases.shape import Shape
16+
17+
18+
@plot_with_custom_style
19+
def plot_shape_on_dataset(
20+
dataset: Dataset,
21+
shape: Shape,
22+
show_bounds: bool = False,
23+
alpha: Number = 0.25,
24+
) -> Axes:
25+
"""
26+
Plot a shape superimposed on a dataset to evaluate heuristics.
27+
28+
Parameters
29+
----------
30+
dataset : Dataset
31+
The dataset that ``shape`` was instantiated with.
32+
shape : Shape
33+
The shape that was instantiated with ``dataset``.
34+
show_bounds : bool, default ``False``
35+
Whether to include the dataset's bounds in the plot.
36+
alpha : Number, default ``0.25``
37+
The transparency to use for the dataset's points.
38+
39+
Returns
40+
-------
41+
matplotlib.axes.Axes
42+
The :class:`~matplotlib.axes.Axes` object containing the plot.
43+
44+
Examples
45+
--------
46+
47+
.. plot::
48+
:scale: 75
49+
:include-source:
50+
:caption:
51+
Visualization of the :class:`.Star` shape when calculated based on the
52+
music :class:`.Dataset`, with the dataset's bounds.
53+
54+
from data_morph.data.loader import DataLoader
55+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
56+
from data_morph.shapes.lines import Star
57+
58+
dataset = DataLoader.load_dataset('music')
59+
shape = Star(dataset)
60+
plot_shape_on_dataset(dataset, shape, show_bounds=True, alpha=0.1)
61+
62+
.. plot::
63+
:scale: 75
64+
:include-source:
65+
:caption:
66+
Visualization of the :class:`.Heart` shape when calculated based on the
67+
music :class:`.Dataset`, without the dataset's bounds.
68+
69+
from data_morph.data.loader import DataLoader
70+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
71+
from data_morph.shapes.points import Heart
72+
73+
dataset = DataLoader.load_dataset('music')
74+
shape = Heart(dataset)
75+
plot_shape_on_dataset(dataset, shape, alpha=0.1)
76+
"""
77+
ax = dataset.plot(show_bounds=show_bounds, title=None, alpha=alpha)
78+
shape.plot(ax=ax)
79+
return ax

src/data_morph/shapes/circles/bullseye.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ class Bullseye(Rings):
2222
This shape is generated using the panda dataset.
2323
2424
from data_morph.data.loader import DataLoader
25+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
2526
from data_morph.shapes.circles import Bullseye
2627
27-
_ = Bullseye(DataLoader.load_dataset('panda')).plot()
28+
dataset = DataLoader.load_dataset('panda')
29+
shape = Bullseye(dataset)
30+
plot_shape_on_dataset(dataset, shape, show_bounds=False, alpha=0.25)
2831
2932
See Also
3033
--------

src/data_morph/shapes/circles/circle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ class Circle(Shape):
2828
This shape is generated using the panda dataset.
2929
3030
from data_morph.data.loader import DataLoader
31+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
3132
from data_morph.shapes.circles import Circle
3233
33-
_ = Circle(DataLoader.load_dataset('panda')).plot()
34+
dataset = DataLoader.load_dataset('panda')
35+
shape = Circle(dataset)
36+
plot_shape_on_dataset(dataset, shape, show_bounds=False, alpha=0.25)
3437
3538
Parameters
3639
----------

src/data_morph/shapes/circles/rings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ class Rings(Shape):
2828
This shape is generated using the panda dataset.
2929
3030
from data_morph.data.loader import DataLoader
31+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
3132
from data_morph.shapes.circles import Rings
3233
33-
_ = Rings(DataLoader.load_dataset('panda')).plot()
34+
dataset = DataLoader.load_dataset('panda')
35+
shape = Rings(dataset)
36+
plot_shape_on_dataset(dataset, shape, show_bounds=False, alpha=0.25)
3437
3538
Parameters
3639
----------

src/data_morph/shapes/lines/diamond.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ class Diamond(LineCollection):
1313
:caption:
1414
This shape is generated using the panda dataset.
1515
16-
import matplotlib.pyplot as plt
1716
from data_morph.data.loader import DataLoader
17+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
1818
from data_morph.shapes.lines import Diamond
1919
20-
_ = Diamond(DataLoader.load_dataset('panda')).plot()
20+
dataset = DataLoader.load_dataset('panda')
21+
shape = Diamond(dataset)
22+
plot_shape_on_dataset(dataset, shape, show_bounds=False, alpha=0.25)
2123
2224
Parameters
2325
----------

src/data_morph/shapes/lines/high_lines.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ class HighLines(LineCollection):
1414
This shape is generated using the panda dataset.
1515
1616
from data_morph.data.loader import DataLoader
17+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
1718
from data_morph.shapes.lines import HighLines
1819
19-
_ = HighLines(DataLoader.load_dataset('panda')).plot()
20+
dataset = DataLoader.load_dataset('panda')
21+
shape = HighLines(dataset)
22+
plot_shape_on_dataset(dataset, shape, show_bounds=False, alpha=0.25)
2023
2124
Parameters
2225
----------

src/data_morph/shapes/lines/horizontal_lines.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ class HorizontalLines(LineCollection):
1616
This shape is generated using the panda dataset.
1717
1818
from data_morph.data.loader import DataLoader
19+
from data_morph.plotting.diagnostics import plot_shape_on_dataset
1920
from data_morph.shapes.lines import HorizontalLines
2021
21-
_ = HorizontalLines(DataLoader.load_dataset('panda')).plot()
22+
dataset = DataLoader.load_dataset('panda')
23+
shape = HorizontalLines(dataset)
24+
plot_shape_on_dataset(dataset, shape, show_bounds=False, alpha=0.25)
2225
2326
Parameters
2427
----------

0 commit comments

Comments
 (0)