diff --git a/changelog/780.bugfix.rst b/changelog/780.bugfix.rst deleted file mode 100644 index f6c872dd6..000000000 --- a/changelog/780.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances. diff --git a/changelog/798.bugfix.rst b/changelog/798.bugfix.rst new file mode 100644 index 000000000..aa2e8ef18 --- /dev/null +++ b/changelog/798.bugfix.rst @@ -0,0 +1 @@ +Added an internal code to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances. diff --git a/ndcube/conftest.py b/ndcube/conftest.py index e723a7171..e52ad7643 100644 --- a/ndcube/conftest.py +++ b/ndcube/conftest.py @@ -292,6 +292,86 @@ def wcs_3d_ln_lt_t_rotated(): return WCS(header=h_rotated) +@pytest.fixture +def wcs_3d_ln_lt_l_coupled(): + # WCS for a 3D data cube with two celestial axes and one wavelength axis. + # The latitudinal dimension is coupled to the third pixel dimension through + # a single off diagonal element in the PCij matrix + header = { + 'CTYPE1': 'HPLN-TAN', + 'CRPIX1': 5, + 'CDELT1': 5, + 'CUNIT1': 'arcsec', + 'CRVAL1': 0.0, + + 'CTYPE2': 'HPLT-TAN', + 'CRPIX2': 5, + 'CDELT2': 5, + 'CUNIT2': 'arcsec', + 'CRVAL2': 0.0, + + 'CTYPE3': 'WAVE', + 'CRPIX3': 1.0, + 'CDELT3': 1, + 'CUNIT3': 'Angstrom', + 'CRVAL3': 1.0, + + 'PC1_1': 1, + 'PC1_2': 0, + 'PC1_3': 0, + 'PC2_1': 0, + 'PC2_2': 1, + 'PC2_3': -1.0, + 'PC3_1': 0.0, + 'PC3_2': 0.0, + 'PC3_3': 1.0, + + 'WCSAXES': 3, + + 'DATEREF': "2020-01-01T00:00:00" + } + return WCS(header=header) + + +@pytest.fixture +def wcs_3d_ln_lt_t_coupled(): + # WCS for a 3D data cube with two celestial axes and one time axis. + header = { + 'CTYPE1': 'HPLN-TAN', + 'CRPIX1': 5, + 'CDELT1': 5, + 'CUNIT1': 'arcsec', + 'CRVAL1': 0.0, + + 'CTYPE2': 'HPLT-TAN', + 'CRPIX2': 5, + 'CDELT2': 5, + 'CUNIT2': 'arcsec', + 'CRVAL2': 0.0, + + 'CTYPE3': 'UTC', + 'CRPIX3': 1.0, + 'CDELT3': 1, + 'CUNIT3': 's', + 'CRVAL3': 1.0, + + 'PC1_1': 1, + 'PC1_2': 0, + 'PC1_3': 0, + 'PC2_1': 0, + 'PC2_2': 1, + 'PC2_3': 0, + 'PC3_1': 0, + 'PC3_2': 1, + 'PC3_3': 1, + + 'WCSAXES': 3, + + 'DATEREF': "2020-01-01T00:00:00" + } + return WCS(header=header) + + ################################################################################ # Extra and Global Coords Fixtures ################################################################################ @@ -519,6 +599,31 @@ def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d): return cube +@pytest.fixture +def ndcube_3d_coupled(wcs_3d_ln_lt_l_coupled): + shape = (128, 256, 512) + wcs_3d_ln_lt_l_coupled.array_shape = shape + data = data_nd(shape) + mask = data > 0 + return NDCube( + data, + wcs_3d_ln_lt_l_coupled, + mask=mask, + uncertainty=data, + ) + + +@pytest.fixture +def ndcube_3d_coupled_time(wcs_3d_ln_lt_t_coupled): + shape = (128, 256, 512) + wcs_3d_ln_lt_t_coupled.array_shape = shape + data = data_nd(shape) + return NDCube( + data, + wcs_3d_ln_lt_t_coupled, + ) + + @pytest.fixture def ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l): return gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index 56be5487d..5f5f4ee40 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -14,8 +14,6 @@ from astropy.units import UnitsError from astropy.wcs.utils import _split_matrix -from ndcube.utils.wcs import world_axis_to_pixel_axes - try: # Import sunpy coordinates if available to register the frames and WCS functions with astropy import sunpy.coordinates # NOQA @@ -486,47 +484,9 @@ def quantity(self): """Unitful representation of the NDCube data.""" return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED) - def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units): - """ - Generate world coordinates for independent axes. - - The idea is to workout only the specific grid that is needed for independent axes. - This speeds up the calculation of world coordinates and reduces memory usage. - - Parameters - ---------- - pixel_corners : bool - If one needs pixel corners, otherwise pixel centers. - wcs : astropy.wcs.WCS - The WCS. - needed_axes : array-like - The required pixel axes. - units : bool - If units are needed. - - Returns - ------- - array-like - The world coordinates. - """ - needed_axes = np.array(needed_axes).squeeze() - if self.data.ndim in needed_axes: - required_axes = needed_axes - 1 - else: - required_axes = needed_axes - lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes]) - indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]] - world_coords = wcs.pixel_to_world_values(*indices) - if units: - world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes]) - return world_coords - - def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units): + def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None): """ - Generate world coordinates for dependent axes. - - This will work out the exact grid that is needed for dependent axes - and can be time and memory consuming. + Private method to generate world coordinates. Parameters ---------- @@ -535,7 +495,7 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit wcs : astropy.wcs.WCS The WCS. needed_axes : array-like - The required pixel axes. + The axes that are needed. units : bool If units are needed. @@ -573,6 +533,12 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit # And inject 0s for those coordinates for idx in non_corr_axes: sub_range.insert(idx, 0) + # If we are subsetting world axes, ignore any pixel axes which are not correlated with our requested world axis. + if any(world_axis in needed_axes for world_axis in world_axes_indices): + needed_pixel_axes = wcs.axis_correlation_matrix[needed_axes] + unneeded_pixel_axes = np.argwhere(needed_pixel_axes.sum(axis=0) == 0)[:, 0] + for idx in unneeded_pixel_axes: + sub_range[idx] = 0 # Generate a grid of broadcastable pixel indices for all pixel dimensions grid = np.meshgrid(*sub_range, indexing='ij') # Convert to world coordinates @@ -592,41 +558,6 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit world_coords[i] = coord << u.Unit(unit) return world_coords - def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None): - """ - Private method to generate world coordinates. - - Handles both dependent and independent axes. - - Parameters - ---------- - pixel_corners : bool - If one needs pixel corners, otherwise pixel centers. - wcs : astropy.wcs.WCS - The WCS. - needed_axes : array-like - The axes that are needed. - units : bool - If units are needed. - - Returns - ------- - array-like - The world coordinates. - """ - axes_are_independent = [] - pixel_axes = set() - for world_axis in needed_axes: - pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix) - axes_are_independent.append(len(pix_ax) == 1) - pixel_axes = pixel_axes.union(set(pix_ax)) - pixel_axes = list(pixel_axes) - if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0: - world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units) - else: - world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units) - return world_coords - @utils.cube.sanitize_wcs def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): # Docstring in NDCubeABC. diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index daf263e5f..01f2aa0f2 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -230,6 +230,20 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime): assert coords[0].shape == (5,) +@pytest.mark.limit_memory("12 MB") +def test_axis_world_coords_wave_coupled_dims(ndcube_3d_coupled): + cube = ndcube_3d_coupled + + cube.axis_world_coords('em.wl') + + +@pytest.mark.limit_memory("12 MB") +def test_axis_world_coords_time_coupled_dims(ndcube_3d_coupled_time): + cube = ndcube_3d_coupled_time + + cube.axis_world_coords('time') + + def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime): cube = ndcube_3d_l_ln_lt_ectime sub_cube = cube[:, 0] @@ -292,10 +306,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True) - assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) + assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True) - assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) + assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) @pytest.mark.parametrize(("ndc", "item"), diff --git a/pyproject.toml b/pyproject.toml index 347b06750..17b45dc5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ tests = [ "pytest-mpl>=0.12", "pytest-xdist", "pytest", + "pytest-memray; sys_platform != 'win32'", "scipy", "specutils", "sunpy>=5.0.0", diff --git a/pytest.ini b/pytest.ini index 597d9b120..cd1290262 100644 --- a/pytest.ini +++ b/pytest.ini @@ -28,6 +28,11 @@ addopts = --doctest-continue-on-failure mpl-results-path = figure_test_images mpl-use-full-test-name = true +remote_data_strict = True +doctest_subpackage_requires = + docs/explaining_ndcube/* = numpy>=2.0.0 +markers = + limit_memory: pytest-memray marker to fail a test if too much memory used filterwarnings = # Turn all warnings into errors so they do not pass silently. error @@ -53,6 +58,3 @@ filterwarnings = ignore:FigureCanvasAgg is non-interactive, and thus cannot be shown:UserWarning # Oldestdeps from gWCS ignore:pkg_resources is deprecated as an API:DeprecationWarning -remote_data_strict = True -doctest_subpackage_requires = - docs/explaining_ndcube/* = numpy>=2.0.0