Skip to content

Commit cd93cf8

Browse files
authored
Merge pull request #780 from sunpy/YOLO
Faster shortcut for working out coordinates values for non-correlated WCS
2 parents ee91f49 + f821a69 commit cd93cf8

File tree

4 files changed

+176
-31
lines changed

4 files changed

+176
-31
lines changed

changelog/780.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.

ndcube/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l():
197197
return WCS(header=header)
198198

199199

200+
@pytest.fixture
201+
def wcs_3d_wave_lt_ln():
202+
header = {
203+
'CTYPE1': 'WAVE ',
204+
'CUNIT1': 'Angstrom',
205+
'CDELT1': 0.2,
206+
'CRPIX1': 0,
207+
'CRVAL1': 10,
208+
209+
'CTYPE2': 'HPLT-TAN',
210+
'CUNIT2': 'deg',
211+
'CDELT2': 0.5,
212+
'CRPIX2': 2,
213+
'CRVAL2': 0.5,
214+
215+
'CTYPE3': 'HPLN-TAN ',
216+
'CUNIT3': 'deg',
217+
'CDELT3': 0.4,
218+
'CRPIX3': 2,
219+
'CRVAL3': 1,
220+
}
221+
return WCS(header=header)
222+
223+
200224
@pytest.fixture
201225
def wcs_2d_lt_ln():
202226
spatial = {
@@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d):
445469
return cube
446470

447471

472+
@pytest.fixture
473+
def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln):
474+
shape = (3, 4, 5)
475+
wcs_3d_wave_lt_ln.array_shape = shape
476+
data = data_nd(shape)
477+
mask = data > 0
478+
cube = NDCube(
479+
data,
480+
wcs_3d_wave_lt_ln,
481+
mask=mask,
482+
uncertainty=data,
483+
)
484+
base_time = Time('2000-01-01', format='fits', scale='utc')
485+
timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])])
486+
cube.extra_coords.add('time', 0, timestamps)
487+
return cube
488+
489+
448490
@pytest.fixture
449491
def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
450492
data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]],

ndcube/ndcube.py

Lines changed: 100 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import astropy.nddata
1313
import astropy.units as u
1414
from astropy.units import UnitsError
15+
from astropy.wcs.utils import _split_matrix
16+
17+
from ndcube.utils.wcs import world_axis_to_pixel_axes
1518

1619
try:
1720
# Import sunpy coordinates if available to register the frames and WCS functions with astropy
@@ -20,7 +23,6 @@
2023
pass
2124

2225
from astropy.wcs import WCS
23-
from astropy.wcs.utils import _split_matrix
2426
from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper
2527
from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects
2628

@@ -479,24 +481,76 @@ def quantity(self):
479481
"""Unitful representation of the NDCube data."""
480482
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)
481483

482-
def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units):
483-
# Create meshgrid of all pixel coordinates.
484-
# If user wants pixel_corners, set pixel values to pixel corners.
485-
# Else make pixel centers.
484+
def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
485+
"""
486+
Generate world coordinates for independent axes.
487+
488+
The idea is to workout only the specific grid that is needed for independent axes.
489+
This speeds up the calculation of world coordinates and reduces memory usage.
490+
491+
Parameters
492+
----------
493+
pixel_corners : bool
494+
If one needs pixel corners, otherwise pixel centers.
495+
wcs : astropy.wcs.WCS
496+
The WCS.
497+
needed_axes : array-like
498+
The required pixel axes.
499+
units : bool
500+
If units are needed.
501+
502+
Returns
503+
-------
504+
array-like
505+
The world coordinates.
506+
"""
507+
needed_axes = np.array(needed_axes).squeeze()
508+
if self.data.ndim in needed_axes:
509+
required_axes = needed_axes - 1
510+
else:
511+
required_axes = needed_axes
512+
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
513+
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
514+
world_coords = wcs.pixel_to_world_values(*indices)
515+
if units:
516+
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
517+
return world_coords
518+
519+
def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
520+
"""
521+
Generate world coordinates for dependent axes.
522+
523+
This will work out the exact grid that is needed for dependent axes
524+
and can be time and memory consuming.
525+
526+
Parameters
527+
----------
528+
pixel_corners : bool
529+
If one needs pixel corners, otherwise pixel centers.
530+
wcs : astropy.wcs.WCS
531+
The WCS.
532+
needed_axes : array-like
533+
The required pixel axes.
534+
units : bool
535+
If units are needed.
536+
537+
Returns
538+
-------
539+
array-like
540+
The world coordinates.
541+
"""
486542
pixel_shape = self.data.shape[::-1]
487543
if pixel_corners:
488544
pixel_shape = tuple(np.array(pixel_shape) + 1)
489545
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
490546
else:
491547
ranges = [np.arange(i) for i in pixel_shape]
492-
493548
# Limit the pixel dimensions to the ones present in the ExtraCoords
494549
if isinstance(wcs, ExtraCoords):
495550
ranges = [ranges[i] for i in wcs.mapping]
496551
wcs = wcs.wcs
497552
if wcs is None:
498-
return []
499-
553+
return ()
500554
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
501555
# required so values_to_high_level_objects in the calling function doesn't crash or warn
502556
world_coords = [0] * wcs.world_n_dim
@@ -528,71 +582,92 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
528582
array_slice[wcs.axis_correlation_matrix[idx]] = slice(None)
529583
tmp_world = world[idx][tuple(array_slice)].T
530584
world_coords[idx] = tmp_world
531-
532585
if units:
533586
for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
534587
world_coords[i] = coord << u.Unit(unit)
588+
return world_coords
589+
590+
def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
591+
"""
592+
Private method to generate world coordinates.
593+
594+
Handles both dependent and independent axes.
595+
596+
Parameters
597+
----------
598+
pixel_corners : bool
599+
If one needs pixel corners, otherwise pixel centers.
600+
wcs : astropy.wcs.WCS
601+
The WCS.
602+
needed_axes : array-like
603+
The axes that are needed.
604+
units : bool
605+
If units are needed.
535606
607+
Returns
608+
-------
609+
array-like
610+
The world coordinates.
611+
"""
612+
axes_are_independent = []
613+
pixel_axes = set()
614+
for world_axis in needed_axes:
615+
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
616+
axes_are_independent.append(len(pix_ax) == 1)
617+
pixel_axes = pixel_axes.union(set(pix_ax))
618+
pixel_axes = list(pixel_axes)
619+
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
620+
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
621+
else:
622+
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
536623
return world_coords
537624

