1+ from __future__ import annotations
2+
13import warnings
4+ from typing import TYPE_CHECKING
25
36import astropy .units as units
47import matplotlib
8+ import matplotlib as mpl
59import matplotlib .pyplot as plt
610import numpy as np
711from astropy .time import Time
12+ from matplotlib .ticker import NullFormatter
13+ from mpl_toolkits .axes_grid1 import make_axes_locatable
14+
15+ if TYPE_CHECKING :
16+ from pyvisgrid .core .gridder import Gridder
817
918__all__ = ["plot_ungridded_uv" , "plot_dirty_image" , "plot_mask" ]
1019
@@ -155,15 +164,36 @@ def _apply_crop(ax: matplotlib.axes.Axes, crop: tuple[list[float | None]]):
155164 ax .set_ylim (crop [1 ][0 ], crop [1 ][1 ])
156165
157166
167+ # based on https://stackoverflow.com/a/18195921 by "bogatron"
168+ def _configure_colorbar (
169+ mappable : mpl .cm .ScalarMappable ,
170+ ax : mpl .axes .Axes ,
171+ fig : mpl .figure .Figure ,
172+ label : str | None ,
173+ show_ticks : bool = True ,
174+ fontsize : str = "medium" ,
175+ ) -> mpl .colorbar .Colorbar :
176+ divider = make_axes_locatable (ax )
177+ cax = divider .append_axes ("right" , size = "5%" , pad = 0.05 )
178+ cbar = fig .colorbar (mappable , cax = cax )
179+ cbar .set_label (label , fontsize = fontsize )
180+
181+ if not show_ticks :
182+ cbar .set_ticks ([])
183+ cbar .ax .yaxis .set_major_formatter (NullFormatter ())
184+ cbar .ax .yaxis .set_minor_formatter (NullFormatter ())
185+ else :
186+ cbar .ax .tick_params (labelsize = fontsize )
187+
188+ return cbar
189+
190+
158191def plot_ungridded_uv (
159- u : np .ndarray ,
160- v : np .ndarray ,
161- times : np .ndarray ,
192+ gridder : Gridder ,
162193 mode : str = "wave" ,
163194 show_times : bool = True ,
164195 use_relative_time : bool = True ,
165196 time_cmap : str | matplotlib .colors .Colormap = "inferno" ,
166- colorbar_shrink : float = 1.0 ,
167197 marker_size : float | None = None ,
168198 aspect_args : dict | None = None ,
169199 plot_args : dict = None ,
@@ -195,10 +225,6 @@ def plot_ungridded_uv(
195225 times_cmap: str | matplotlib.colors.Colormap, optional
196226 The colormap to be used for the time component of the plot.
197227 Default is ``'inferno'``.
198- colorbar_shrink: float, optional
199- The shrink parameter of the colorbar. This can be needed if the plot is
200- included as a subplot to adjust the size of the colorbar.
201- Default is ``1``, meaning original scale.
202228 marker_size : float | None, optional
203229 The size of the scatter markers in points**2.
204230 Default is ``None``, meaning the default value supplied by
@@ -251,15 +277,21 @@ def plot_ungridded_uv(
251277
252278 match mode :
253279 case "wave" :
280+ u = gridder .u_wave
281+ v = gridder .v_wave
254282 unit = "$\\ lambda$"
255283 case "meter" :
284+ u = gridder .u_meter
285+ v = gridder .v_meter
256286 unit = "m"
257287 case _:
258288 raise ValueError (
259289 "The given mode does not exist! Valid modes are: wave, meter."
260290 )
261291
262- times = Time (np .tile (times , reps = 2 ), format = "mjd" ) if show_times else None
292+ times = (
293+ Time (np .tile (gridder .times .mjd , reps = 2 ), format = "mjd" ) if show_times else None
294+ )
263295 time_unit = "MJD"
264296
265297 if use_relative_time and show_times :
@@ -279,7 +311,7 @@ def plot_ungridded_uv(
279311 )
280312
281313 if show_times :
282- fig . colorbar ( scat , ax = ax , shrink = colorbar_shrink , label = "Time / " + time_unit )
314+ _configure_colorbar ( mappable = scat , ax = ax , fig = fig , label = "Time / " + time_unit )
283315
284316 ax .set_aspect (** aspect_args )
285317 scat .set_rasterized (True )
@@ -290,15 +322,14 @@ def plot_ungridded_uv(
290322 if save_to is not None :
291323 fig .savefig (save_to , ** save_args )
292324
293- return fig , ax , scat
325+ return fig , ax
294326
295327
296328def plot_mask (
297329 grid_data ,
298330 mode : str = "hist" ,
299331 crop : tuple [list [float | None ]] = ([None , None ], [None , None ]),
300332 norm : str | matplotlib .colors .Normalize = None ,
301- colorbar_shrink : float = 1 ,
302333 cmap : str | matplotlib .colors .Colormap | None = None ,
303334 plot_args : dict = None ,
304335 fig_args : dict = None ,
@@ -369,10 +400,6 @@ def plot_mask(
369400 itself.
370401
371402 Default is ``None``, meaning no norm will be applied.
372- colorbar_shrink: float, optional
373- The shrink parameter of the colorbar. This can be needed if the plot is
374- included as a subplot to adjust the size of the colorbar.
375- Default is ``1``, meaning original scale.
376403 cmap: str | matplotlib.colors.Colormap | None, optional
377404 The colormap to be used for the plot.
378405 Default is ``None``, meaning the colormap will be default to a value
@@ -424,8 +451,8 @@ def plot_mask(
424451 "hist" : "inferno" ,
425452 "abs" : "viridis" ,
426453 "phase" : "RdBu" ,
427- "real" : "RdBu " ,
428- "imag" : "RdBu " ,
454+ "real" : "PiYG " ,
455+ "imag" : "PuOr " ,
429456 }
430457
431458 cmap = cmap_dict [mode ] if cmap is None else cmap
@@ -442,9 +469,10 @@ def plot_mask(
442469 cmap = cmap ,
443470 ** plot_args ,
444471 )
445- fig . colorbar (
446- im , ax = ax , shrink = colorbar_shrink , label = "$(u,v)$ per frequel / 1/fq"
472+ _configure_colorbar (
473+ mappable = im , ax = ax , fig = fig , label = "$(u,v)$ per frequel / 1/fq"
447474 )
475+
448476 case "abs" :
449477 mask_abs , _ = grid_data .get_mask_abs_phase ()
450478 im = ax .imshow (
@@ -455,7 +483,7 @@ def plot_mask(
455483 cmap = cmap ,
456484 ** plot_args ,
457485 )
458- fig . colorbar ( im , ax = ax , shrink = colorbar_shrink , label = "Amplitude / a.u." )
486+ _configure_colorbar ( mappable = im , ax = ax , fig = fig , label = "Amplitude / a.u." )
459487 case "phase" :
460488 _ , mask_phase = grid_data .get_mask_abs_phase ()
461489 im = ax .imshow (
@@ -466,8 +494,7 @@ def plot_mask(
466494 cmap = cmap ,
467495 ** plot_args ,
468496 )
469- cbar = fig .colorbar (im , ax = ax , shrink = colorbar_shrink , label = "Phase / rad" )
470-
497+ cbar = _configure_colorbar (mappable = im , ax = ax , fig = fig , label = "Phase / rad" )
471498 cbar .set_ticks (np .arange (- np .pi , 3 / 2 * np .pi , np .pi / 2 ))
472499 cbar .set_ticklabels (["$-\\ pi$" , "$-\\ pi/2$" , "$0$" , "$\\ pi/2$" , "$\\ pi$" ])
473500 case "real" :
@@ -479,7 +506,7 @@ def plot_mask(
479506 cmap = cmap ,
480507 ** plot_args ,
481508 )
482- fig . colorbar ( im , ax = ax , shrink = colorbar_shrink , label = "Real Part / a.u." )
509+ _configure_colorbar ( mappable = im , ax = ax , fig = fig , label = "Real Part / a.u." )
483510 case "imag" :
484511 im = ax .imshow (
485512 grid_data .mask_imag ,
@@ -489,9 +516,11 @@ def plot_mask(
489516 cmap = cmap ,
490517 ** plot_args ,
491518 )
492- fig .colorbar (
493- im , ax = ax , shrink = colorbar_shrink , label = "Imaginary Part / a.u."
519+
520+ _configure_colorbar (
521+ mappable = im , ax = ax , fig = fig , label = "Imaginary Part / a.u."
494522 )
523+
495524 case _:
496525 raise ValueError (
497526 f"The given mode does not exist!"
@@ -506,7 +535,7 @@ def plot_mask(
506535 if save_to is not None :
507536 fig .savefig (save_to , ** save_args )
508537
509- return fig , ax , im
538+ return fig , ax
510539
511540
512541def plot_dirty_image (
@@ -693,9 +722,9 @@ def plot_dirty_image(
693722 ** plot_args ,
694723 )
695724
696- fig . colorbar ( im , ax = ax , shrink = colorbar_shrink , label = "Flux Density / Jy/pix" )
725+ _configure_colorbar ( mappable = im , ax = ax , fig = fig , label = "Flux Density / Jy/pix" )
697726
698727 if save_to is not None :
699728 fig .savefig (save_to , ** save_args )
700729
701- return fig , ax , im
730+ return fig , ax
0 commit comments