Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
199a2eb
add comments and data
MaHaWo Feb 5, 2025
090f954
improve pyvista plotting functions
MaHaWo Feb 18, 2025
a6eeab8
add tests for pyvista plotting utils
MaHaWo Mar 11, 2025
2ab4fa6
finish utils test
MaHaWo Mar 12, 2025
d6e6d31
reorganize tests
MaHaWo Mar 12, 2025
5df8906
make test compatible with coverage reporting
MaHaWo Mar 12, 2025
b4c51a8
remove superfluous function, make docstring better
MaHaWo Mar 12, 2025
6d9b68b
adjust python versions
MaHaWo Mar 12, 2025
ab05b25
add pyvista headless display to
MaHaWo Mar 12, 2025
677e766
make code compatible with old python versions
MaHaWo Mar 12, 2025
45c1de6
try to fix docstring issues
MaHaWo Mar 12, 2025
9520ac8
fix documentation issue
MaHaWo Mar 12, 2025
f14453c
fix issues with default args
MaHaWo Mar 12, 2025
efee1f6
fix bug in constructor call
MaHaWo Mar 12, 2025
6d5ff5a
add functions that eat sme.SimulationResult
MaHaWo Apr 3, 2025
8a579f7
add 3D wrapper functions and docs notebook stuff
MaHaWo Apr 3, 2025
7c1fc44
correct type annotations
MaHaWo Apr 3, 2025
1a23b49
correct typos
MaHaWo Apr 3, 2025
57901a7
make notebook work
MaHaWo Apr 3, 2025
beb7026
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2025
9a8279e
try to add static backend to make pv work
MaHaWo Apr 3, 2025
f91081b
Merge branch 'add-pyvista-3D-visualization' of github.com:spatial-mod…
MaHaWo Apr 3, 2025
aca5a2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2025
7cced94
try to make docs thing work again damnit
MaHaWo Apr 3, 2025
76a8f85
Merge branch 'add-pyvista-3D-visualization' of github.com:spatial-mod…
MaHaWo Apr 3, 2025
4d92d5c
fix import issue
MaHaWo Apr 4, 2025
24de02d
fix improve README section in notebook
MaHaWo Apr 4, 2025
6cd63e1
use client mode for pv to try and make pipeline work
MaHaWo Apr 4, 2025
723b4fc
try to make doc build on ci work
MaHaWo Apr 4, 2025
35e5218
try once more to make docs ci work
MaHaWo Apr 4, 2025
9ea64e7
adjust headings
MaHaWo Apr 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: pyvista/setup-headless-display-action@v3
- run: pip install -e .[tests]
- run: python -m pytest --cov=sme_contrib --cov-report=xml -v
- uses: codecov/codecov-action@v3
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ classifiers = [
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",

]
dependencies = ["matplotlib", "numpy", "pillow", "pyswarms", "sme>=1.4.0"]
dependencies = ["matplotlib", "numpy", "pillow", "pyswarms", "sme>=1.4.0","pyvista[all]", "imageio[ffmpeg]"]

dynamic = ["version"]

[project.urls]
Expand All @@ -42,7 +44,7 @@ docs = [
"nbsphinx",
"pandoc",
"sphinx>=4.5.0",
"sphinx_rtd_theme>=1.0.0"
"sphinx_rtd_theme>=1.0.0",
]

[tool.setuptools.dynamic]
Expand Down
2 changes: 1 addition & 1 deletion src/sme_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.16"
__version__ = "0.0.17"
162 changes: 162 additions & 0 deletions src/sme_contrib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap as lscmap
from matplotlib import animation
import pyvista as pv
from typing import Any, Callable, Union

from .pyvista_utils import (
find_layout,
)


def colormap(color, name="my colormap"):
Expand Down Expand Up @@ -120,3 +126,159 @@ def concentration_heatmap_animation(
)
plt.close()
return anim