538625
@utils.cube.sanitize_wcs
539626
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
540627
# Docstring in NDCubeABC.
541628
if isinstance(wcs, BaseHighLevelWCS):
542629
wcs = wcs.low_level_wcs
543-
544630
orig_wcs = wcs
545631
if isinstance(wcs, ExtraCoords):
546632
wcs = wcs.wcs
547633
if not wcs:
548634
return ()
549-
550635
object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
551636
unique_obj_names = utils.misc.unique_sorted(object_names)
552637
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]
553-
554638
# Create a mapping from world index in the WCS to object index in axes_coords
555639
world_index_to_object_index = {}
556640
for object_index, world_axes in enumerate(world_axes_for_obj):
557641
for world_index in world_axes:
558642
world_index_to_object_index[world_index] = object_index
559-
560643
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
561644
object_indices = utils.misc.unique_sorted(
562645
[world_index_to_object_index[world_index] for world_index in world_indices]
563646
)
564-
565-
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False)
566-
647+
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False)
567648
axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)
568-
569649
if not axes:
570650
return tuple(axes_coords)
571-
572651
return tuple(axes_coords[i] for i in object_indices)
573652

574653
@utils.cube.sanitize_wcs
575654
def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
576655
# Docstring in NDCubeABC.
577656
if isinstance(wcs, BaseHighLevelWCS):
578657
wcs = wcs.low_level_wcs
579-
580658
orig_wcs = wcs
581659
if isinstance(wcs, ExtraCoords):
582660
wcs = wcs.wcs
583-
661+
if not wcs:
662+
return ()
584663
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
585-
586-
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True)
587-
664+
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True)
588665
world_axis_physical_types = wcs.world_axis_physical_types
589-
590666
# If user has supplied axes, extract only the
591667
# world coords that correspond to those axes.
592668
if axes:
593669
axes_coords = [axes_coords[i] for i in world_indices]
594670
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])
595-
596671
# Return in array order.
597672
# First replace characters in physical types forbidden for namedtuple identifiers.
598673
identifiers = []

ndcube/tests/test_ndcube.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from astropy.coordinates import SkyCoord, SpectralCoord
1313
from astropy.io import fits
1414
from astropy.nddata import UnknownUncertainty
15+
from astropy.tests.helper import assert_quantity_allclose
1516
from astropy.time import Time
1617
from astropy.units import UnitsError
1718
from astropy.wcs import WCS
@@ -177,9 +178,19 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):
177178

178179
coords = cube.axis_world_coords()
179180
assert len(coords) == 2
181+
assert isinstance(coords[0], SkyCoord)
182+
assert coords[0].shape == (5, 8)
183+
assert isinstance(coords[1], SpectralCoord)
184+
assert coords[1].shape == (10,)
180185

181186
coords = cube.axis_world_coords(wcs=cube.combined_wcs)
182187
assert len(coords) == 3
188+
assert isinstance(coords[0], SkyCoord)
189+
assert coords[0].shape == (5, 8)
190+
assert isinstance(coords[1], SpectralCoord)
191+
assert coords[1].shape == (10,)
192+
assert isinstance(coords[2], Time)
193+
assert coords[2].shape == (5,)
183194

184195
coords = cube.axis_world_coords(wcs=cube.extra_coords)
185196
assert len(coords) == 1
@@ -199,8 +210,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
199210
# slice the cube so extra_coords is empty, and then try and run axis_world_coords
200211
awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords)
201212
assert awc == ()
202-
sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True)
203-
assert awc == ()
204213

205214

206215
@pytest.mark.xfail(reason=">1D Tables not supported")
@@ -235,13 +244,31 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):
235244
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
236245

237246

247+
def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time):
248+
# This replicates a specific NDCube object in visualization.rst
249+
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
250+
assert len(coords) == 1
251+
assert isinstance(coords[0], Time)
252+
assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))
253+
254+
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
255+
assert len(coords) == 1
256+
assert isinstance(coords.time, u.Quantity)
257+
assert_quantity_allclose(coords.time, [0, 60, 120] * u.second)
258+
259+
238260
@pytest.mark.parametrize("axes", [[-1], [2], ["em"]])
239261
def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
262+
263+
# We go from 4 pixels to 6 pixels when we add pixel corners
264+
coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False)
265+
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
266+
240267
coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
241-
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
268+
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
242269

243270
coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True)
244-
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
271+
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
245272

246273

247274
@pytest.mark.parametrize(("ndc", "item"),
@@ -252,10 +279,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
252279
indirect=("ndc",))
253280
def test_axis_world_coords_sliced_all_3d(ndc, item):
254281
coords = ndc[item].axis_world_coords_values()
255-
assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
282+
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
256283

257284
coords = ndc[item].axis_world_coords()
258-
assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
285+
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
259286

260287

261288
@pytest.mark.parametrize(("ndc", "item"),

0 commit comments

Comments
 (0)