Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions src/plopp/backends/pythreejs/scatter3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def __init__(
opacity: float = 1,
pixel_size: sc.Variable | float | None = None,
):
import pythreejs as p3

check_ndim(data, ndim=1, origin='Scatter3d')
self.uid = uid if uid is not None else uuid.uuid4().hex
self._canvas = canvas
Expand All @@ -73,6 +71,7 @@ def __init__(
self._x = x
self._y = y
self._z = z
self._opacity = opacity

# TODO: remove pixel_size in the next release
self._size = size if pixel_size is None else pixel_size
Expand All @@ -88,14 +87,29 @@ def __init__(
dtype=float, unit=self._data.coords[x].unit
).value

self.points = None
self._make_point_cloud()

if self._colormapper is not None:
self._colormapper.add_artist(self.uid, self)
colors = self._colormapper.rgba(self.data)[..., :3].astype('float32')
self._update_colors()
else:
colors = np.broadcast_to(
np.array(to_rgb(f'C{artist_number}' if color is None else color)),
(self._data.coords[self._x].shape[0], 3),
).astype('float32')
self.geometry.attributes["color"].array = colors

self._add_point_cloud_to_scene()

def _make_point_cloud(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does this take approximately? Is it as fast as updating a 2d plot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the number of points. I did some basic timings: it's about 0.01s for 100_000 points, and ~0.04s for 500_000 points.
But this is only run if the updated values have a different shape than the existing ones, which I am guessing is not super common.

A more relevant question is probably how long does it take to update the positions every time?
We are now running

        self.geometry.attributes["position"].array = np.array(
            [
                self._data.coords[self._x].values.astype('float32'),
                self._data.coords[self._y].values.astype('float32'),
                self._data.coords[self._z].values.astype('float32'),
            ]
        ).T

on every update, which is potentially quite a large allocation?
Before, we only have one array of floats for the colors, now we have 4 (colors + 3 positions).

We could only update if the coords have changed, but we would have to check something like

if any(not sc.identical(old_coords[dim], self._data.coords[dim]) for dim in "xyz"):

I need to check the timings of such a check, maybe it's fast enough?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, check it before you optimise too much! 0.04s seems fine. I don't think anyone expects 60fps.

Just a guess, but maybe you can make the big allocation slightly cheaper by using np.stack or any of its variants instead of np.array([..]).T.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check can actually be pretty fast (10x or more) compared to setting the position, so I added it in.

"""
Create the point cloud geometry and material.
"""
import pythreejs as p3

if self.points is not None:
self._canvas.remove(self.points)

self.geometry = p3.BufferGeometry(
attributes={
Expand All @@ -108,7 +122,11 @@ def __init__(
]
).T
),
'color': p3.BufferAttribute(array=colors),
'color': p3.BufferAttribute(
array=np.zeros(
(self._data.coords[self._x].shape[0], 3), dtype='float32'
)
),
}
)

Expand All @@ -120,9 +138,14 @@ def __init__(
vertexColors='VertexColors',
size=2.5 * self._size * pixel_ratio,
transparent=True,
opacity=opacity,
opacity=self._opacity,
)
self.points = p3.Points(geometry=self.geometry, material=self.material)

def _add_point_cloud_to_scene(self) -> None:
"""
Add the point cloud to the canvas scene.
"""
self._canvas.add(self.points)

def notify_artist(self, message: str) -> None:
Expand All @@ -137,28 +160,55 @@ def notify_artist(self, message: str) -> None:
"""
self._update_colors()

def _update_colors(self):
def _update_colors(self) -> None:
"""
Set the point cloud's rgba colors:
"""
self.geometry.attributes["color"].array = self._colormapper.rgba(self.data)[
..., :3
].astype('float32')

def _update_positions(self) -> None:
"""
Update the point cloud's positions from the data.
"""
self.geometry.attributes["position"].array = np.array(
[
self._data.coords[self._x].values.astype('float32'),
self._data.coords[self._y].values.astype('float32'),
self._data.coords[self._z].values.astype('float32'),
]
).T

def update(self, new_values):
"""
Update point cloud array with new values.
If the coordinates have changed, the positions of the points are re-computed,
only if ``validate_on_update`` is ``True``.

Parameters
----------
new_values:
New data to update the point cloud values from.
"""
check_ndim(new_values, ndim=1, origin='Scatter3d')
need_new_point_cloud = False
if self._data.shape != new_values.shape:
need_new_point_cloud = True

self._data = new_values

if need_new_point_cloud:
self._make_point_cloud()
else:
self._update_positions()

if self._colormapper is not None:
self._update_colors()

if need_new_point_cloud:
self._add_point_cloud_to_scene()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have to remove the old cloud?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done above on L111, but maybe it makes more sense to do it here... and remove of the if self.points is not None


@property
def opacity(self) -> float:
"""
Expand Down
12 changes: 12 additions & 0 deletions tests/backends/pythreejs/pythreejs_scatter3d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,15 @@ def test_update_raises_when_data_is_not_1d():
sc.DimensionError, match='Scatter3d only accepts data with 1 dimension'
):
scat.update(da2d)


def test_update_with_different_number_of_points():
da = scatter(npoints=500)
scat = Scatter3d(canvas=Canvas(), data=da, x='x', y='y', z='z')
assert scat.points.geometry.attributes['position'].array.shape[0] == 500
assert scat.points.geometry.attributes['color'].array.shape[0] == 500
new = scatter(npoints=200)
scat.update(new)
assert scat.points.geometry.attributes['position'].array.shape[0] == 200
assert scat.points.geometry.attributes['color'].array.shape[0] == 200
assert sc.identical(scat._data, new)
Loading