diff --git a/src/data_morph/shapes/factory.py b/src/data_morph/shapes/factory.py index d5975c63..d377df0c 100644 --- a/src/data_morph/shapes/factory.py +++ b/src/data_morph/shapes/factory.py @@ -33,6 +33,7 @@ RightParabola, Scatter, Spade, + Spiral, UpParabola, ) @@ -75,6 +76,7 @@ class ShapeFactory: SlantDownLines, SlantUpLines, Spade, + Spiral, Star, UpParabola, VerticalLines, diff --git a/src/data_morph/shapes/points/__init__.py b/src/data_morph/shapes/points/__init__.py index 8609f4d4..373961a1 100644 --- a/src/data_morph/shapes/points/__init__.py +++ b/src/data_morph/shapes/points/__init__.py @@ -6,6 +6,7 @@ from .parabola import DownParabola, LeftParabola, RightParabola, UpParabola from .scatter import Scatter from .spade import Spade +from .spiral import Spiral __all__ = [ 'Club', @@ -16,5 +17,6 @@ 'RightParabola', 'Scatter', 'Spade', + 'Spiral', 'UpParabola', ] diff --git a/src/data_morph/shapes/points/spiral.py b/src/data_morph/shapes/points/spiral.py new file mode 100644 index 00000000..8deceaf2 --- /dev/null +++ b/src/data_morph/shapes/points/spiral.py @@ -0,0 +1,56 @@ +"""Spiral shape.""" + +import numpy as np + +from ...data.dataset import Dataset +from ..bases.point_collection import PointCollection + + +class Spiral(PointCollection): + """ + Class for the spiral shape. + + .. plot:: + :scale: 75 + :caption: + This shape is generated using the panda dataset. + + from data_morph.data.loader import DataLoader + from data_morph.shapes.points import Spiral + + _ = Spiral(DataLoader.load_dataset('panda')).plot() + + Parameters + ---------- + dataset : Dataset + The starting dataset to morph into other shapes. + + Notes + ----- + This shape uses the formula for an `Archimedean spiral + `_. + """ + + def __init__(self, dataset: Dataset) -> None: + max_radius = min(*dataset.morph_bounds.range) / 2 + + x_center, y_center = dataset.data_bounds.center + x_range, y_range = dataset.data_bounds.range + num_rotations = 3 if x_range >= y_range else 3.25 + + # progress of the spiral growing wider (0 to 100%) + t = np.concatenate( + [ + np.linspace(0, 0.1, 3, endpoint=False), + np.linspace(0.1, 0.2, 5, endpoint=False), + np.linspace(0.2, 0.5, 25, endpoint=False), + np.linspace(0.5, 0.75, 30, endpoint=False), + np.linspace(0.75, 1, 35, endpoint=True), + ] + ) + + # x and y calculations for a spiral + x = (t * max_radius) * np.cos(2 * num_rotations * np.pi * t) + x_center + y = (t * max_radius) * np.sin(2 * num_rotations * np.pi * t) + y_center + + super().__init__(*np.stack([x, y], axis=1)) diff --git a/tests/shapes/test_points.py b/tests/shapes/test_points.py index 61a109ff..a9b705ff 100644 --- a/tests/shapes/test_points.py +++ b/tests/shapes/test_points.py @@ -172,3 +172,21 @@ class TestSpade(PointsModuleTestBase): ((0, 0), 57.350348), ((10, 80), 10.968080), ) + + +class TestSpiral(PointsModuleTestBase): + """Test the Spiral class.""" + + shape_name = 'spiral' + distance_test_cases = ( + ((10.862675, 65.846698), 0), + ((29.280789, 59.546024), 0), + ((16.022152, 68.248880), 0), + ((20.310858, 65.251728), 0), + ((22.803548, 72.599350), 0), + ((0, 0), 58.03780546896006), + ((10, 50), 8.239887412781957), + ((30, 70), 0.6642518196535838), + ((25, 65), 1.3042797087884075), + ((-30, 100), 52.14470630148412), + )