Skip to content

Commit 1c9faa1

Browse files
authored
Merge pull request #798 from Cadair/maybe_this_will_work
Reduce memory consuption in axis_world_coords (again)
2 parents 067ddf8 + 1a5bd72 commit 1c9faa1

File tree

7 files changed

+137
-84
lines changed

7 files changed

+137
-84
lines changed

changelog/780.bugfix.rst

Lines changed: 0 additions & 1 deletion
This file was deleted.

changelog/798.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added an internal code to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.

ndcube/conftest.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,86 @@ def wcs_3d_ln_lt_t_rotated():
292292
return WCS(header=h_rotated)
293293

294294

295+
@pytest.fixture
296+
def wcs_3d_ln_lt_l_coupled():
297+
# WCS for a 3D data cube with two celestial axes and one wavelength axis.
298+
# The latitudinal dimension is coupled to the third pixel dimension through
299+
# a single off diagonal element in the PCij matrix
300+
header = {
301+
'CTYPE1': 'HPLN-TAN',
302+
'CRPIX1': 5,
303+
'CDELT1': 5,
304+
'CUNIT1': 'arcsec',
305+
'CRVAL1': 0.0,
306+
307+
'CTYPE2': 'HPLT-TAN',
308+
'CRPIX2': 5,
309+
'CDELT2': 5,
310+
'CUNIT2': 'arcsec',
311+
'CRVAL2': 0.0,
312+
313+
'CTYPE3': 'WAVE',
314+
'CRPIX3': 1.0,
315+
'CDELT3': 1,
316+
'CUNIT3': 'Angstrom',
317+
'CRVAL3': 1.0,
318+
319+
'PC1_1': 1,
320+
'PC1_2': 0,
321+
'PC1_3': 0,
322+
'PC2_1': 0,
323+
'PC2_2': 1,
324+
'PC2_3': -1.0,
325+
'PC3_1': 0.0,
326+
'PC3_2': 0.0,
327+
'PC3_3': 1.0,
328+
329+
'WCSAXES': 3,
330+
331+
'DATEREF': "2020-01-01T00:00:00"
332+
}
333+
return WCS(header=header)
334+
335+
336+
@pytest.fixture
337+
def wcs_3d_ln_lt_t_coupled():
338+
# WCS for a 3D data cube with two celestial axes and one time axis.
339+
header = {
340+
'CTYPE1': 'HPLN-TAN',
341+
'CRPIX1': 5,
342+
'CDELT1': 5,
343+
'CUNIT1': 'arcsec',
344+
'CRVAL1': 0.0,
345+
346+
'CTYPE2': 'HPLT-TAN',
347+
'CRPIX2': 5,
348+
'CDELT2': 5,
349+
'CUNIT2': 'arcsec',
350+
'CRVAL2': 0.0,
351+
352+
'CTYPE3': 'UTC',
353+
'CRPIX3': 1.0,
354+
'CDELT3': 1,
355+
'CUNIT3': 's',
356+
'CRVAL3': 1.0,
357+
358+
'PC1_1': 1,
359+
'PC1_2': 0,
360+
'PC1_3': 0,
361+
'PC2_1': 0,
362+
'PC2_2': 1,
363+
'PC2_3': 0,
364+
'PC3_1': 0,
365+
'PC3_2': 1,
366+
'PC3_3': 1,
367+
368+
'WCSAXES': 3,
369+
370+
'DATEREF': "2020-01-01T00:00:00"
371+
}
372+
return WCS(header=header)
373+
374+
295375
################################################################################
296376
# Extra and Global Coords Fixtures
297377
################################################################################
@@ -519,6 +599,31 @@ def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
519599
return cube
520600

521601

602+
@pytest.fixture
603+
def ndcube_3d_coupled(wcs_3d_ln_lt_l_coupled):
604+
shape = (128, 256, 512)
605+
wcs_3d_ln_lt_l_coupled.array_shape = shape
606+
data = data_nd(shape)
607+
mask = data > 0
608+
return NDCube(
609+
data,
610+
wcs_3d_ln_lt_l_coupled,
611+
mask=mask,
612+
uncertainty=data,
613+
)
614+
615+
616+
@pytest.fixture
617+
def ndcube_3d_coupled_time(wcs_3d_ln_lt_t_coupled):
618+
shape = (128, 256, 512)
619+
wcs_3d_ln_lt_t_coupled.array_shape = shape
620+
data = data_nd(shape)
621+
return NDCube(
622+
data,
623+
wcs_3d_ln_lt_t_coupled,
624+
)
625+
626+
522627
@pytest.fixture
523628
def ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l):
524629
return gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l,

ndcube/ndcube.py

Lines changed: 9 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from astropy.units import UnitsError
1515
from astropy.wcs.utils import _split_matrix
1616

17-
from ndcube.utils.wcs import world_axis_to_pixel_axes
18-
1917
try:
2018
# Import sunpy coordinates if available to register the frames and WCS functions with astropy
2119
import sunpy.coordinates # NOQA
@@ -486,47 +484,9 @@ def quantity(self):
486484
"""Unitful representation of the NDCube data."""
487485
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)
488486

