Skip to content

Commit c2c1251

Browse files
authored
Simplify scatter generation (#326)
1 parent a36562b commit c2c1251

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

src/data_morph/shapes/points/scatter.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,16 @@ class Scatter(PointCollection):
3030

3131
def __init__(self, dataset: Dataset) -> None:
3232
rng = np.random.default_rng(1)
33-
center = (dataset.data.x.mean(), dataset.data.y.mean())
33+
morph_range = dataset.morph_bounds.range
34+
center = dataset.morph_bounds.center
3435
points = [center]
35-
max_radius = max(dataset.data.x.std(), dataset.data.y.std())
3636
points.extend(
3737
[
3838
(
39-
center[0]
40-
+ np.cos(angle) * radius
41-
+ rng.standard_normal() * max_radius,
42-
center[1]
43-
+ np.sin(angle) * radius
44-
+ rng.standard_normal() * max_radius,
39+
center[0] + np.cos(angle) * rng.uniform(0, morph_range[0] / 2),
40+
center[1] + np.sin(angle) * rng.uniform(0, morph_range[1] / 2),
4541
)
46-
for radius in np.linspace(max_radius // 5, max_radius, num=5)
47-
for angle in np.linspace(0, 360, num=50, endpoint=False)
42+
for angle in np.linspace(0, 720, num=100, endpoint=False)
4843
]
4944
)
5045
super().__init__(*points)

0 commit comments

Comments
 (0)