Skip to content

Commit c50d101

Browse files
authored
Merge pull request #17 from radionets-project/update_ms_gridding
Update MS Gridding and Dirty Image Plotting
2 parents ded9de5 + 136e693 commit c50d101

File tree

5 files changed

+152
-101
lines changed

5 files changed

+152
-101
lines changed

pyvisgrid/core/gridder.py

Lines changed: 90 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6-
import numpy
76
import numpy as np
7+
from numpy.typing import ArrayLike
8+
89
from astropy.constants import c
910
from astropy.io import fits
1011
from casatools.table import table
@@ -25,12 +26,15 @@ class GridData:
2526
DataClass to save the gridded and non-gridded visibilities for a
2627
specific Stokes component.
2728
28-
Parameters
29+
Attributes
2930
----------
3031
3132
vis_data : numpy.ndarray
3233
The ungridded visibilities.
3334
35+
fov : float
36+
The size of the Field Of View of the gridded data in arcseconds.
37+
3438
mask : numpy.ndarray, optional
3539
The mask created from the given (u,v) coordinates. The mask contains
3640
the number of (u,v) coordinates per pixel.
@@ -47,25 +51,58 @@ class GridData:
4751
4852
"""
4953

50-
vis_data: numpy.ndarray
51-
mask: numpy.ndarray | None = None
52-
mask_real: numpy.ndarray | None = None
53-
mask_imag: numpy.ndarray | None = None
54-
dirty_image: numpy.ndarray | None = None
54+
vis_data: np.ndarray
55+
fov: float | None = None
56+
mask: np.ndarray | None = None
57+
mask_real: np.ndarray | None = None
58+
mask_imag: np.ndarray | None = None
59+
dirty_image: np.ndarray | None = None
5560

5661
def __str__(self):
5762
return self.__dict__
5863

64+
def get_mask_complex(self):
65+
"""
66+
Returns the gridded mask as a complex array with
67+
the form ``mask.real + 1j * mask.imag``.
68+
69+
Returns
70+
-------
71+
numpy.ndarray: Complex mask
72+
"""
73+
return self.mask_real + 1j * self.mask_imag
74+
75+
def get_mask_abs_phase(self):
76+
"""
77+
Returns the gridded mask as amplitude and phase.
78+
79+
Returns
80+
-------
81+
tuple[numpy.ndarray]: Amplitude and phase of the complex mask.
82+
"""
83+
mask_complex = self.get_mask_complex()
84+
return np.abs(mask_complex), np.angle(mask_complex)
85+
86+
def get_mask_real_imag(self):
87+
"""
88+
Returns the gridded mask as real and imaginary parts.
89+
90+
Returns
91+
-------
92+
tuple[numpy.ndarray]: Real and parts of the complex mask.
93+
"""
94+
return self.mask_real, self.mask_imag
95+
5996

6097
class Gridder:
6198
def __init__(
6299
self,
63-
u_meter: numpy.ndarray,
64-
v_meter: numpy.ndarray,
100+
u_meter: np.ndarray,
101+
v_meter: np.ndarray,
65102
img_size: int,
66103
fov: float,
67104
ref_frequency: float,
68-
frequency_offsets: numpy.typing.ArrayLike,
105+
frequency_offsets: ArrayLike,
69106
):
70107
"""
71108
@@ -191,27 +228,35 @@ def grid(self, stokes_component: str = "I"):
191228
)
192229

193230
mask, *_ = np.histogram2d(
194-
u_wave_full, v_wave_full, bins=[bins, bins], density=False
231+
x=u_wave_full, y=v_wave_full, bins=[bins, bins], density=False
195232
)
233+
mask = (
234+
mask.T
235+
) # u (x) is histogrammed along the first dimension (rows) --> transpose
196236
mask[mask == 0] = 1
197237

