diff --git a/src/data_morph/shapes/factory.py b/src/data_morph/shapes/factory.py index d377df0c..5e4b205d 100644 --- a/src/data_morph/shapes/factory.py +++ b/src/data_morph/shapes/factory.py @@ -28,6 +28,7 @@ Club, DotsGrid, DownParabola, + FigureEight, Heart, LeftParabola, RightParabola, @@ -65,6 +66,7 @@ class ShapeFactory: Diamond, DotsGrid, DownParabola, + FigureEight, Heart, HighLines, HorizontalLines, diff --git a/src/data_morph/shapes/points/__init__.py b/src/data_morph/shapes/points/__init__.py index 373961a1..094eddf7 100644 --- a/src/data_morph/shapes/points/__init__.py +++ b/src/data_morph/shapes/points/__init__.py @@ -2,6 +2,7 @@ from .club import Club from .dots_grid import DotsGrid +from .figure_eight import FigureEight from .heart import Heart from .parabola import DownParabola, LeftParabola, RightParabola, UpParabola from .scatter import Scatter @@ -12,6 +13,7 @@ 'Club', 'DotsGrid', 'DownParabola', + 'FigureEight', 'Heart', 'LeftParabola', 'RightParabola', diff --git a/src/data_morph/shapes/points/figure_eight.py b/src/data_morph/shapes/points/figure_eight.py new file mode 100644 index 00000000..c143380f --- /dev/null +++ b/src/data_morph/shapes/points/figure_eight.py @@ -0,0 +1,53 @@ +"""Figure eight shape.""" + +import numpy as np + +from ...data.dataset import Dataset +from ..bases.point_collection import PointCollection + + +class FigureEight(PointCollection): + """ + Class for the figure eight shape. + + .. plot:: + :scale: 75 + :caption: + This shape is generated using the panda dataset. + + from data_morph.data.loader import DataLoader + from data_morph.shapes.points import FigureEight + + _ = FigureEight(DataLoader.load_dataset('panda')).plot() + + Parameters + ---------- + dataset : Dataset + The starting dataset to morph into other shapes. For datasets + with larger *y* ranges than *x* ranges, the figure eight will be + vertical; otherwise, it will be horizontal. + + Notes + ----- + This shape uses the formula for the `Lemniscate of Bernoulli + `_. + """ + + name = 'figure_eight' + + def __init__(self, dataset: Dataset) -> None: + x_shift, y_shift = dataset.data_bounds.center + x_range, y_range = dataset.data_bounds.range + + t = np.linspace(-3.1, 3.1, num=80) + + focal_distance = max(x_range, y_range) * 0.3 + half_width = focal_distance * np.sqrt(2) + + x = (half_width * np.cos(t)) / (1 + np.square(np.sin(t))) + y = x * np.sin(t) + + super().__init__( + *np.stack([x, y] if x_range >= y_range else [y, x], axis=1) + + np.array([x_shift, y_shift]) + ) diff --git a/tests/shapes/test_points.py b/tests/shapes/test_points.py index a9b705ff..11c70b59 100644 --- a/tests/shapes/test_points.py +++ b/tests/shapes/test_points.py @@ -190,3 +190,21 @@ class TestSpiral(PointsModuleTestBase): ((25, 65), 1.3042797087884075), ((-30, 100), 52.14470630148412), ) + + +class TestFigureEight(PointsModuleTestBase): + """Test the FigureEight class.""" + + shape_name = 'figure_eight' + distance_test_cases = ( + ((17.79641748, 67.34954701), 0), + ((21.71773824, 63.21594749), 0), + ((22.20358252, 67.34954701), 0), + ((19.26000438, 64.25495015), 0), + ((19.50182914, 77.69858052), 0), + ((0, 0), 55.70680898398098), + ((19, 61), 1.9727377843832639), + ((19, 64), 0.34685744033355576), + ((25, 65), 3.6523121397065657), + ((18, 40), 12.392782544116978), + )