Skip to content

Commit b819859

Browse files
authored
Update heart and spade logic (#327)
1 parent c2c1251 commit b819859

File tree

6 files changed

+67
-19
lines changed

6 files changed

+67
-19
lines changed

src/data_morph/shapes/bases/point_collection.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from matplotlib.axes import Axes
1818

19+
from ...bounds.bounding_box import BoundingBox
20+
1921

2022
class PointCollection(Shape):
2123
"""
@@ -37,6 +39,32 @@ def __init__(self, *points: Iterable[Number]) -> None:
3739
def __repr__(self) -> str:
3840
return f'<{self.__class__.__name__} of {len(self.points)} points>'
3941

42+
@staticmethod
43+
def _center(points: np.ndarray, bounds: BoundingBox) -> np.ndarray:
44+
"""
45+
Center the points within the bounding box.
46+
47+
Parameters
48+
----------
49+
points : np.ndarray
50+
The points to center.
51+
bounds : BoundingBox
52+
The bounding box within which to center the points.
53+
54+
Returns
55+
-------
56+
np.ndarray
57+
The centered points.
58+
"""
59+
maxes = points.max(axis=0)
60+
span = maxes - points.min(axis=0)
61+
gap = (np.array(bounds.range) - span) / 2
62+
63+
(_, xmax), (_, ymax) = bounds
64+
shift = np.array([xmax, ymax]) - maxes - gap
65+
66+
return points + shift
67+
4068
def distance(self, x: Number, y: Number) -> float:
4169
"""
4270
Calculate the minimum distance from the points of this shape

src/data_morph/shapes/points/heart.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,20 @@ class Heart(PointCollection):
3535
"""
3636

3737
def __init__(self, dataset: Dataset) -> None:
38-
_, xmax = dataset.data_bounds.x_bounds
39-
x_shift, y_shift = dataset.data_bounds.center
38+
data_bounds = dataset.data_bounds
39+
(_, xmax), (_, ymax) = data_bounds
40+
x_shift, y_shift = data_bounds.center
4041

4142
t = np.linspace(-3, 3, num=80)
4243

4344
x = 16 * np.sin(t) ** 3
4445
y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)
4546

4647
# scale by the half the widest width of the heart
47-
scale_factor = (xmax - x_shift) / 16
48+
scale_factor = min((xmax - x_shift), (ymax - y_shift)) / 16
4849

4950
super().__init__(
50-
*np.stack([x * scale_factor + x_shift, y * scale_factor + y_shift], axis=1)
51+
*self._center(
52+
np.stack([x * scale_factor, y * scale_factor], axis=1), data_bounds
53+
)
5154
)

src/data_morph/shapes/points/spade.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ class Spade(PointCollection):
3030
"""
3131

3232
def __init__(self, dataset: Dataset) -> None:
33-
_, xmax = dataset.data_bounds.x_bounds
34-
x_shift, y_shift = dataset.data_bounds.center
33+
data_bounds = dataset.data_bounds
34+
_, xmax = data_bounds.x_bounds
35+
x_shift, y_shift = data_bounds.center
3536

3637
# upside-down heart
3738
heart_points = self._get_inverted_heart(dataset, y_shift)
@@ -43,7 +44,7 @@ def __init__(self, dataset: Dataset) -> None:
4344
x = np.concatenate((heart_points[:, 0], base_x), axis=0)
4445
y = np.concatenate((heart_points[:, 1], base_y), axis=0)
4546

46-
super().__init__(*np.stack([x, y], axis=1))
47+
super().__init__(*self._center(np.stack([x, y], axis=1), data_bounds))
4748

4849
@staticmethod
4950
def _get_inverted_heart(dataset: Dataset, y_shift: Number) -> np.ndarray:

tests/shapes/bases/test_point_collection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import re
44

55
import matplotlib.pyplot as plt
6+
import numpy as np
67
import pytest
78

9+
from data_morph.bounds.bounding_box import BoundingBox
810
from data_morph.shapes.bases.point_collection import PointCollection
911

1012

@@ -18,6 +20,19 @@ def point_collection(self):
1820
"""An instance of PointCollection."""
1921
return PointCollection([0, 0], [20, 50])
2022

23+
@pytest.mark.parametrize(
24+
'bounding_box',
25+
[BoundingBox([0, 100], [-50, 50]), BoundingBox([0, 20], [0, 50])],
26+
)
27+
def test_center(self, point_collection, bounding_box):
28+
"""Test that points are centered within the bounding box."""
29+
points = point_collection._center(point_collection.points, bounding_box)
30+
31+
(xmin, xmax), (ymin, ymax) = bounding_box
32+
upper = np.array([xmax, ymax]) - points.max(axis=0)
33+
lower = points.min(axis=0) - np.array([xmin, ymin])
34+
assert np.array_equal(upper, lower)
35+
2136
def test_distance_zero(self, point_collection):
2237
"""Test the distance() method on points in the collection."""
2338
for point in point_collection.points:

tests/shapes/points/test_heart.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ class TestHeart(PointsModuleTestBase):
1212

1313
shape_name = 'heart'
1414
distance_test_cases = (
15-
((19.89946048, 54.82281916), 0.0),
16-
((10.84680454, 70.18556376), 0.0),
17-
((29.9971295, 67.66402445), 0.0),
18-
((27.38657942, 62.417184), 0.0),
19-
((20, 50), 4.567369),
20-
((10, 80), 8.564365),
15+
((22.424114, 59.471779), 0.0),
16+
((10.405462, 70.897342), 0.0),
17+
((21.064032, 72.065253), 0.0),
18+
((16.035166, 60.868470), 0.0),
19+
((20, 50), 6.065782511791651),
20+
((10, 80), 7.173013322704914),
2121
)

tests/shapes/points/test_spade.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ class TestSpade(PointsModuleTestBase):
1212

1313
shape_name = 'spade'
1414
distance_test_cases = (
15-
((19.97189615, 75.43271708), 0),
16-
((23.75, 55), 0),
17-
((11.42685318, 59.11304904), 0),
18-
((20, 75), 0.2037185),
19-
((0, 0), 57.350348),
20-
((10, 80), 10.968080),
15+
((19.818701, 60.065370), 0),
16+
((23.750000, 55.532859), 0),
17+
((20.067229, 60.463689), 0),
18+
((18.935968, 58.467606), 0),
19+
((20, 75), 0.5335993101603015),
20+
((0, 0), 57.861566654807596),
21+
((10, 80), 11.404000978114487),
2122
)

0 commit comments

Comments
 (0)