Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Club,
DotsGrid,
DownParabola,
FigureEight,
Heart,
LeftParabola,
RightParabola,
Expand Down Expand Up @@ -65,6 +66,7 @@ class ShapeFactory:
Diamond,
DotsGrid,
DownParabola,
FigureEight,
Heart,
HighLines,
HorizontalLines,
Expand Down
2 changes: 2 additions & 0 deletions src/data_morph/shapes/points/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .club import Club
from .dots_grid import DotsGrid
from .figure_eight import FigureEight
from .heart import Heart
from .parabola import DownParabola, LeftParabola, RightParabola, UpParabola
from .scatter import Scatter
Expand All @@ -12,6 +13,7 @@
'Club',
'DotsGrid',
'DownParabola',
'FigureEight',
'Heart',
'LeftParabola',
'RightParabola',
Expand Down
53 changes: 53 additions & 0 deletions src/data_morph/shapes/points/figure_eight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Figure eight shape."""

import numpy as np

from ...data.dataset import Dataset
from ..bases.point_collection import PointCollection


class FigureEight(PointCollection):
"""
Class for the figure eight shape.

.. plot::
:scale: 75
:caption:
This shape is generated using the panda dataset.

from data_morph.data.loader import DataLoader
from data_morph.shapes.points import FigureEight

_ = FigureEight(DataLoader.load_dataset('panda')).plot()

Parameters
----------
dataset : Dataset
The starting dataset to morph into other shapes. For datasets
with larger *y* ranges than *x* ranges, the figure eight will be
vertical; otherwise, it will be horizontal.

Notes
-----
This shape uses the formula for the `Lemniscate of Bernoulli
<https://en.wikipedia.org/wiki/Lemniscate_of_Bernoulli>`_.
"""

name = 'figure_eight'

def __init__(self, dataset: Dataset) -> None:
x_shift, y_shift = dataset.data_bounds.center
x_range, y_range = dataset.data_bounds.range

t = np.linspace(-3.1, 3.1, num=80)

focal_distance = max(x_range, y_range) * 0.3
half_width = focal_distance * np.sqrt(2)

x = (half_width * np.cos(t)) / (1 + np.square(np.sin(t)))
y = x * np.sin(t)

super().__init__(
*np.stack([x, y] if x_range >= y_range else [y, x], axis=1)
+ np.array([x_shift, y_shift])
)
18 changes: 18 additions & 0 deletions tests/shapes/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,21 @@ class TestSpiral(PointsModuleTestBase):
((25, 65), 1.3042797087884075),
((-30, 100), 52.14470630148412),
)


class TestFigureEight(PointsModuleTestBase):
"""Test the FigureEight class."""

shape_name = 'figure_eight'
distance_test_cases = (
((17.79641748, 67.34954701), 0),
((21.71773824, 63.21594749), 0),
((22.20358252, 67.34954701), 0),
((19.26000438, 64.25495015), 0),
((19.50182914, 77.69858052), 0),
((0, 0), 55.70680898398098),
((19, 61), 1.9727377843832639),
((19, 64), 0.34685744033355576),
((25, 65), 3.6523121397065657),
((18, 40), 12.392782544116978),
)