Skip to content

Commit 2f9491a

Browse files
authored
Merge pull request #140 from stefmolin/rings-shape
Add Rings shape
2 parents 6ecd0bf + 847c818 commit 2f9491a

File tree

3 files changed

+95
-12
lines changed

3 files changed

+95
-12
lines changed

src/data_morph/shapes/circles.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numbers import Number
44

55
import matplotlib.pyplot as plt
6+
import numpy as np
67
from matplotlib.axes import Axes
78

89
from ..data.dataset import Dataset
@@ -85,37 +86,51 @@ def plot(self, ax: Axes = None) -> Axes:
8586
return ax
8687

8788

88-
class Bullseye(Shape):
89+
class Rings(Shape):
8990
"""
90-
Class representing a bullseye shape comprising two concentric circles.
91+
Class representing rings comprising multiple concentric circles.
9192
9293
.. plot::
9394
:scale: 75
9495
:caption:
9596
This shape is generated using the panda dataset.
9697
9798
from data_morph.data.loader import DataLoader
98-
from data_morph.shapes.circles import Bullseye
99+
from data_morph.shapes.circles import Rings
99100
100-
_ = Bullseye(DataLoader.load_dataset('panda')).plot()
101+
_ = Rings(DataLoader.load_dataset('panda')).plot()
101102
102103
Parameters
103104
----------
104105
dataset : Dataset
105106
The starting dataset to morph into other shapes.
107+
num_rings : int, default 4
108+
The number of rings to include. Must be greater than 1.
109+
110+
See Also
111+
--------
112+
Circle : The individual rings are represented as circles.
106113
"""
107114

108-
def __init__(self, dataset: Dataset) -> None:
115+
def __init__(self, dataset: Dataset, num_rings: int = 4) -> None:
116+
if not isinstance(num_rings, int):
117+
raise TypeError('num_rings must be an integer')
118+
if num_rings <= 1:
119+
raise ValueError('num_rings must be greater than 1')
120+
109121
stdev = dataset.df.std().mean()
110-
self.circles: list[Circle] = [Circle(dataset, r) for r in [stdev, stdev * 2]]
111-
"""list[Circle]: The inner and outer :class:`Circle` objects."""
122+
self.circles: list[Circle] = [
123+
Circle(dataset, r)
124+
for r in np.linspace(stdev / num_rings * 2, stdev * 2, num_rings)
125+
]
126+
"""list[Circle]: The individual rings represented by :class:`Circle` objects."""
112127

113128
def __repr__(self) -> str:
114129
return self._recursive_repr('circles')
115130

116131
def distance(self, x: Number, y: Number) -> float:
117132
"""
118-
Calculate the minimum absolute distance between this bullseye's inner and outer
133+
Calculate the minimum absolute distance between any of this shape's
119134
circles' edges and a point (x, y).
120135
121136
Parameters
@@ -126,13 +141,13 @@ def distance(self, x: Number, y: Number) -> float:
126141
Returns
127142
-------
128143
float
129-
The minimum absolute distance between this bullseye's inner and outer
144+
The minimum absolute distance between any of this shape's
130145
circles' edges and the point (x, y).
131146
132147
See Also
133148
--------
134149
Circle.distance :
135-
A bullseye consists of two circles, so we use the minimum
150+
Rings consists of multiple circles, so we use the minimum
136151
distance to one of the circles.
137152
"""
138153
return min(circle.distance(x, y) for circle in self.circles)
@@ -155,3 +170,31 @@ def plot(self, ax: Axes = None) -> Axes:
155170
for circle in self.circles:
156171
ax = circle.plot(ax)
157172
return ax
173+
174+
175+
class Bullseye(Rings):
176+
"""
177+
Class representing a bullseye shape comprising two concentric circles.
178+
179+
.. plot::
180+
:scale: 75
181+
:caption:
182+
This shape is generated using the panda dataset.
183+
184+
from data_morph.data.loader import DataLoader
185+
from data_morph.shapes.circles import Bullseye
186+
187+
_ = Bullseye(DataLoader.load_dataset('panda')).plot()
188+
189+
Parameters
190+
----------
191+
dataset : Dataset
192+
The starting dataset to morph into other shapes.
193+
194+
See Also
195+
--------
196+
Rings : The Bullseye is a special case where we only have 2 rings.
197+
"""
198+
199+
def __init__(self, dataset: Dataset) -> None:
200+
super().__init__(dataset=dataset, num_rings=2)

src/data_morph/shapes/factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ShapeFactory:
5151
'up_parab': points.UpParabola,
5252
'diamond': polygons.Diamond,
5353
'rectangle': polygons.Rectangle,
54+
'rings': circles.Rings,
5455
'star': polygons.Star,
5556
}
5657

@@ -61,22 +62,25 @@ class ShapeFactory:
6162
def __init__(self, dataset: Dataset) -> None:
6263
self._dataset: Dataset = dataset
6364

64-
def generate_shape(self, shape: str) -> Shape:
65+
def generate_shape(self, shape: str, **kwargs) -> Shape:
6566
"""
6667
Generate the shape object based on the dataset.
6768
6869
Parameters
6970
----------
7071
shape : str
7172
The desired shape. See :attr:`.AVAILABLE_SHAPES`.
73+
**kwargs
74+
Additional keyword arguments to pass down when creating
75+
the shape.
7276
7377
Returns
7478
-------
7579
Shape
7680
An shape object of the requested type.
7781
"""
7882
try:
79-
return self._SHAPE_MAPPING[shape](self._dataset)
83+
return self._SHAPE_MAPPING[shape](self._dataset, **kwargs)
8084
except KeyError as err:
8185
raise ValueError(f'No such shape as {shape}.') from err
8286

tests/shapes/test_circles.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,39 @@ def test_is_circle(self, shape):
7171
shape.cy + shape.r * np.sin(angles),
7272
):
7373
assert pytest.approx(shape.distance(x, y)) == 0
74+
75+
76+
class TestRings(CirclesModuleTestBase):
77+
"""Test the Rings class."""
78+
79+
shape_name = 'rings'
80+
distance_test_cases = [[(20, 50), 3.16987], [(10, 25), 9.08004]]
81+
repr_regex = (
82+
r'^<Rings>\n'
83+
r' circles=\n'
84+
r' <Circle cx=(\d+\.*\d*) cy=(\d+\.*\d*) r=(\d+\.*\d*)>\n'
85+
r' <Circle cx=(\d+\.*\d*) cy=(\d+\.*\d*) r=(\d+\.*\d*)>'
86+
)
87+
88+
@pytest.mark.parametrize('num_rings', [3, 5])
89+
def test_init(self, shape_factory, num_rings):
90+
"""Test that the Rings contains multiple concentric circles."""
91+
shape = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)
92+
93+
assert len(shape.circles) == num_rings
94+
assert all(
95+
getattr(circle, center_coord) == getattr(shape.circles[0], center_coord)
96+
for circle in shape.circles[1:]
97+
for center_coord in ['cx', 'cy']
98+
)
99+
assert len({circle.r for circle in shape.circles}) == num_rings
100+
101+
@pytest.mark.parametrize('num_rings', ['3', -5, 1, True])
102+
def test_num_rings_is_valid(self, shape_factory, num_rings):
103+
"""Test that num_rings input validation is working."""
104+
if isinstance(num_rings, int):
105+
with pytest.raises(ValueError, match='num_rings must be greater than 1'):
106+
_ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)
107+
else:
108+
with pytest.raises(TypeError, match='num_rings must be an integer'):
109+
_ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)

0 commit comments

Comments
 (0)