Skip to content

Commit 9391f31

Browse files
stefmolinont-ncole
andauthored
Add spiral shape (#264)
Co-authored-by: Nick Cole <[email protected]>
1 parent c82a381 commit 9391f31

File tree

4 files changed

+78
-0
lines changed

4 files changed

+78
-0
lines changed

src/data_morph/shapes/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
RightParabola,
3434
Scatter,
3535
Spade,
36+
Spiral,
3637
UpParabola,
3738
)
3839

@@ -75,6 +76,7 @@ class ShapeFactory:
7576
SlantDownLines,
7677
SlantUpLines,
7778
Spade,
79+
Spiral,
7880
Star,
7981
UpParabola,
8082
VerticalLines,

src/data_morph/shapes/points/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .parabola import DownParabola, LeftParabola, RightParabola, UpParabola
77
from .scatter import Scatter
88
from .spade import Spade
9+
from .spiral import Spiral
910

1011
__all__ = [
1112
'Club',
@@ -16,5 +17,6 @@
1617
'RightParabola',
1718
'Scatter',
1819
'Spade',
20+
'Spiral',
1921
'UpParabola',
2022
]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Spiral shape."""
2+
3+
import numpy as np
4+
5+
from ...data.dataset import Dataset
6+
from ..bases.point_collection import PointCollection
7+
8+
9+
class Spiral(PointCollection):
10+
"""
11+
Class for the spiral shape.
12+
13+
.. plot::
14+
:scale: 75
15+
:caption:
16+
This shape is generated using the panda dataset.
17+
18+
from data_morph.data.loader import DataLoader
19+
from data_morph.shapes.points import Spiral
20+
21+
_ = Spiral(DataLoader.load_dataset('panda')).plot()
22+
23+
Parameters
24+
----------
25+
dataset : Dataset
26+
The starting dataset to morph into other shapes.
27+
28+
Notes
29+
-----
30+
This shape uses the formula for an `Archimedean spiral
31+
<https://en.wikipedia.org/wiki/Archimedean_spiral>`_.
32+
"""
33+
34+
def __init__(self, dataset: Dataset) -> None:
35+
max_radius = min(*dataset.morph_bounds.range) / 2
36+
37+
x_center, y_center = dataset.data_bounds.center
38+
x_range, y_range = dataset.data_bounds.range
39+
num_rotations = 3 if x_range >= y_range else 3.25
40+
41+
# progress of the spiral growing wider (0 to 100%)
42+
t = np.concatenate(
43+
[
44+
np.linspace(0, 0.1, 3, endpoint=False),
45+
np.linspace(0.1, 0.2, 5, endpoint=False),
46+
np.linspace(0.2, 0.5, 25, endpoint=False),
47+
np.linspace(0.5, 0.75, 30, endpoint=False),
48+
np.linspace(0.75, 1, 35, endpoint=True),
49+
]
50+
)
51+
52+
# x and y calculations for a spiral
53+
x = (t * max_radius) * np.cos(2 * num_rotations * np.pi * t) + x_center
54+
y = (t * max_radius) * np.sin(2 * num_rotations * np.pi * t) + y_center
55+
56+
super().__init__(*np.stack([x, y], axis=1))

tests/shapes/test_points.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,21 @@ class TestSpade(PointsModuleTestBase):
172172
((0, 0), 57.350348),
173173
((10, 80), 10.968080),
174174
)
175+
176+
177+
class TestSpiral(PointsModuleTestBase):
178+
"""Test the Spiral class."""
179+
180+
shape_name = 'spiral'
181+
distance_test_cases = (
182+
((10.862675, 65.846698), 0),
183+
((29.280789, 59.546024), 0),
184+
((16.022152, 68.248880), 0),
185+
((20.310858, 65.251728), 0),
186+
((22.803548, 72.599350), 0),
187+
((0, 0), 58.03780546896006),
188+
((10, 50), 8.239887412781957),
189+
((30, 70), 0.6642518196535838),
190+
((25, 65), 1.3042797087884075),
191+
((-30, 100), 52.14470630148412),
192+
)

0 commit comments

Comments
 (0)