Skip to content

Commit 80efae9

Browse files
committed
feature: Add better cbar behaviour to plotting
1 parent 595c8da commit 80efae9

File tree

3 files changed

+60
-67
lines changed

3 files changed

+60
-67
lines changed

src/pyvisgrid/core/gridder.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -732,17 +732,7 @@ def plot_ungridded_uv(self, **kwargs):
732732
The axes object.
733733
"""
734734

735-
if "mode" not in kwargs:
736-
kwargs["mode"] = "wave"
737-
738-
if kwargs["mode"] == "wave":
739-
u = self.u_wave
740-
v = self.v_wave
741-
else:
742-
u = self.u_meter
743-
v = self.v_meter
744-
745-
return plotting.plot_ungridded_uv(u=u, v=v, times=self.times.mjd, **kwargs)
735+
return plotting.plot_ungridded_uv(gridder=self, **kwargs)
746736

747737
def plot_mask(self, stokes_component: str = "I", **kwargs):
748738
"""Plots the (u,v) mask (the binned visibilities) of the gridded

src/pyvisgrid/plotting/animations.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@
1414
from astropy.coordinates import ITRS, SkyCoord
1515
from astropy.time import Time
1616
from cartopy.feature.nightshade import Nightshade
17-
from matplotlib.ticker import NullFormatter
1817
from mergedeep import merge
19-
from mpl_toolkits.axes_grid1 import make_axes_locatable
2018
from radiotools.layouts import Layout
2119
from tqdm.auto import tqdm
2220

2321
if TYPE_CHECKING:
2422
from pyvisgrid.core.gridder import GridData, GridDataSeries
2523

26-
from pyvisgrid.plotting.plotting import _configure_axes, _get_norm
24+
from pyvisgrid.plotting.plotting import _configure_axes, _configure_colorbar, _get_norm
2725

2826
__all__ = ["plot_earth_layout", "plot_observation_state", "animate_observation"]
2927

@@ -34,30 +32,6 @@ def _is_value_in(value: object, lst: list):
3432
return value in np.ravel(lst)
3533

3634

37-
# based on https://stackoverflow.com/a/18195921 by "bogatron"
38-
def _configure_colorbar(
39-
mappable: mpl.cm.ScalarMappable,
40-
ax: mpl.axes.Axes,
41-
fig: mpl.figure.Figure,
42-
label: str | None,
43-
show_ticks: bool,
44-
fontsize: str = "medium",
45-
) -> mpl.colorbar.Colorbar:
46-
divider = make_axes_locatable(ax)
47-
cax = divider.append_axes("right", size="5%", pad=0.05)
48-
cbar = fig.colorbar(mappable, cax=cax)
49-
cbar.set_label(label, fontsize=fontsize)
50-
51-
if not show_ticks:
52-
cbar.set_ticks([])
53-
cbar.ax.yaxis.set_major_formatter(NullFormatter())
54-
cbar.ax.yaxis.set_minor_formatter(NullFormatter())
55-
else:
56-
cbar.ax.tick_params(labelsize=fontsize)
57-
58-
return cbar
59-
60-
6135
def _times2hours(times: np.ndarray):
6236
times = Time(times, format="mjd")
6337
times = times.unix

src/pyvisgrid/plotting/plotting.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
from __future__ import annotations
2+
13
import warnings
4+
from typing import TYPE_CHECKING
25

36
import astropy.units as units
47
import matplotlib
8+
import matplotlib as mpl
59
import matplotlib.pyplot as plt
610
import numpy as np
711
from astropy.time import Time
12+
from matplotlib.ticker import NullFormatter
13+
from mpl_toolkits.axes_grid1 import make_axes_locatable
14+
15+
if TYPE_CHECKING:
16+
from pyvisgrid.core.gridder import Gridder
817

918
__all__ = ["plot_ungridded_uv", "plot_dirty_image", "plot_mask"]
1019

@@ -155,15 +164,36 @@ def _apply_crop(ax: matplotlib.axes.Axes, crop: tuple[list[float | None]]):
155164
ax.set_ylim(crop[1][0], crop[1][1])
156165

157166