489-
def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
490-
"""
491-
Generate world coordinates for independent axes.
492-
493-
The idea is to workout only the specific grid that is needed for independent axes.
494-
This speeds up the calculation of world coordinates and reduces memory usage.
495-
496-
Parameters
497-
----------
498-
pixel_corners : bool
499-
If one needs pixel corners, otherwise pixel centers.
500-
wcs : astropy.wcs.WCS
501-
The WCS.
502-
needed_axes : array-like
503-
The required pixel axes.
504-
units : bool
505-
If units are needed.
506-
507-
Returns
508-
-------
509-
array-like
510-
The world coordinates.
511-
"""
512-
needed_axes = np.array(needed_axes).squeeze()
513-
if self.data.ndim in needed_axes:
514-
required_axes = needed_axes - 1
515-
else:
516-
required_axes = needed_axes
517-
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
518-
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
519-
world_coords = wcs.pixel_to_world_values(*indices)
520-
if units:
521-
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
522-
return world_coords
523-
524-
def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
487+
def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
525488
"""
526-
Generate world coordinates for dependent axes.
527-
528-
This will work out the exact grid that is needed for dependent axes
529-
and can be time and memory consuming.
489+
Private method to generate world coordinates.
530490
531491
Parameters
532492
----------
@@ -535,7 +495,7 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit
535495
wcs : astropy.wcs.WCS
536496
The WCS.
537497
needed_axes : array-like
538-
The required pixel axes.
498+
The axes that are needed.
539499
units : bool
540500
If units are needed.
541501
@@ -573,6 +533,12 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit
573533
# And inject 0s for those coordinates
574534
for idx in non_corr_axes:
575535
sub_range.insert(idx, 0)
536+
# If we are subsetting world axes, ignore any pixel axes which are not correlated with our requested world axis.
537+
if any(world_axis in needed_axes for world_axis in world_axes_indices):
538+
needed_pixel_axes = wcs.axis_correlation_matrix[needed_axes]
539+
unneeded_pixel_axes = np.argwhere(needed_pixel_axes.sum(axis=0) == 0)[:, 0]
540+
for idx in unneeded_pixel_axes:
541+
sub_range[idx] = 0
576542
# Generate a grid of broadcastable pixel indices for all pixel dimensions
577543
grid = np.meshgrid(*sub_range, indexing='ij')
578544
# Convert to world coordinates
@@ -592,41 +558,6 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit
592558
world_coords[i] = coord << u.Unit(unit)
593559
return world_coords
594560

595-
def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
596-
"""
597-
Private method to generate world coordinates.
598-
599-
Handles both dependent and independent axes.
600-
601-
Parameters
602-
----------
603-
pixel_corners : bool
604-
If one needs pixel corners, otherwise pixel centers.
605-
wcs : astropy.wcs.WCS
606-
The WCS.
607-
needed_axes : array-like
608-
The axes that are needed.
609-
units : bool
610-
If units are needed.
611-
612-
Returns
613-
-------
614-
array-like
615-
The world coordinates.
616-
"""
617-
axes_are_independent = []
618-
pixel_axes = set()
619-
for world_axis in needed_axes:
620-
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
621-
axes_are_independent.append(len(pix_ax) == 1)
622-
pixel_axes = pixel_axes.union(set(pix_ax))
623-
pixel_axes = list(pixel_axes)
624-
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
625-
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
626-
else:
627-
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
628-
return world_coords
629-
630561
@utils.cube.sanitize_wcs
631562
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
632563
# Docstring in NDCubeABC.

ndcube/tests/test_ndcube.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,20 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):
230230
assert coords[0].shape == (5,)
231231

232232

233+
@pytest.mark.limit_memory("12 MB")
234+
def test_axis_world_coords_wave_coupled_dims(ndcube_3d_coupled):
235+
cube = ndcube_3d_coupled
236+
237+
cube.axis_world_coords('em.wl')
238+
239+
240+
@pytest.mark.limit_memory("12 MB")
241+
def test_axis_world_coords_time_coupled_dims(ndcube_3d_coupled_time):
242+
cube = ndcube_3d_coupled_time
243+
244+
cube.axis_world_coords('time')
245+
246+
233247
def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
234248
cube = ndcube_3d_l_ln_lt_ectime
235249
sub_cube = cube[:, 0]
@@ -292,10 +306,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
292306
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
293307

294308
coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
295-
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
309+
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
296310

297311
coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True)
298-
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
312+
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
299313

300314

301315
@pytest.mark.parametrize(("ndc", "item"),

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ tests = [
3232
"pytest-mpl>=0.12",
3333
"pytest-xdist",
3434
"pytest",
35+
"pytest-memray; sys_platform != 'win32'",
3536
"scipy",
3637
"specutils",
3738
"sunpy>=5.0.0",

pytest.ini

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ addopts =
2828
--doctest-continue-on-failure
2929
mpl-results-path = figure_test_images
3030
mpl-use-full-test-name = true
31+
remote_data_strict = True
32+
doctest_subpackage_requires =
33+
docs/explaining_ndcube/* = numpy>=2.0.0
34+
markers =
35+
limit_memory: pytest-memray marker to fail a test if too much memory used
3136
filterwarnings =
3237
# Turn all warnings into errors so they do not pass silently.
3338
error
@@ -53,6 +58,3 @@ filterwarnings =
5358
ignore:FigureCanvasAgg is non-interactive, and thus cannot be shown:UserWarning
5459
# Oldestdeps from gWCS
5560
ignore:pkg_resources is deprecated as an API:DeprecationWarning
56-
remote_data_strict = True
57-
doctest_subpackage_requires =
58-
docs/explaining_ndcube/* = numpy>=2.0.0

0 commit comments

Comments
 (0)