Skip to content

Commit 5ab8616

Browse files
committed
feat: Add time recording, Add plotting utils
1 parent 56b3755 commit 5ab8616

File tree

5 files changed

+297
-14
lines changed

5 files changed

+297
-14
lines changed

src/radiosim/ppdisks/plotting/__init__.py

Whitespace-only changes.

src/radiosim/ppdisks/plotting/plotting.py

Whitespace-only changes.
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import warnings
2+
3+
import matplotlib
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
from matplotlib.ticker import NullFormatter
7+
from mpl_toolkits.axes_grid1 import make_axes_locatable
8+
9+
10+
def configure_axes(
11+
fig: matplotlib.figure.Figure | None,
12+
ax: matplotlib.axes.Axes | None,
13+
fig_args: dict = None,
14+
):
15+
"""Configures figure and axis depending if they were given
16+
as parameters.
17+
18+
If neither figure nor axis are given, a new subplot will be created.
19+
If they are given the given ones will be returned.
20+
If only one of both is not given, this will cause an exception.
21+
22+
Parameters
23+
----------
24+
fig : matplotlib.figure.Figure | None
25+
The figure object.
26+
ax : matplotlib.axes.Axes | None
27+
The axes object.
28+
fig_args : dict, optional
29+
Optional arguments to be supplied to the ``plt.subplots`` call.
30+
31+
Returns
32+
-------
33+
fig : matplotlib.figure.Figure
34+
The figure object.
35+
ax : matplotlib.axes.Axes
36+
The axes object.
37+
"""
38+
if fig_args is None:
39+
fig_args = {}
40+
41+
if None in (fig, ax) and not all(x is None for x in (fig, ax)):
42+
raise KeyError("The parameters ax and fig have to be both None or not None!")
43+
44+
if ax is None:
45+
fig, ax = plt.subplots(layout="tight", **fig_args)
46+
47+
return fig, ax
48+
49+
50+
def _get_norm(
51+
norm: str,
52+
vmax: float | None = None,
53+
vmin: float | None = None,
54+
vcenter: float = 0,
55+
):
56+
"""Converts a string parameter to a matplotlib norm.
57+
58+
Parameters
59+
----------
60+
norm : str
61+
The name of the norm.
62+
Possible values are:
63+
64+
- ``log``: Returns a logarithmic norm with clipping on (!), meaning
65+
values above the maximum will be mapped to the maximum and
66+
values below the minimum will be mapped to the minimum, thus
67+
avoiding the appearance of a colormaps 'over' and 'under'
68+
colors (e.g. in case of negative values). Depending on the
69+
use case this is desirable but in case that it is not, one
70+
can set the norm to ``log_noclip`` or provide a custom norm.
71+
72+
- ``log_noclip``: Returns a logarithmic norm with clipping off.
73+
74+
- ``centered``: Returns a linear norm which centered around zero.
75+
76+
- ``sqrt``: Returns a power norm with exponent 0.5, meaning the
77+
square-root of the values.
78+
79+
- other: A value not declared above will be returned as is, meaning
80+
that this could be any value which exists in matplotlib
81+
itself.
82+
83+
vmax : float | None, optional
84+
The maximum value of the range to normalize. This might not have an effect
85+
for every norm. Default is ``None``.
86+
87+
vmin : float | None, optional
88+
The minimum value of the range to normalize. This might not have an effect
89+
for every norm. Default is ``None``.
90+
91+
vcenter : float | None, optional
92+
The central value of the range to normalize. This might not have an effect
93+
for every norm. Default is ``0``.
94+
95+
Returns
96+
-------
97+
matplotlib.colors.Normalize | str
98+
The norm or the str if no specific norm is defined for the string.
99+
"""
100+
match norm:
101+
case "log":
102+
if vmin == 0:
103+
vmin = np.finfo(float).eps
104+
warnings.warn(
105+
f"Since the given vmin is 0, the value was set to {vmin}"
106+
" to enable logarithmic normalization.",
107+
stacklevel=1,
108+
)
109+
110+
return matplotlib.colors.LogNorm(clip=True, vmin=vmin, vmax=vmax)
111+
case "log_noclip":
112+
if vmin == 0:
113+
vmin = np.finfo(float).eps
114+
warnings.warn(
115+
f"Since the given vmin is 0, the value was set to {vmin}"
116+
" to enable logarithmic normalization.",
117+
stacklevel=1,
118+
)
119+
120+
return matplotlib.colors.LogNorm(clip=False, vmin=vmin, vmax=vmax)
121+
case "centered":
122+
if vmin is not None and vmax is not None:
123+
return matplotlib.colors.CenteredNorm(
124+
vcenter=vcenter, halfrange=np.max([np.abs(vmin), np.abs(vmax)])
125+
)
126+
else:
127+
return matplotlib.colors.CenteredNorm(vcenter=vcenter)
128+
129+
case "sqrt":
130+
return matplotlib.colors.PowerNorm(0.5, vmin=vmin, vmax=vmax)
131+
case _:
132+
return norm
133+
134+
135+
def apply_crop(ax: matplotlib.axes.Axes, crop: tuple[list[float | None]]):
136+
"""Applies a specific x and y limit ('crop') to the given axis.
137+
This will effectively crop the image.
138+
139+
Parameters
140+
----------
141+
ax : matplotlib.axes.Axes
142+
The axis which to apply the limits to.
143+
crop : tuple[list[float | None]]
144+
The crop of the image. This has to have the format
145+
``([x_left, x_right], [y_left, y_right])``, where the left and right
146+
values for each axis are the upper and lower limits of the axes which
147+
should be shown.
148+
IMPORTANT: If one supplies the ``plt.imshow`` an ``extent`` parameter,
149+
this will be the scale in which one has to give the crop! If not, the crop
150+
has to be in pixels.
151+
"""
152+
ax.set_xlim(crop[0][0], crop[0][1])
153+
ax.set_ylim(crop[1][0], crop[1][1])
154+
155+
156+
# based on https://stackoverflow.com/a/18195921 by "bogatron"
157+
# Marked code (inside >>> BEGIN / <<< END) is licensed under CC BY-SA 3.0
158+
def configure_colorbar(
159+
mappable: matplotlib.cm.ScalarMappable,
160+
ax: matplotlib.axes.Axes,
161+
fig: matplotlib.figure.Figure,
162+
label: str | None,
163+
show_ticks: bool = True,
164+
fontsize: str = "medium",
165+
) -> matplotlib.colorbar.Colorbar:
166+
# >>> BEGIN
167+
divider = make_axes_locatable(ax)
168+
cax = divider.append_axes("right", size="5%", pad=0.05)
169+
cbar = fig.colorbar(mappable, cax=cax)
170+
cbar.set_label(label, fontsize=fontsize)
171+
172+
if not show_ticks:
173+
cbar.set_ticks([])
174+
cbar.ax.yaxis.set_major_formatter(NullFormatter())
175+
cbar.ax.yaxis.set_minor_formatter(NullFormatter())
176+
else:
177+
cbar.ax.tick_params(labelsize=fontsize)
178+
# <<< END
179+
180+
return cbar
181+
182+
183+
def ellipse2cartesian(r: np.ndarray, phi: np.ndarray, a: float, b: float, alpha: float):
184+
alpha = np.deg2rad(alpha)
185+
return (
186+
r * (a * np.cos(phi) * np.cos(alpha) - b * np.sin(phi) * np.sin(alpha)),
187+
r * (a * np.cos(phi) * np.sin(alpha) + b * np.sin(phi) * np.cos(alpha)),
188+
)
189+
190+
191+
def xy2pix(
192+
x: np.ndarray,
193+
y: np.ndarray,
194+
shape: tuple[int],
195+
xy_lims: tuple[list[float]] = ([-1, 1], [-1, 1]),
196+
):
197+
xy_lims = np.ndarray(xy_lims)
198+
199+
delta_x = (np.abs(np.diff(xy_lims[0])) / shape[1])[0]
200+
delta_y = (np.abs(np.diff(xy_lims[1])) / shape[0])[0]
201+
202+
col_idx = np.floor((x - xy_lims[0, 0]) // delta_x).int()
203+
row_idx = np.floor((y - xy_lims[1, 0]) // delta_y).int()
204+
205+
return row_idx, col_idx
206+
207+
208+
def ellipse_img2cartesian_img(
209+
r: np.ndarray,
210+
phi: np.ndarray,
211+
intensities: np.ndarray,
212+
grid_shape: tuple[int],
213+
a: float,
214+
b: float,
215+
alpha: float,
216+
dtype: type,
217+
xy_lims: tuple[list[float]] = ([-1, 1], [-1, 1]),
218+
):
219+
image = np.zeros(grid_shape, dtype=dtype)
220+
221+
x, y = ellipse2cartesian(r, phi, a=a, b=b, alpha=alpha)
222+
row, col = xy2pix(x=x, y=y, shape=grid_shape, xy_lims=xy_lims)
223+
224+
row_mask = np.logical_and(row < grid_shape[0], row > 0)
225+
col_mask = np.logical_and(col < grid_shape[1], col > 0)
226+
mask = np.logical_and(row_mask, col_mask)
227+
228+
row = row[mask]
229+
col = col[mask]
230+
231+
image[row, col] = intensities[mask]
232+
return image

src/radiosim/ppdisks/setup.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import subprocess
2+
import time
23
from pathlib import Path
34

45
from tqdm.auto import tqdm
@@ -37,9 +38,10 @@ def compile(
3738
rescale: bool = False,
3839
show_progress: bool = True,
3940
model_id: int | None = None,
41+
return_execution_time: bool = False,
4042
verbose: bool = False,
4143
show_fargo_output: bool = False,
42-
) -> None:
44+
) -> int | None:
4345
model_desc = f" | Model {model_id}" if model_id is not None else ""
4446
with tqdm(
4547
desc="Compiling" + model_desc, total=1, disable=not show_progress
@@ -59,6 +61,8 @@ def compile(
5961
if verbose:
6062
print(f"CMD @ {Variables.get('FARGO_ROOT')} $ " + " ".join(clean_cmd))
6163

64+
starting_time = time.time_ns()
65+
6266
subprocess.run(
6367
clean_cmd,
6468
cwd=Variables.get("FARGO_ROOT"),
@@ -88,11 +92,18 @@ def compile(
8892
shell=True,
8993
)
9094

95+
compile_time = time.time_ns() - starting_time
96+
9197
if verbose:
9298
print("============ FINISH COMPILATION ============")
9399

94100
progress.update(n=1)
95101

102+
if return_execution_time:
103+
return compile_time
104+
else:
105+
return None
106+
96107
# Subprocess output capture adapted from https://stackoverflow.com/a/28319191
97108
# Marked code (inside >>> BEGIN / <<< END) is licensed under CC BY-SA 3.0
98109
def run(
@@ -102,8 +113,9 @@ def run(
102113
parallel: bool = False,
103114
num_nodes: int = 1,
104115
cuda_device_id: int = 0,
116+
return_execution_time: bool = False,
105117
verbose: bool = False,
106-
) -> None:
118+
) -> tuple[float, list[float]] | None:
107119
total_steps = self._param_config["output_parameters.ntot"].value
108120
steps_between_outputs = self._param_config["output_parameters.ninterm"].value
109121

@@ -124,6 +136,9 @@ def run(
124136
if verbose:
125137
print(f"CMD @ {Variables.get('FARGO_ROOT')} $ " + " ".join(processes))
126138

139+
starting_time = time.time_ns()
140+
output_times = []
141+
127142
# >>> BEGIN
128143
with (
129144
tqdm(
@@ -145,6 +160,13 @@ def run(
145160
if not line.startswith("OUTPUT"):
146161
continue
147162

163+
output_times.append(time.time_ns() - starting_time)
148164
progress.update(n=steps_between_outputs)
149-
150165
# <<< END
166+
167+
run_time = time.time_ns() - starting_time
168+
169+
if return_execution_time:
170+
return run_time, output_times
171+
else:
172+
return None

0 commit comments

Comments
 (0)