167+
# based on https://stackoverflow.com/a/18195921 by "bogatron"
168+
def _configure_colorbar(
169+
mappable: mpl.cm.ScalarMappable,
170+
ax: mpl.axes.Axes,
171+
fig: mpl.figure.Figure,
172+
label: str | None,
173+
show_ticks: bool = True,
174+
fontsize: str = "medium",
175+
) -> mpl.colorbar.Colorbar:
176+
divider = make_axes_locatable(ax)
177+
cax = divider.append_axes("right", size="5%", pad=0.05)
178+
cbar = fig.colorbar(mappable, cax=cax)
179+
cbar.set_label(label, fontsize=fontsize)
180+
181+
if not show_ticks:
182+
cbar.set_ticks([])
183+
cbar.ax.yaxis.set_major_formatter(NullFormatter())
184+
cbar.ax.yaxis.set_minor_formatter(NullFormatter())
185+
else:
186+
cbar.ax.tick_params(labelsize=fontsize)
187+
188+
return cbar
189+
190+
158191
def plot_ungridded_uv(
159-
u: np.ndarray,
160-
v: np.ndarray,
161-
times: np.ndarray,
192+
gridder: Gridder,
162193
mode: str = "wave",
163194
show_times: bool = True,
164195
use_relative_time: bool = True,
165196
time_cmap: str | matplotlib.colors.Colormap = "inferno",
166-
colorbar_shrink: float = 1.0,
167197
marker_size: float | None = None,
168198
aspect_args: dict | None = None,
169199
plot_args: dict = None,
@@ -195,10 +225,6 @@ def plot_ungridded_uv(
195225
times_cmap: str | matplotlib.colors.Colormap, optional
196226
The colormap to be used for the time component of the plot.
197227
Default is ``'inferno'``.
198-
colorbar_shrink: float, optional
199-
The shrink parameter of the colorbar. This can be needed if the plot is
200-
included as a subplot to adjust the size of the colorbar.
201-
Default is ``1``, meaning original scale.
202228
marker_size : float | None, optional
203229
The size of the scatter markers in points**2.
204230
Default is ``None``, meaning the default value supplied by
@@ -251,15 +277,21 @@ def plot_ungridded_uv(
251277

252278
match mode:
253279
case "wave":
280+
u = gridder.u_wave
281+
v = gridder.v_wave
254282
unit = "$\\lambda$"
255283
case "meter":
284+
u = gridder.u_meter
285+
v = gridder.v_meter
256286
unit = "m"
257287
case _:
258288
raise ValueError(
259289
"The given mode does not exist! Valid modes are: wave, meter."
260290
)
261291

262-
times = Time(np.tile(times, reps=2), format="mjd") if show_times else None
292+
times = (
293+
Time(np.tile(gridder.times.mjd, reps=2), format="mjd") if show_times else None
294+
)
263295
time_unit = "MJD"
264296

265297
if use_relative_time and show_times:
@@ -279,7 +311,7 @@ def plot_ungridded_uv(
279311
)
280312

281313
if show_times:
282-
fig.colorbar(scat, ax=ax, shrink=colorbar_shrink, label="Time / " + time_unit)
314+
_configure_colorbar(mappable=scat, ax=ax, fig=fig, label="Time / " + time_unit)
283315

284316
ax.set_aspect(**aspect_args)
285317
scat.set_rasterized(True)
@@ -290,15 +322,14 @@ def plot_ungridded_uv(
290322
if save_to is not None:
291323
fig.savefig(save_to, **save_args)
292324

293-
return fig, ax, scat
325+
return fig, ax
294326

295327

296328
def plot_mask(
297329
grid_data,
298330
mode: str = "hist",
299331
crop: tuple[list[float | None]] = ([None, None], [None, None]),
300332
norm: str | matplotlib.colors.Normalize = None,
301-
colorbar_shrink: float = 1,
302333
cmap: str | matplotlib.colors.Colormap | None = None,
303334
plot_args: dict = None,
304335
fig_args: dict = None,
@@ -369,10 +400,6 @@ def plot_mask(
369400
itself.
370401
371402
Default is ``None``, meaning no norm will be applied.
372-
colorbar_shrink: float, optional
373-
The shrink parameter of the colorbar. This can be needed if the plot is
374-
included as a subplot to adjust the size of the colorbar.
375-
Default is ``1``, meaning original scale.
376403
cmap: str | matplotlib.colors.Colormap | None, optional
377404
The colormap to be used for the plot.
378405
Default is ``None``, meaning the colormap will be default to a value
@@ -424,8 +451,8 @@ def plot_mask(
424451
"hist": "inferno",
425452
"abs": "viridis",
426453
"phase": "RdBu",
427-
"real": "RdBu",
428-
"imag": "RdBu",
454+
"real": "PiYG",
455+
"imag": "PuOr",
429456
}
430457

431458
cmap = cmap_dict[mode] if cmap is None else cmap
@@ -442,9 +469,10 @@ def plot_mask(
442469
cmap=cmap,
443470
**plot_args,
444471
)
445-
fig.colorbar(
446-
im, ax=ax, shrink=colorbar_shrink, label="$(u,v)$ per frequel / 1/fq"
472+
_configure_colorbar(
473+
mappable=im, ax=ax, fig=fig, label="$(u,v)$ per frequel / 1/fq"
447474
)
475+
448476
case "abs":
449477
mask_abs, _ = grid_data.get_mask_abs_phase()
450478
im = ax.imshow(
@@ -455,7 +483,7 @@ def plot_mask(
455483
cmap=cmap,
456484
**plot_args,
457485
)
458-
fig.colorbar(im, ax=ax, shrink=colorbar_shrink, label="Amplitude / a.u.")
486+
_configure_colorbar(mappable=im, ax=ax, fig=fig, label="Amplitude / a.u.")
459487
case "phase":
460488
_, mask_phase = grid_data.get_mask_abs_phase()
461489
im = ax.imshow(
@@ -466,8 +494,7 @@ def plot_mask(
466494
cmap=cmap,
467495
**plot_args,
468496
)
469-
cbar = fig.colorbar(im, ax=ax, shrink=colorbar_shrink, label="Phase / rad")
470-
497+
cbar = _configure_colorbar(mappable=im, ax=ax, fig=fig, label="Phase / rad")
471498
cbar.set_ticks(np.arange(-np.pi, 3 / 2 * np.pi, np.pi / 2))
472499
cbar.set_ticklabels(["$-\\pi$", "$-\\pi/2$", "$0$", "$\\pi/2$", "$\\pi$"])
473500
case "real":
@@ -479,7 +506,7 @@ def plot_mask(
479506
cmap=cmap,
480507
**plot_args,
481508
)
482-
fig.colorbar(im, ax=ax, shrink=colorbar_shrink, label="Real Part / a.u.")
509+
_configure_colorbar(mappable=im, ax=ax, fig=fig, label="Real Part / a.u.")
483510
case "imag":
484511
im = ax.imshow(
485512
grid_data.mask_imag,
@@ -489,9 +516,11 @@ def plot_mask(
489516
cmap=cmap,
490517
**plot_args,
491518
)
492-
fig.colorbar(
493-
im, ax=ax, shrink=colorbar_shrink, label="Imaginary Part / a.u."
519+
520+
_configure_colorbar(
521+
mappable=im, ax=ax, fig=fig, label="Imaginary Part / a.u."
494522
)
523+
495524
case _:
496525
raise ValueError(
497526
f"The given mode does not exist!"
@@ -506,7 +535,7 @@ def plot_mask(
506535
if save_to is not None:
507536
fig.savefig(save_to, **save_args)
508537

509-
return fig, ax, im
538+
return fig, ax
510539

511540

512541
def plot_dirty_image(
@@ -693,9 +722,9 @@ def plot_dirty_image(
693722
**plot_args,
694723
)
695724

696-
fig.colorbar(im, ax=ax, shrink=colorbar_shrink, label="Flux Density / Jy/pix")
725+
_configure_colorbar(mappable=im, ax=ax, fig=fig, label="Flux Density / Jy/pix")
697726

698727
if save_to is not None:
699728
fig.savefig(save_to, **save_args)
700729

701-
return fig, ax, im
730+
return fig, ax

0 commit comments

Comments
 (0)