def facet_grid_3D(
data: dict[str, np.ndarray],
plotfuncs: dict[str, Callable],
show_cmap: bool = False,
cmap: Union[str, np.ndarray, pv.LookupTable] = "viridis",
portrait: bool = False,
linked_views: bool = True,
plotter_kwargs: Union[dict, None] = None,
plotfuncs_kwargs: Union[dict[str, dict[str, Any]], None] = None,
) -> pv.Plotter:
"""
Create a 3D facet plot using PyVista.

This follows the seaborn.FacetGrid concept. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary. Do not attempt to plot 2D images and 3D images into the same facet grid, as this will create odd artifacts and may not work as expected.

Args:
data : (dict[str, np.ndarray]) A dictionary where keys are labels and values are numpy arrays containing the data to be plotted.
plotfuncs : (dict[str, Callable]) A dictionary where keys are labels and values are functions with signature ``f(label:str, data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid, plotter:pv.Plotter, panel:tuple[int, int], show_cmap:bool=show_cmap, cmap=cmap, **plotfuncs_kwargs )`` -> None
show_cmap : bool, optional Whether to show the color map. Default is False.
cmap : (str | np.ndarray | pv.LookupTable), optional The color map to use. Default is "viridis".
portrait : (bool), optional Whether to use a portrait layout. Default is False.
linked_views : (bool), optional Whether to link the views of the subplots. Default is True.
plotter_kwargs : (dict, optional) Additional keyword arguments to pass to the PyVista Plotter.
plotfuncs_kwargs : (dict[str, dict[str, Any]]), optional Additional keyword arguments to pass to each plotting function.

Returns:
pv.Plotter The PyVista Plotter object with the created facet plot.
"""
if data.keys() != plotfuncs.keys():
raise ValueError(
"The keys for the data and plotfuncs dictionaries must be the same."
)

layout = find_layout(len(data), portrait=portrait)

plotter = pv.Plotter(
shape=layout, **(plotter_kwargs if plotter_kwargs is not None else {})
)

label = iter(plotfuncs.keys())

for i in range(layout[0]):
for j in range(layout[1]):
current_label = next(label)
plotfuncs[current_label](
current_label,
data[current_label],
plotter,
panel=(i, j),
show_cmap=show_cmap,
cmap=cmap,
**(
plotfuncs_kwargs.get(current_label, {})
if plotfuncs_kwargs is not None
else {}
),
)

if linked_views:
plotter.link_views()

return plotter


def facet_grid_animate_3D(
filename: str,
data: list[dict[str, np.ndarray]],
plotfuncs: dict[str, Callable],
show_cmap: bool = False,
cmap: Union[str, np.ndarray, pv.LookupTable] = "viridis",
portrait: bool = False,
linked_views: bool = True,
titles: Union[list[dict[str, str]], None] = None,
plotter_kwargs: Union[dict, None] = None,
plotfuncs_kwargs: Union[dict[str, dict[str, Any]], None] = None,
) -> str:
"""
Create a 3D animation from a series of data snapshots using PyVista.

This series must be a list of dictionaries with the data for each frame keyed by a label used to title the panel it will be plotted into. The final plot will have as many subplots as there are labels in the data dictionaries. The keys for plotfuncs and data must be the same.

Args:
filename : (str) The name of the output movie file.
data : (list[dict[str, np.ndarray]]) A list of dictionaries containing the data for each timestep.
plotfuncs : (dict[str, Callable]) A dictionary of plotting functions keyed by data label. The keys for plotfuncs and data must be the same.
show_cmap : (bool), optional Whether to show the color map (default is False).
cmap : (str | np.ndarray | pv.LookupTable, optional) The colormap to use (default is "viridis").
portrait : (bool), optional Whether to use portrait layout (default is False).
linked_views : (bool), optional Whether to link the views of the subplots (default is True).
titles : (list[dict[str, str]]), optional A list of dictionaries containing titles for each subplot (default is an empty list).
plotter_kwargs : (dict), optional Additional keyword arguments to pass to the PyVista Plotter (default is an empty dictionary).
plotfuncs_kwargs : (dict[str, dict[str, Any]]), optional Additional keyword arguments to pass to each plotting function (default is an empty dictionary).

Returns:
str The filename of the created movie.
"""
if titles is None:
titles = []

if len(titles) > 0 and len(titles) != len(data):
raise ValueError(
"The number of titles must be the same as the number of data dictionaries."
)

if data[0].keys() != plotfuncs.keys():
raise ValueError(
"The keys for the data and plotfuncs dictionaries must be the same."
)

# main function, called for each frame in the movie
def create_frame(
data_dict: dict[str, np.ndarray], title: dict[str:str], layout=(1, 1)
):
label = iter(data_dict.keys())
for i in range(layout[0]):
for j in range(layout[1]):
current_label = next(label)
plotfuncs[current_label](
title.get(current_label, current_label),
data_dict[current_label],
plotter,
panel=(i, j),
show_cmap=show_cmap,
cmap=cmap,
**plotfuncs_kwargs.get(current_label, {})
if plotfuncs_kwargs is not None
else {},
)

plotter.write_frame()

# preparations
layout = find_layout(len(plotfuncs), portrait=portrait)

