diff --git a/changelog/780.bugfix.rst b/changelog/780.bugfix.rst new file mode 100644 index 000000000..f6c872dd6 --- /dev/null +++ b/changelog/780.bugfix.rst @@ -0,0 +1 @@ +Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances. diff --git a/ndcube/conftest.py b/ndcube/conftest.py index ad34271d0..8960e8879 100644 --- a/ndcube/conftest.py +++ b/ndcube/conftest.py @@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l(): return WCS(header=header) +@pytest.fixture +def wcs_3d_wave_lt_ln(): + header = { + 'CTYPE1': 'WAVE ', + 'CUNIT1': 'Angstrom', + 'CDELT1': 0.2, + 'CRPIX1': 0, + 'CRVAL1': 10, + + 'CTYPE2': 'HPLT-TAN', + 'CUNIT2': 'deg', + 'CDELT2': 0.5, + 'CRPIX2': 2, + 'CRVAL2': 0.5, + + 'CTYPE3': 'HPLN-TAN ', + 'CUNIT3': 'deg', + 'CDELT3': 0.4, + 'CRPIX3': 2, + 'CRVAL3': 1, + } + return WCS(header=header) + + @pytest.fixture def wcs_2d_lt_ln(): spatial = { @@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d): return cube +@pytest.fixture +def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln): + shape = (3, 4, 5) + wcs_3d_wave_lt_ln.array_shape = shape + data = data_nd(shape) + mask = data > 0 + cube = NDCube( + data, + wcs_3d_wave_lt_ln, + mask=mask, + uncertainty=data, + ) + base_time = Time('2000-01-01', format='fits', scale='utc') + timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])]) + cube.extra_coords.add('time', 0, timestamps) + return cube + + @pytest.fixture def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d): data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]], diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index d219331e7..b767eec10 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -12,6 +12,9 @@ import astropy.nddata import astropy.units as u from astropy.units import UnitsError +from astropy.wcs.utils import _split_matrix + +from ndcube.utils.wcs import world_axis_to_pixel_axes try: # Import sunpy coordinates if available to register the frames and WCS functions with astropy @@ -20,7 +23,6 @@ pass from astropy.wcs import WCS -from astropy.wcs.utils import _split_matrix from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects @@ -479,24 +481,76 @@ def quantity(self): """Unitful representation of the NDCube data.""" return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED) - def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units): - # Create meshgrid of all pixel coordinates. - # If user wants pixel_corners, set pixel values to pixel corners. - # Else make pixel centers. + def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units): + """ + Generate world coordinates for independent axes. + + The idea is to workout only the specific grid that is needed for independent axes. + This speeds up the calculation of world coordinates and reduces memory usage. + + Parameters + ---------- + pixel_corners : bool + If one needs pixel corners, otherwise pixel centers. + wcs : astropy.wcs.WCS + The WCS. + needed_axes : array-like + The required pixel axes. + units : bool + If units are needed. + + Returns + ------- + array-like + The world coordinates. + """ + needed_axes = np.array(needed_axes).squeeze() + if self.data.ndim in needed_axes: + required_axes = needed_axes - 1 + else: + required_axes = needed_axes + lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes]) + indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]] + world_coords = wcs.pixel_to_world_values(*indices) + if units: + world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes]) + return world_coords + + def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units): + """ + Generate world coordinates for dependent axes. + + This will work out the exact grid that is needed for dependent axes + and can be time and memory consuming. + + Parameters + ---------- + pixel_corners : bool + If one needs pixel corners, otherwise pixel centers. + wcs : astropy.wcs.WCS + The WCS. + needed_axes : array-like + The required pixel axes. + units : bool + If units are needed. + + Returns + ------- + array-like + The world coordinates. + """ pixel_shape = self.data.shape[::-1] if pixel_corners: pixel_shape = tuple(np.array(pixel_shape) + 1) ranges = [np.arange(i) - 0.5 for i in pixel_shape] else: ranges = [np.arange(i) for i in pixel_shape] - # Limit the pixel dimensions to the ones present in the ExtraCoords if isinstance(wcs, ExtraCoords): ranges = [ranges[i] for i in wcs.mapping] wcs = wcs.wcs if wcs is None: - return [] - + return () # This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is # required so values_to_high_level_objects in the calling function doesn't crash or warn world_coords = [0] * wcs.world_n_dim @@ -528,11 +582,44 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) array_slice[wcs.axis_correlation_matrix[idx]] = slice(None) tmp_world = world[idx][tuple(array_slice)].T world_coords[idx] = tmp_world - if units: for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)): world_coords[i] = coord << u.Unit(unit) + return world_coords + + def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None): + """ + Private method to generate world coordinates. + + Handles both dependent and independent axes. + + Parameters + ---------- + pixel_corners : bool + If one needs pixel corners, otherwise pixel centers. + wcs : astropy.wcs.WCS + The WCS. + needed_axes : array-like + The axes that are needed. + units : bool + If units are needed. + Returns + ------- + array-like + The world coordinates. + """ + axes_are_independent = [] + pixel_axes = set() + for world_axis in needed_axes: + pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix) + axes_are_independent.append(len(pix_ax) == 1) + pixel_axes = pixel_axes.union(set(pix_ax)) + pixel_axes = list(pixel_axes) + if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0: + world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units) + else: + world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units) return world_coords @utils.cube.sanitize_wcs @@ -540,35 +627,27 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): # Docstring in NDCubeABC. if isinstance(wcs, BaseHighLevelWCS): wcs = wcs.low_level_wcs - orig_wcs = wcs if isinstance(wcs, ExtraCoords): wcs = wcs.wcs if not wcs: return () - object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components]) unique_obj_names = utils.misc.unique_sorted(object_names) world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names] - # Create a mapping from world index in the WCS to object index in axes_coords world_index_to_object_index = {} for object_index, world_axes in enumerate(world_axes_for_obj): for world_index in world_axes: world_index_to_object_index[world_index] = object_index - world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) object_indices = utils.misc.unique_sorted( [world_index_to_object_index[world_index] for world_index in world_indices] ) - - axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False) - + axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False) axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs) - if not axes: return tuple(axes_coords) - return tuple(axes_coords[i] for i in object_indices) @utils.cube.sanitize_wcs @@ -576,23 +655,19 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None): # Docstring in NDCubeABC. if isinstance(wcs, BaseHighLevelWCS): wcs = wcs.low_level_wcs - orig_wcs = wcs if isinstance(wcs, ExtraCoords): wcs = wcs.wcs - + if not wcs: + return () world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) - - axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True) - + axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True) world_axis_physical_types = wcs.world_axis_physical_types - # If user has supplied axes, extract only the # world coords that correspond to those axes. if axes: axes_coords = [axes_coords[i] for i in world_indices] world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices]) - # Return in array order. # First replace characters in physical types forbidden for namedtuple identifiers. identifiers = [] diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index a40cb4434..248a6f6be 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -12,6 +12,7 @@ from astropy.coordinates import SkyCoord, SpectralCoord from astropy.io import fits from astropy.nddata import UnknownUncertainty +from astropy.tests.helper import assert_quantity_allclose from astropy.time import Time from astropy.units import UnitsError from astropy.wcs import WCS @@ -177,9 +178,19 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime): coords = cube.axis_world_coords() assert len(coords) == 2 + assert isinstance(coords[0], SkyCoord) + assert coords[0].shape == (5, 8) + assert isinstance(coords[1], SpectralCoord) + assert coords[1].shape == (10,) coords = cube.axis_world_coords(wcs=cube.combined_wcs) assert len(coords) == 3 + assert isinstance(coords[0], SkyCoord) + assert coords[0].shape == (5, 8) + assert isinstance(coords[1], SpectralCoord) + assert coords[1].shape == (10,) + assert isinstance(coords[2], Time) + assert coords[2].shape == (5,) coords = cube.axis_world_coords(wcs=cube.extra_coords) assert len(coords) == 1 @@ -199,8 +210,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime): # slice the cube so extra_coords is empty, and then try and run axis_world_coords awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords) assert awc == () - sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True) - assert awc == () @pytest.mark.xfail(reason=">1D Tables not supported") @@ -235,13 +244,31 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l): assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) +def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time): + # This replicates a specific NDCube object in visualization.rst + coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs) + assert len(coords) == 1 + assert isinstance(coords[0], Time) + assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000'])) + + coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs) + assert len(coords) == 1 + assert isinstance(coords.time, u.Quantity) + assert_quantity_allclose(coords.time, [0, 60, 120] * u.second) + + @pytest.mark.parametrize("axes", [[-1], [2], ["em"]]) def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): + + # We go from 4 pixels to 6 pixels when we add pixel corners + coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False) + assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) + coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True) - assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) + assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True) - assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) + assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) @pytest.mark.parametrize(("ndc", "item"), @@ -252,10 +279,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): indirect=("ndc",)) def test_axis_world_coords_sliced_all_3d(ndc, item): coords = ndc[item].axis_world_coords_values() - assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) + assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) coords = ndc[item].axis_world_coords() - assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) + assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) @pytest.mark.parametrize(("ndc", "item"),