Skip to content

Commit ce4cb7f

Browse files
committed
feat: Rework animations
1 parent ec32cdd commit ce4cb7f

File tree

2 files changed

+138
-72
lines changed

2 files changed

+138
-72
lines changed

src/radiosim/ppdisks/plotting/plotting.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,13 @@
66
from .utils import (
77
configure_axes,
88
configure_colorbar,
9-
ellipse_img2cartesian_img,
109
get_norm,
1110
)
1211

1312

14-
def plot_polar_image(
15-
polar_intensities: np.ndarray,
16-
grid_shape: tuple,
17-
r_lims: ArrayLike,
18-
phi_lims: ArrayLike | None = None,
19-
intensity_unit: str | None = None,
13+
def plot_image(
14+
data: np.ndarray,
15+
intensity_label: str | None = None,
2016
a_maj: float = 1.0,
2117
b_min: float = 1.0,
2218
rot_angle: float = 0.0,
@@ -33,58 +29,30 @@ def plot_polar_image(
3329
fig: matplotlib.figure.Figure | None = None,
3430
ax: matplotlib.axes.Axes | None = None,
3531
) -> tuple[matplotlib.image.AxesImage, matplotlib.figure.Figure, matplotlib.axes.Axes]:
36-
if phi_lims is None:
37-
phi_lims = [-np.pi, np.pi]
38-
39-
if xy_lims is None:
40-
xy_lims = ([-r_lims[1], r_lims[1]], [-r_lims[1], r_lims[1]])
41-
4232
norm = get_norm(norm) if isinstance(norm, str) else norm
4333

44-
intensity_unit = "Intensity / a.u." if intensity_unit is None else intensity_unit
34+
intensity_label = "Intensity / a.u." if intensity_label is None else intensity_label
4535

4636
save_args = {} if save_args is None else save_args
4737

4838
plot_args = {} if plot_args is None else plot_args
4939
fig_args = {} if fig_args is None else fig_args
5040

51-
r = np.linspace(
52-
r_lims[0],
53-
r_lims[1],
54-
polar_intensities.shape[0],
55-
)
56-
phi = np.linspace(phi_lims[0], phi_lims[1], polar_intensities.shape[1])
57-
58-
rs, phis = np.meshgrid(r, phi)
59-
rs = rs.T
60-
phis = phis.T
61-
62-
data_trafo = ellipse_img2cartesian_img(
63-
r=rs,
64-
phi=phis,
65-
intensities=polar_intensities,
66-
grid_shape=grid_shape,
67-
a=a_maj,
68-
b=b_min,
69-
alpha=rot_angle,
70-
xy_lims=xy_lims,
71-
)
72-
7341
fig, ax = configure_axes(fig=fig, ax=ax, fig_args=fig_args)
7442

7543
im = ax.imshow(
76-
data_trafo,
44+
data,
7745
origin="lower",
7846
cmap=cmap,
7947
interpolation="none",
8048
norm=get_norm(norm=norm)
8149
if intensity_limits is None
8250
else get_norm(norm=norm, vmin=intensity_limits[0], vmax=intensity_limits[1]),
83-
extent=np.ravel(xy_lims),
51+
extent=np.ravel(xy_lims) if xy_lims is not None else None,
8452
**plot_args,
8553
)
8654

87-
configure_colorbar(mappable=im, ax=ax, fig=fig, label=intensity_unit)
55+
configure_colorbar(mappable=im, ax=ax, fig=fig, label=intensity_label)
8856

8957
ax.set_xlabel(f"$x$ / {xy_unit}")
9058
ax.set_ylabel(f"$y$ / {xy_unit}")

src/radiosim/ppdisks/simulation.py

Lines changed: 131 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
import shutil
33
from os import PathLike
44
from pathlib import Path
5+
from time import time
56

67
import matplotlib
78
import matplotlib.animation as animation
8-
import matplotlib.pyplot as plt
99
import numpy as np
1010
from astropy import constants as const
1111
from astropy import units as un
12+
from astropy.time import Time
13+
from numpy.typing import ArrayLike
1214
from tqdm.auto import tqdm
1315

1416
from radiosim.ppdisks.config import TOMLConfiguration
15-
from radiosim.ppdisks.plotting.plotting import plot_polar_image
17+
from radiosim.ppdisks.plotting.plotting import plot_image
18+
from radiosim.ppdisks.plotting.utils import ellipse_img2cartesian_img
1619

1720
from .config import Variables
1821
from .config.fargo import Constants, Planet, PlanetConfig, UnitSystem
@@ -746,36 +749,95 @@ def get_dust_density(self, output_idx: int = -1, dust_idx: int = 1) -> np.ndarra
746749
dtype=self._run.get_float_type(),
747750
).reshape(self._run.get_polar_img_size())
748751

752+
def get_polar_dust_density(
753+
self,
754+
grid_shape: tuple[int],
755+
output_idx: int = -1,
756+
dust_idx: int = 1,
757+
a_maj: float = 1.0,
758+
b_min: float = 1.0,
759+
rot_angle: float = 0.0,
760+
xy_lims: ArrayLike | None = None,
761+
xy_unit: un.Unit = un.AU,
762+
) -> tuple[np.ndarray, float, float]:
763+
polar_intensities = self.get_dust_density(
764+
output_idx=output_idx, dust_idx=dust_idx
765+
)
766+
767+
r_min, r_max = self.get_radius_lims()
768+
r_min = (r_min * un.AU).to(xy_unit).value
769+
r_max = (r_max * un.AU).to(xy_unit).value
770+
771+
r = np.linspace(
772+
r_min,
773+
r_max,
774+
polar_intensities.shape[0],
775+
)
776+
phi = np.linspace(0, 2 * np.pi, polar_intensities.shape[1])
777+
778+
rs, phis = np.meshgrid(r, phi)
779+
rs = rs.T
780+
phis = phis.T
781+
782+
return (
783+
ellipse_img2cartesian_img(
784+
r=rs,
785+
phi=phis,
786+
intensities=polar_intensities,
787+
grid_shape=grid_shape,
788+
a=a_maj,
789+
b=b_min,
790+
alpha=rot_angle,
791+
xy_lims=xy_lims
792+
if xy_lims is not None
793+
else [[-r_max, r_max], [-r_max, r_max]],
794+
),
795+
r_min,
796+
r_max,
797+
)
798+
749799
def plot_dust_density(
750800
self,
751801
grid_shape: tuple,
752802
output_idx: int = -1,
753803
dust_idx: int = 1,
804+
a_maj: float = 1.0,
805+
b_min: float = 1.0,
806+
rot_angle: float = 0.0,
807+
xy_lims: ArrayLike | None = None,
754808
xy_unit: un.Unit = un.AU,
809+
intensity_limits: ArrayLike | None = None,
755810
save_to: str | None = None,
756811
save_args: dict = None,
757812
**kwargs,
758813
) -> tuple[
759814
matplotlib.image.AxesImage, matplotlib.figure.Figure, matplotlib.axes.Axes
760815
]:
761-
r_min, r_max = self.get_radius_lims()
762-
r_min = (r_min * un.AU).to(xy_unit).value
763-
r_max = (r_max * un.AU).to(xy_unit).value
764-
765816
unit_system = self._run._sim._unit_system
766817

767-
return plot_polar_image(
768-
polar_intensities=self.get_dust_density(
769-
output_idx=output_idx, dust_idx=dust_idx
770-
),
818+
polar_intensities, _, r_max = self.get_polar_dust_density(
771819
grid_shape=grid_shape,
772-
r_lims=[r_min, r_max],
773-
intensity_unit=(
820+
output_idx=output_idx,
821+
dust_idx=dust_idx,
822+
a_maj=a_maj,
823+
b_min=b_min,
824+
rot_angle=rot_angle,
825+
xy_lims=xy_lims,
826+
xy_unit=xy_unit,
827+
)
828+
829+
xy_lims = xy_lims if xy_lims is not None else [[-r_max, r_max], [-r_max, r_max]]
830+
831+
return plot_image(
832+
data=polar_intensities,
833+
xy_lims=xy_lims,
834+
intensity_label=(
774835
"Dust density / "
775836
f"{
776837
(unit_system.mass / unit_system.length**2).to_string(format='latex')
777838
}"
778839
),
840+
intensity_limits=intensity_limits,
779841
dtype=self._run.get_float_type(),
780842
save_to=save_to,
781843
save_args=save_args,
@@ -785,9 +847,17 @@ def plot_dust_density(
785847
def animate_dust_density(
786848
self,
787849
grid_shape: tuple,
788-
save_to: str | PathLike,
789850
step_size: int,
851+
output_fmt: str = "mp4",
852+
output_dir: str | PathLike | None = None,
853+
save_to: str | PathLike | None = None,
854+
save_with_timestamp: bool = False,
790855
dust_idx: int = 1,
856+
start_idx: int = 0,
857+
a_maj: float = 1.0,
858+
b_min: float = 1.0,
859+
rot_angle: float = 0.0,
860+
end_idx: int | None = None,
791861
xy_unit: un.Unit = un.AU,
792862
save_args: dict = None,
793863
fps: int = 30,
@@ -796,36 +866,60 @@ def animate_dust_density(
796866
show_progress: bool = True,
797867
**kwargs,
798868
) -> None:
799-
save_to = Path(save_to)
800-
801-
ims = []
802-
803-
fig, ax = plt.subplots()
869+
if save_to is not None:
870+
save_to = Path(save_to)
871+
else:
872+
output_dir = Path(output_dir)
873+
save_to = (
874+
output_dir / f"{self._run._sim.name}-run_{self._run._id}-"
875+
f"model_{self._id}.{output_fmt}"
876+
)
804877

805878
print(
806879
"Animation length will be: "
807880
f"{np.round(self.get_num_outputs() // step_size / fps, 2)} seconds"
808881
)
809882

810-
for i in tqdm(
811-
np.arange(start=0, stop=self.get_num_outputs(), step=step_size),
812-
desc="Plotting densities",
813-
disable=not show_progress,
814-
):
815-
im, _, _ = self.plot_dust_density(
883+
end_idx = self.get_num_outputs() if end_idx is None else end_idx
884+
num_outputs = (end_idx - start_idx) // step_size
885+
886+
data = np.zeros((num_outputs, *grid_shape))
887+
888+
output_idcs = np.arange(start_idx, end_idx + step_size, step=step_size)
889+
890+
for i in np.arange(0, num_outputs):
891+
img, _, _ = self.get_polar_dust_density(
816892
grid_shape=grid_shape,
817-
output_idx=i,
893+
output_idx=output_idcs[i],
818894
dust_idx=dust_idx,
895+
a_maj=a_maj,
896+
b_min=b_min,
897+
rot_angle=rot_angle,
819898
xy_unit=xy_unit,
820-
fig=fig,
821-
ax=ax,
822-
**kwargs,
823899
)
824-
ims.append([im])
825-
if hasattr(im, "colorbar") and im.colorbar is not None:
826-
im.colorbar.ax.remove()
900+
data[i] = img
901+
902+
print(f"{[data[data > 0].min(), data.max()]}")
903+
904+
im, fig, ax = self.plot_dust_density(
905+
grid_shape=grid_shape,
906+
output_idx=start_idx,
907+
dust_idx=dust_idx,
908+
xy_unit=xy_unit,
909+
a_maj=a_maj,
910+
b_min=b_min,
911+
rot_angle=rot_angle,
912+
intensity_limits=[data[data > 0].min(), data.max()],
913+
**kwargs,
914+
)
827915

828-
anim = animation.ArtistAnimation(fig, ims, blit=blit, interval=1e3 / fps)
916+
def update(frame: int):
917+
im.set_data(data[frame + 1])
918+
return [im]
919+
920+
anim = animation.FuncAnimation(
921+
fig=fig, func=update, frames=num_outputs - 1, blit=blit, interval=1e3 / fps
922+
)
829923

830924
writer = None
831925
if save_to.suffix.lower() == ".gif":
@@ -839,8 +933,12 @@ def _progress_func(_i, _n):
839933
progress_bar.update(1)
840934

841935
with tqdm(
842-
total=len(ims), desc="Saving animation", disable=not show_progress
936+
total=num_outputs - 1, desc="Saving animation", disable=not show_progress
843937
) as progress_bar:
938+
if save_with_timestamp:
939+
save_to = save_to.with_name(
940+
f"{save_to.stem}-{Time(time(), format='unix').isot}"
941+
)
844942
if writer is None:
845943
anim.save(save_to, progress_callback=_progress_func, dpi=dpi)
846944
else:

0 commit comments

Comments
 (0)