plotter = pv.Plotter(
shape=layout, **plotter_kwargs if plotter_kwargs is not None else {}
)

plotter.open_movie(filename)

# add first frame here to set up the plotter
create_frame(data[0], titles[0] if len(titles) > 0 else {}, layout)

if linked_views:
plotter.link_views()

for i, single_timestep_data in enumerate(data[1::]):
create_frame(
single_timestep_data, titles[i] if len(titles) > 0 else {}, layout=layout
)

plotter.close()

return filename
110 changes: 110 additions & 0 deletions src/sme_contrib/pyvista_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pyvista as pv
import numpy as np
from itertools import cycle
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt


def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
"""
Convert an RGB 3D image represented as a 4D tensor to a 3D image tensor where each unique RGB value is assigned a unique scalar, i.e., it contracts the dimension with the RGB values into scalars in such a way that 2 different colors are mapped to 2 different scalars, too. This is needed because PyVista doesn't work with RGB values directly and expects fields defined on a grid.

Args:
img (np.ndarray): A 3D numpy array representing an RGB image with shape (height, width, 3).

Retruns:
np.ndarray: A 2D numpy array with the same height and width as the input image, where each pixel's value corresponds to a unique scalar representing the original RGB value.
"""
reshaped = np.copy(img.reshape(-1, 3))
unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)

values = np.arange(len(unique_rgb))
return values[ridx].reshape(img.shape[:-1])


def make_discrete_colormap(
cmap: str = "tab10", values: np.ndarray = np.array([])
) -> pv.LookupTable:
"""
Create a discrete colormap for use with PyVista with as many colors as unique values in the ``values``array based on a given matplotlbit colormap. The colors will possibly repeat if there are more unique values than colors in the colormap. In this case, the outcome is intended, e.g., for separability of regions in the visualization,

Parameters:
cmap (str): The name of the colormap to use. Default is 'tab10'.
values (np.ndarray): An array of values to map to colors. Default is an empty array.

Returns:
pv.LookupTable: A PyVista LookupTable object with the values drawn from the specified colormap in RGBA format.
"""
cm = []

if values.size == 0:
values = np.arange(0, 1, 1)
cm = [
mcolors.to_rgba(plt.get_cmap(cmap).colors[0]),
]
else:
i = 0
for c in cycle(plt.get_cmap(cmap).colors):
cm.append(mcolors.to_rgba(c))
if len(cm) >= len(values):
break
i += 1
lt = pv.LookupTable(
values=np.array(cm) * 255,
scalar_range=(0, len(values)),
n_values=len(values),
)

return lt


def find_layout(num_plots: int, portrait: bool = False) -> tuple[int, int]:
"""Find a reasonable layout for a grid of subplots. This splits num_subplots into n x m subplots where n and m are as close as possible to each other. This can include a case where n x m > num_plots. Then, the superficial panels in the grid are ignored in the plotting process.

Args:
num_plots (int): Number of plots to arrange
portrait (bool, optional): Whether the min or max of (n,m) should be the column number in the resulting grid. Defaults to False.

Returns:
tuple[int, int]: Tuple describing (n_rows, n_cols) of the grid
"""

# for checking approximation accuracy with ints. if root > root_int, then
# we need to adjust n_row, n_cols sucht that n_row * n_cols >= root^2
root = np.sqrt(num_plots)
root_int = np.rint(root)

if np.isclose(root, root_int):
return int(root_int), int(root_int) # perfect square because root is an integer
else:
# approximation by integer root is inexact

# find an approximation that is close to square such that n_row * n_cols - num_plots is
# as small as possible
a = int(np.floor(root))
b = int(np.ceil(root))

a_1 = int(a - 1)
b_1 = int(b + 1)

# make a couple of guesses that are close to the root and select the best one
guesses = [
(x, y)
for x, y in [
(a, b),
(a_1, b_1),
(a, b_1),
(a_1, b),
]
if x * y >= num_plots
]
best_guess = guesses[
np.argmin([x * y for x, y in guesses])
] # smallest possible approximation

# handle orientation of the grid. min => rows for landscape, min=> cols for portrait
return (
(np.min(best_guess), np.max(best_guess))
if not portrait
else (np.max(best_guess), np.min(best_guess))
)
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from pyvista import examples


@pytest.fixture(scope="session")
def exampledata():
armadillo = examples.download_armadillo()
bloodvessel = examples.download_blood_vessels()
brain = examples.download_brain()

return {
"armadillo": armadillo,
"bloodvessel": bloodvessel,
"brain": brain,
}
Loading
Loading