diff --git a/src/data_morph/bounds/bounding_box.py b/src/data_morph/bounds/bounding_box.py index 60cb9bd7..e137aa73 100644 --- a/src/data_morph/bounds/bounding_box.py +++ b/src/data_morph/bounds/bounding_box.py @@ -167,3 +167,15 @@ def range(self) -> Iterable[Number]: The range covered by the x and y bounds, respectively. """ return self.x_bounds.range, self.y_bounds.range + + @property + def center(self) -> Iterable[Number]: + """ + Calculate the center of the bounding box. + + Returns + ------- + Iterable[numbers.Number] + The center of the x and y bounds, respectively. + """ + return self.x_bounds.center, self.y_bounds.center diff --git a/src/data_morph/bounds/interval.py b/src/data_morph/bounds/interval.py index 72d0eebe..b49926f7 100644 --- a/src/data_morph/bounds/interval.py +++ b/src/data_morph/bounds/interval.py @@ -164,3 +164,15 @@ def range(self) -> Number: The range covered by the interval. """ return abs(self._bounds[1] - self._bounds[0]) + + @property + def center(self) -> Number: + """ + Calculate the center of the interval. + + Returns + ------- + numbers.Number + The center of the interval. + """ + return sum(self) / 2 diff --git a/src/data_morph/shapes/points/club.py b/src/data_morph/shapes/points/club.py index 1ce579fb..6dfea021 100644 --- a/src/data_morph/shapes/points/club.py +++ b/src/data_morph/shapes/points/club.py @@ -29,18 +29,14 @@ class Club(PointCollection): """ def __init__(self, dataset: Dataset) -> None: - x_bounds = dataset.data_bounds.x_bounds - y_bounds = dataset.data_bounds.y_bounds - - x_shift = sum(x_bounds) / 2 - y_shift = sum(y_bounds) / 2 - scale_factor = min(x_bounds.range, y_bounds.range) / 75 + scale_factor = min(*dataset.data_bounds.range) / 75 x_lobes, y_lobes = self._get_lobes(scale_factor) x_stem, y_stem = self._get_stem(scale_factor) - xs = x_shift + np.concatenate(x_lobes + x_stem) - ys = y_shift + np.concatenate(y_lobes + y_stem) + x_center, y_center = dataset.data_bounds.center + xs = x_center + np.concatenate(x_lobes + x_stem) + ys = y_center + np.concatenate(y_lobes + y_stem) super().__init__(*np.stack([xs, ys], axis=1)) diff --git a/src/data_morph/shapes/points/heart.py b/src/data_morph/shapes/points/heart.py index c8db9382..71eda930 100644 --- a/src/data_morph/shapes/points/heart.py +++ b/src/data_morph/shapes/points/heart.py @@ -36,10 +36,7 @@ class Heart(PointCollection): def __init__(self, dataset: Dataset) -> None: x_bounds = dataset.data_bounds.x_bounds - y_bounds = dataset.data_bounds.y_bounds - - x_shift = sum(x_bounds) / 2 - y_shift = sum(y_bounds) / 2 + x_shift, y_shift = dataset.data_bounds.center t = np.linspace(-3, 3, num=80) diff --git a/src/data_morph/shapes/points/spade.py b/src/data_morph/shapes/points/spade.py index 5acdefff..4104a946 100644 --- a/src/data_morph/shapes/points/spade.py +++ b/src/data_morph/shapes/points/spade.py @@ -31,10 +31,7 @@ class Spade(PointCollection): def __init__(self, dataset: Dataset) -> None: x_bounds = dataset.data_bounds.x_bounds - y_bounds = dataset.data_bounds.y_bounds - - x_shift = sum(x_bounds) / 2 - y_shift = sum(y_bounds) / 2 + x_shift, y_shift = dataset.data_bounds.center # upside-down heart heart_points = self._get_inverted_heart(dataset, y_shift) diff --git a/tests/bounds/test_bounding_box.py b/tests/bounds/test_bounding_box.py index c96eb1f4..83633262 100644 --- a/tests/bounds/test_bounding_box.py +++ b/tests/bounds/test_bounding_box.py @@ -205,3 +205,8 @@ def test_range(self): """Test that the range property is working.""" bbox = BoundingBox([0, 10], [5, 10]) assert bbox.range == (10, 5) + + def test_center(self): + """Test that the center property is working.""" + bbox = BoundingBox([0, 10], [5, 10]) + assert bbox.center == (5, 7.5) diff --git a/tests/bounds/test_interval.py b/tests/bounds/test_interval.py index 2312e052..4c744dcb 100644 --- a/tests/bounds/test_interval.py +++ b/tests/bounds/test_interval.py @@ -172,3 +172,18 @@ def test_range(self, limits, inclusive, expected): """Test that the range property is working.""" bounds = Interval(limits, inclusive) assert bounds.range == expected + + @pytest.mark.parametrize('inclusive', [True, False]) + @pytest.mark.parametrize( + ('limits', 'expected'), + [ + ([-10, -5], -7.5), + ([-1, 1], 0), + ([2, 100], 51), + ], + ids=str, + ) + def test_center(self, limits, inclusive, expected): + """Test that the center property is working.""" + bounds = Interval(limits, inclusive) + assert bounds.center == expected