Skip to content

Commit d1592cd

Browse files
authored
Merge pull request #29 from radionets-project/fix_ms_gridding
Fix Measurement Set Gridding
2 parents 555d767 + f263eec commit d1592cd

File tree

4 files changed

+37
-41
lines changed

4 files changed

+37
-41
lines changed

docs/changes/29.bugfix.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- Downgraded the pinned version for ``numpy`` in ``pyproject.toml`` to avoid potential conflicts.
2+
- Updated the docstring for the ``vis_data`` attribute in `GridData` to specify its shape for better clarity.
3+
4+
- Improved data selection and masking logic in ``GridData.from_ms`` to ensure correct row selection, flag masking, and shape alignment for measurement and channel data. This also fixes the handling of flagged data and ensures that the returned arrays are correctly filtered and shaped.

pyproject.toml

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,22 @@ classifiers = [
3333
requires-python = ">=3.10"
3434

3535
dependencies = [
36-
"astropy>=6.1.0",
37-
"click>=8.2.1",
38-
"h5py>=3.14.0",
39-
"matplotlib>=3.10.5",
40-
"numpy>=2.2.6",
41-
"pandas>=2.3.2",
42-
"python-casacore>=3.7.1",
43-
"scipy>=1.15.3",
44-
"toml>=0.10.2",
45-
"torch>=2.8.0",
46-
"tqdm>=4.67.1",
36+
"astropy>=6.1.0",
37+
"click>=8.2.1",
38+
"h5py>=3.14.0",
39+
"matplotlib>=3.10.5",
40+
"numpy>=1.26.0",
41+
"pandas>=2.3.2",
42+
"python-casacore>=3.7.1",
43+
"scipy>=1.15.3",
44+
"toml>=0.10.2",
45+
"torch>=2.8.0",
46+
"tqdm>=4.67.1",
4747
]
4848

4949
[project.optional-dependencies]
5050
pyvisgen = ["pyvisgen"]
51-
all = [
52-
"pyvisgrid[pyvisgen]",
53-
]
51+
all = ["pyvisgrid[pyvisgen]"]
5452

5553

5654
[dependency-groups]
@@ -92,8 +90,8 @@ dev = [
9290
"ipython",
9391
"jupyter",
9492
"pre-commit",
95-
{include-group = "tests"},
96-
{include-group = "docs"},
93+
{ include-group = "tests" },
94+
{ include-group = "docs" },
9795
]
9896

9997
[project.urls]
@@ -121,11 +119,11 @@ extend-exclude = ["tests"]
121119

122120
[tool.ruff.lint]
123121
extend-select = [
124-
"I", # isort
125-
"E", # pycodestyle
126-
"F", # Pyflakes
127-
"UP", # pyupgrade
128-
"B", # flake8-bugbear
122+
"I", # isort
123+
"E", # pycodestyle
124+
"F", # Pyflakes
125+
"UP", # pyupgrade
126+
"B", # flake8-bugbear
129127
"SIM", # flake8-simplify
130128
]
131129
ignore = ["B905"]

src/pyvisgrid/core/gridder.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class GridData:
2929
Attributes
3030
----------
3131
vis_data : numpy.ndarray
32-
The ungridded visibilities.
32+
The ungridded visibilities of shape ``(N_MEASUREMENTS * N_CHANNELS,)``.
3333
fov : float
3434
The size of the Field Of View of the gridded data in arcseconds.
3535
mask : numpy.ndarray, optional
@@ -247,7 +247,7 @@ def grid(self, stokes_component: str = "I"):
247247
np.fft.ifft2(np.fft.fftshift(mask_real + 1j * mask_imag))
248248
)
249249

250-
return self[stokes_component]
250+
return self.stokes[stokes_component]
251251