198238
mask_real, _, _ = np.histogram2d(
199-
u_wave_full,
200-
v_wave_full,
239+
x=u_wave_full,
240+
y=v_wave_full,
201241
bins=[bins, bins],
202242
weights=stokes_real_full,
203243
density=False,
204244
)
205245
mask_imag, _, _ = np.histogram2d(
206-
u_wave_full,
207-
v_wave_full,
246+
x=u_wave_full,
247+
y=v_wave_full,
208248
bins=[bins, bins],
209249
weights=stokes_imag_full,
210250
density=False,
211251
)
252+
253+
mask_real = mask_real.T # see above
254+
mask_imag = mask_imag.T # see above
255+
212256
mask_real /= mask
213257
mask_imag /= mask
214258

259+
grid_data.fov = self.fov
215260
grid_data.mask = mask
216261
grid_data.mask_real = mask_real
217262
grid_data.mask_imag = mask_imag
@@ -229,7 +274,7 @@ def from_pyvisgen(
229274
img_size: int,
230275
fov: float,
231276
stokes_components: list[str] | str = "I",
232-
polarizations: list[str] | str = "",
277+
polarizations: list[str] | str | None = None,
233278
):
234279
"""
235280
Initializes the gridder with the visibility data which is generated by the
@@ -262,8 +307,8 @@ def from_pyvisgen(
262307
This can either be a list of components (e.g. ``['I', 'V']``) or a single
263308
string. Default is ``'I'``.
264309
265-
polarizations : list[str] | str, optional
266-
The polarization type. Default is ``''``.
310+
polarizations : list[str] | str | None, optional
311+
The polarization type. Default is ``None``.
267312
"""
268313
u_meter = vis_data.u
269314
v_meter = vis_data.v
@@ -291,6 +336,9 @@ def from_pyvisgen(
291336
if isinstance(stokes_components, str):
292337
stokes_components = [stokes_components]
293338

339+
if polarizations is None:
340+
polarizations = ""
341+
294342
if isinstance(polarizations, str):
295343
polarizations = [polarizations]
296344

@@ -460,26 +508,15 @@ def from_ms(
460508
mask = tab.getcol("DATA_DESC_ID") == desc_id
461509
mask_idx = np.argwhere(mask).ravel()
462510

463-
nrow = mask_idx[-1] - mask_idx[0] + 1
511+
tab_subset = tab.selectrows(rownrs=mask_idx)
464512

465-
data = tab.getcolslice(
466-
data_colname,
467-
blc=[0, 0],
468-
trc=[-1, -1],
469-
incr=[1, 1],
470-
startrow=mask_idx[0],
471-
nrow=nrow,
472-
)
473-
data = data[..., mask_idx - mask_idx[0]]
513+
data = tab_subset.getcol(data_colname)
514+
uv = tab_subset.getcol("UVW")[:2]
474515

475-
uvw = tab.getcolslice(
476-
"UVW", blc=[0], trc=[1], incr=[1], startrow=mask_idx[0], nrow=nrow
477-
)
478-
uvw = uvw[..., mask_idx - mask_idx[0]]
479516
else:
480517
mask = np.ones_like(tab.getcol("DATA_DESC_ID")).astype(bool)
481518
data = tab.getcol(data_colname)
482-
uvw = tab.getcol("UVW")[:2]
519+
uv = tab.getcol("UVW")[:2]
483520

484521
spectral_tab = table(str(path / "SPECTRAL_WINDOW"))
485522

@@ -502,13 +539,13 @@ def from_ms(
502539

503540
flag_mask = np.logical_not(flag_mask.astype(bool))
504541
else:
505-
flag_mask = np.ones(uvw.shape[-1]).astype(bool)
542+
flag_mask = np.ones(uv.shape[-1]).astype(bool)
506543

507-
uvw = uvw[..., flag_mask]
544+
uv = uv[..., flag_mask]
508545
data = data[..., flag_mask]
509546

510-
u_meter = uvw[0]
511-
v_meter = uvw[1]
547+
u_meter = uv[0]
548+
v_meter = uv[1]
512549

513550
stokes_i = data[0] + data[1]
514551

@@ -736,14 +773,22 @@ def plot_dirty_image(self, stokes_component: str = "I", **kwargs):
736773
737774
Default is ``real``.
738775
739-
crop : tuple[list[float | None]], optional
740-
The crop of the image. This has to have the format
741-
``([x_left, x_right], [y_left, y_right])``, where the left and right
742-
values for each axis are the upper and lower limits of the axes which
743-
should be shown.
744-
IMPORTANT: If one supplies the ``plt.imshow`` an ``extent`` parameter
745-
via the ``plot_args`` parameter, this will be the scale in which one
746-
has to give the crop! If not, the crop has to be in pixels.
776+
ax_unit: str | astropy.units.unit, optional
777+
The unit in which to show the ticks of the x and y-axes in.
778+
The y-axis is the Declination (DEC) and the x-axis is the Right Ascension (RA).
779+
The latter one is defined as increasing from left to right!
780+
The unit has to be given as a string or an ``astropy.units.Unit``.
781+
The string must correspond to the string representation of an ``astropy.units.Unit``.
782+
783+
Valid units are either ``pixel`` or angle units like ``arcsec``, ``degree`` etc..
784+
Default is ``pixel``.
785+
786+
center_pos: tuple | None, optional
787+
The coordinate center of the image. The coordinates have to
788+
be given in the unit defined in the parameter ``ax_unit`` above.
789+
If ``ax_unit`` is set to ``pixel`` this parameter is ignored.
790+
Default is ``None``, meaning the coordinates of the axes will be
791+
given as relative.
747792
748793
norm : str | matplotlib.colors.Normalize | None, optional
749794
The name of the norm or a matplotlib norm.

pyvisgrid/plotting/plotting.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import matplotlib.pyplot as plt
33
import numpy as np
44

5+
import warnings
6+
7+
import astropy.units as units
8+
59

610
def _configure_axes(
711
fig: matplotlib.figure.Figure | None,
@@ -206,7 +210,7 @@ def plot_ungridded_uv(
206210

207211
ax.scatter(x=np.append(-u, u), y=np.append(-v, v), s=marker_size, **plot_args)
208212

209-
ax.axis("equal")
213+
ax.set_aspect("equal", "box")
210214

211215
ax.set_xlabel(f"$u$ in {unit}")
212216
ax.set_ylabel(f"$v$ in {unit}")
@@ -375,8 +379,9 @@ def plot_mask(
375379
im, ax=ax, shrink=colorbar_shrink, label="$(u,v)$ per frequel in 1/fq"
376380
)
377381
case "abs":
382+
mask_abs, _ = grid_data.get_mask_abs_phase()
378383
im = ax.imshow(
379-
np.abs(grid_data.mask_real + 1j * grid_data.mask_imag),
384+
mask_abs,
380385
norm=norm,
381386
origin="lower",
382387
interpolation="none",
@@ -385,8 +390,9 @@ def plot_mask(
385390
)
386391
fig.colorbar(im, ax=ax, shrink=colorbar_shrink, label="Amplitude in a.u.")
387392
case "phase":
393+
_, mask_phase = grid_data.get_mask_abs_phase()
388394
im = ax.imshow(
389-
np.angle(grid_data.mask_real + 1j * grid_data.mask_imag),
395+
mask_phase,
390396
norm=norm,
391397
origin="lower",
392398
interpolation="none",
@@ -439,7 +445,8 @@ def plot_mask(
439445
def plot_dirty_image(
440446
grid_data,
441447
mode: str = "real",
442-
crop: tuple[list[float | None]] = ([None, None], [None, None]),
448+
ax_unit: str | units.Unit = "pixel",
449+
center_pos: tuple[float] | None = None,
443450
norm: str | matplotlib.colors.Normalize = None,
444451
colorbar_shrink: float = 1,
445452
cmap: str | matplotlib.colors.Colormap | None = "inferno",
@@ -464,7 +471,6 @@ def plot_dirty_image(
464471
465472
mode : str, optional
466473
The mode specifying which values of the mask should be plotted.
467-
468474
Possible values are:
469475
470476
- ``real``: Plots the real part of the dirty image.
@@ -475,14 +481,22 @@ def plot_dirty_image(
475481
476482
Default is ``real``.
477483
478-
crop : tuple[list[float | None]], optional
479-
The crop of the image. This has to have the format
480-
``([x_left, x_right], [y_left, y_right])``, where the left and right
481-
values for each axis are the upper and lower limits of the axes which
482-
should be shown.
483-
IMPORTANT: If one supplies the ``plt.imshow`` an ``extent`` parameter
484-
via the ``plot_args`` parameter, this will be the scale in which one
485-
has to give the crop! If not, the crop has to be in pixels.
484+
ax_unit: str | astropy.units.unit, optional
485+
The unit in which to show the ticks of the x and y-axes in.
486+
The y-axis is the Declination (DEC) and the x-axis is the Right Ascension (RA).
487+
The latter one is defined as increasing from left to right!
488+
The unit has to be given as a string or an ``astropy.units.Unit``.
489+
The string must correspond to the string representation of an ``astropy.units.Unit``.
490+
491+
Valid units are either ``pixel`` or angle units like ``arcsec``, ``degree`` etc..
492+
Default is ``pixel``.
493+
494+
center_pos: tuple | None, optional
495+
The coordinate center of the image. The coordinates have to
496+
be given in the unit defined in the parameter ``ax_unit`` above.
497+
If ``ax_unit`` is set to ``pixel`` this parameter is ignored.
498+
Default is ``None``, meaning the coordinates of the axes will be
499+
given as relative.
486500
487501
norm : str | matplotlib.colors.Normalize | None, optional
488502
The name of the norm or a matplotlib norm.
@@ -571,19 +585,51 @@ def plot_dirty_image(
571585
"The given mode does not exist! Valid modes are: real, imag, abs"
572586
)
573587

588+
unit = units.Unit(ax_unit)
589+
590+
if unit.physical_type == "angle":
591+
img_size = dirty_image.shape[0]
592+
cell_size = grid_data.fov / img_size
593+
594+
extent = (
595+
np.array([-img_size / 2, img_size / 2] * 2) * cell_size * units.arcsecond
596+
).to(unit)
597+
598+
if center_pos is not None:
599+
center_pos = np.array(center_pos) * unit
600+
extent[:2] += center_pos[0]
601+
extent[2:] += center_pos[1]
602+
label_prefix = ""
603+
else:
604+
label_prefix = "Relative "
605+
606+
ax.set_xlabel(f"{label_prefix}RA in {unit}")
607+
ax.set_ylabel(f"{label_prefix}DEC in {unit}")
608+
609+
extent = extent.value
610+
611+
else:
612+
if unit != units.pixel:
613+
warnings.warn(
614+
f"The given unit {unit} is no angle unit! Using pixels instead."
615+
)
616+
617+
extent = None
618+
619+
ax.set_xlabel("Pixels")
620+
ax.set_ylabel("Pixels")
621+
574622
im = ax.imshow(
575623
dirty_image,
576624
norm=norm,
577625
origin="lower",
578626
interpolation="none",
579627
cmap=cmap,
628+
extent=extent,
580629
**plot_args,
581630
)
582631
fig.colorbar(im, ax=ax, shrink=colorbar_shrink, label="Flux Density in Jy/px")
583632

584-
ax.set_xlabel("Pixels")
585-
ax.set_ylabel("Pixels")
586-
587633
if save_to is not None:
588634
fig.savefig(save_to, **save_args)
589635

tests/data/test_vis.ms.zip

-41.7 MB
Binary file not shown.

tests/data/test_vis_grid_ms.h5

-642 KB
Binary file not shown.

0 commit comments

Comments
 (0)