252252
@classmethod
253253
def from_pyvisgen(
@@ -468,45 +468,41 @@ def from_ms(
468468

469469
if desc_id is not None:
470470
mask = main_tab.getcol("DATA_DESC_ID") == desc_id
471-
mask_idx = np.argwhere(mask).ravel()
472471

473-
main_tab = main_tab.selectrows(rownrs=mask_idx)
474-
475-
data = main_tab.getcol(data_colname)
476-
uv = main_tab.getcol("UVW")[:, :2]
477-
times = main_tab.getcol("TIME")
472+
main_tab = main_tab.selectrows(rownrs=np.argwhere(mask).ravel())
478473

479474
ref_frequency = spectral_tab.getcell("REF_FREQUENCY", desc_id)
480475
frequency_offsets = (
481476
spectral_tab.getcell("CHAN_FREQ", desc_id) - ref_frequency
482477
)
483478

484479
else:
485-
mask = np.ones_like(main_tab.getcol("DATA_DESC_ID")).astype(bool)
486-
data = main_tab.getcol(data_colname)
487-
uv = main_tab.getcol("UVW")[:, :2]
488-
times = main_tab.getcol("TIME")
489-
490480
ref_frequency = spectral_tab.getcell("REF_FREQUENCY", 0)
491481
frequency_offsets = spectral_tab.getcell("CHAN_FREQ", 0) - ref_frequency
492482

483+
data = main_tab.getcol(data_colname)
484+
uv = main_tab.getcol("UVW")[:, :2]
485+
times = main_tab.getcol("TIME")
486+
493487
if filter_flagged:
494488
flag_mask = main_tab.getcol("FLAG")
495-
flag_mask = flag_mask.reshape((-1, flag_mask.shape[0])).astype(np.uint8)
496-
flag_mask = np.prod(flag_mask, axis=0)
489+
flag_mask = flag_mask.reshape((flag_mask.shape[0], -1)).astype(np.uint8)
490+
flag_mask = np.prod(flag_mask, axis=1)
497491

498492
flag_mask = np.logical_not(flag_mask.astype(bool))
499493

500494
else:
501-
flag_mask = np.ones(uv.shape[-1]).astype(bool)
495+
flag_mask = np.ones(uv.shape[0]).astype(bool)
502496

503497
uv = uv[flag_mask]
504498
data = data[flag_mask]
499+
times = times[flag_mask]
505500

506501
u_meter = uv[:, 0]
507502
v_meter = uv[:, 1]
508503

509504
stokes_i = data[..., 0] + data[..., 1]
505+
stokes_i = stokes_i.T # ensure matching shape (N_CHANNELS, N_MEASUREMENTS)
510506

511507
# FIXME: probably some kind of difference in normalization.
512508
# Factor 0.5 fixes this for now. Has to be investigated.
@@ -515,7 +511,7 @@ def from_ms(
515511
cls = cls(
516512
u_meter=u_meter,
517513
v_meter=v_meter,
518-
times=Time(times[flag_mask] / 3600 / 24, format="mjd").mjd,
514+
times=Time(times / 3600 / 24, format="mjd").mjd,
519515
img_size=img_size,
520516
fov=fov,
521517
ref_frequency=ref_frequency,

src/pyvisgrid/plotting/plotting.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,7 @@ def plot_dirty_image(
497497
- ``abs``: Plot the absolute value of the dirty image.
498498
499499
Default is ``real``.
500-
501-
ax_unit: str | astropy.units.unit, optional
500+
ax_unit: str | astropy.units.Unit, optional
502501
The unit in which to show the ticks of the x and y-axes in.
503502
The y-axis is the Declination (DEC) and the x-axis is the Right Ascension (RA).
504503
The latter one is defined as increasing from left to right!
@@ -508,7 +507,6 @@ def plot_dirty_image(
508507
509508
Valid units are either ``pixel`` or angle units like ``arcsec``, ``degree``
510509
etc. Default is ``pixel``.
511-
512510
center_pos: tuple | None, optional
513511
The coordinate center of the image. The coordinates have to
514512
be given in the unit defined in the parameter ``ax_unit`` above.

0 commit comments

Comments
 (0)