-
-
Notifications
You must be signed in to change notification settings - Fork 54
Faster shortcut for working out coordinates values for non-correlated WCS #780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
21ec56e
45087d0
5f75177
e00fb42
ed577eb
180e965
a251f20
0beed40
6caae98
0a9c19c
26675f6
f821a69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,9 @@ | |
| import astropy.nddata | ||
| import astropy.units as u | ||
| from astropy.units import UnitsError | ||
| from astropy.wcs.utils import _split_matrix | ||
|
|
||
| from ndcube.utils.wcs import world_axis_to_pixel_axes | ||
|
|
||
| try: | ||
| # Import sunpy coordinates if available to register the frames and WCS functions with astropy | ||
|
|
@@ -20,7 +23,6 @@ | |
| pass | ||
|
|
||
| from astropy.wcs import WCS | ||
| from astropy.wcs.utils import _split_matrix | ||
| from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper | ||
| from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects | ||
|
|
||
|
|
@@ -479,24 +481,76 @@ | |
| """Unitful representation of the NDCube data.""" | ||
| return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED) | ||
|
|
||
| def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units): | ||
| # Create meshgrid of all pixel coordinates. | ||
| # If user wants pixel_corners, set pixel values to pixel corners. | ||
| # Else make pixel centers. | ||
| def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units): | ||
| """ | ||
| Generate world coordinates for independent axes. | ||
|
|
||
| The idea is to workout only the specific grid that is needed for independent axes. | ||
| This speeds up the calculation of world coordinates and reduces memory usage. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pixel_corners : bool | ||
| If one needs pixel corners, otherwise pixel centers. | ||
| wcs : astropy.wcs.WCS | ||
| The WCS. | ||
| needed_axes : array-like | ||
| The required pixel axes. | ||
| units : bool | ||
| If units are needed. | ||
|
|
||
| Returns | ||
| ------- | ||
| array-like | ||
| The world coordinates. | ||
| """ | ||
| needed_axes = np.array(needed_axes).squeeze() | ||
| if self.data.ndim in needed_axes: | ||
| required_axes = needed_axes - 1 | ||
| else: | ||
| required_axes = needed_axes | ||
| lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes]) | ||
| indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]] | ||
| world_coords = wcs.pixel_to_world_values(*indices) | ||
| if units: | ||
| world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes]) | ||
| return world_coords | ||
|
|
||
| def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units): | ||
| """ | ||
| Generate world coordinates for dependent axes. | ||
|
|
||
| This will work out the exact grid that is needed for dependent axes | ||
| and can be time and memory consuming. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pixel_corners : bool | ||
| If one needs pixel corners, otherwise pixel centers. | ||
| wcs : astropy.wcs.WCS | ||
| The WCS. | ||
| needed_axes : array-like | ||
| The required pixel axes. | ||
| units : bool | ||
| If units are needed. | ||
|
|
||
| Returns | ||
| ------- | ||
| array-like | ||
| The world coordinates. | ||
| """ | ||
| pixel_shape = self.data.shape[::-1] | ||
| if pixel_corners: | ||
| pixel_shape = tuple(np.array(pixel_shape) + 1) | ||
| ranges = [np.arange(i) - 0.5 for i in pixel_shape] | ||
| else: | ||
| ranges = [np.arange(i) for i in pixel_shape] | ||
|
|
||
| # Limit the pixel dimensions to the ones present in the ExtraCoords | ||
| if isinstance(wcs, ExtraCoords): | ||
| ranges = [ranges[i] for i in wcs.mapping] | ||
| wcs = wcs.wcs | ||
| if wcs is None: | ||
| return [] | ||
|
|
||
| return () | ||
| # This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is | ||
| # required so values_to_high_level_objects in the calling function doesn't crash or warn | ||
| world_coords = [0] * wcs.world_n_dim | ||
|
|
@@ -528,71 +582,92 @@ | |
| array_slice[wcs.axis_correlation_matrix[idx]] = slice(None) | ||
| tmp_world = world[idx][tuple(array_slice)].T | ||
| world_coords[idx] = tmp_world | ||
|
|
||
| if units: | ||
| for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)): | ||
| world_coords[i] = coord << u.Unit(unit) | ||
| return world_coords | ||
|
|
||
| def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None): | ||
| """ | ||
| Private method to generate world coordinates. | ||
|
|
||
| Handles both dependent and independent axes. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pixel_corners : bool | ||
| If one needs pixel corners, otherwise pixel centers. | ||
| wcs : astropy.wcs.WCS | ||
| The WCS. | ||
| needed_axes : array-like | ||
| The axes that are needed. | ||
| units : bool | ||
| If units are needed. | ||
|
|
||
| Returns | ||
| ------- | ||
| array-like | ||
| The world coordinates. | ||
| """ | ||
| axes_are_independent = [] | ||
| pixel_axes = set() | ||
| for world_axis in needed_axes: | ||
| pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix) | ||
| axes_are_independent.append(len(pix_ax) == 1) | ||
| pixel_axes = pixel_axes.union(set(pix_ax)) | ||
| pixel_axes = list(pixel_axes) | ||
| if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0: | ||
| world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units) | ||
| else: | ||
| world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units) | ||
| return world_coords | ||
|
|
||
| @utils.cube.sanitize_wcs | ||
| def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): | ||
| # Docstring in NDCubeABC. | ||
| if isinstance(wcs, BaseHighLevelWCS): | ||
| wcs = wcs.low_level_wcs | ||
|
|
||
| orig_wcs = wcs | ||
| if isinstance(wcs, ExtraCoords): | ||
| wcs = wcs.wcs | ||
| if not wcs: | ||
| return () | ||
|
|
||
| object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components]) | ||
| unique_obj_names = utils.misc.unique_sorted(object_names) | ||
| world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names] | ||
|
|
||
| # Create a mapping from world index in the WCS to object index in axes_coords | ||
| world_index_to_object_index = {} | ||
| for object_index, world_axes in enumerate(world_axes_for_obj): | ||
| for world_index in world_axes: | ||
| world_index_to_object_index[world_index] = object_index | ||
|
|
||
| world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) | ||
| object_indices = utils.misc.unique_sorted( | ||
| [world_index_to_object_index[world_index] for world_index in world_indices] | ||
| ) | ||
|
|
||
| axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False) | ||
|
|
||
| axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False) | ||
| axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs) | ||
|
|
||
| if not axes: | ||
| return tuple(axes_coords) | ||
|
|
||
| return tuple(axes_coords[i] for i in object_indices) | ||
|
|
||
| @utils.cube.sanitize_wcs | ||
| def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None): | ||
| # Docstring in NDCubeABC. | ||
| if isinstance(wcs, BaseHighLevelWCS): | ||
| wcs = wcs.low_level_wcs | ||
|
|
||
| orig_wcs = wcs | ||
| if isinstance(wcs, ExtraCoords): | ||
| wcs = wcs.wcs | ||
|
|
||
| if not wcs: | ||
| return () | ||
|
Comment on lines
+661
to
+662
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was missing and is included in the other version of this method, so I added it. |
||
| world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) | ||
|
|
||
| axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True) | ||
|
|
||
| axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True) | ||
| world_axis_physical_types = wcs.world_axis_physical_types | ||
|
|
||
| # If user has supplied axes, extract only the | ||
| # world coords that correspond to those axes. | ||
| if axes: | ||
| axes_coords = [axes_coords[i] for i in world_indices] | ||
| world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices]) | ||
|
|
||
| # Return in array order. | ||
| # First replace characters in physical types forbidden for namedtuple identifiers. | ||
| identifiers = [] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| from astropy.coordinates import SkyCoord, SpectralCoord | ||
| from astropy.io import fits | ||
| from astropy.nddata import UnknownUncertainty | ||
| from astropy.tests.helper import assert_quantity_allclose | ||
| from astropy.time import Time | ||
| from astropy.units import UnitsError | ||
| from astropy.wcs import WCS | ||
|
|
@@ -177,9 +178,19 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime): | |
|
|
||
| coords = cube.axis_world_coords() | ||
| assert len(coords) == 2 | ||
| assert isinstance(coords[0], SkyCoord) | ||
| assert coords[0].shape == (5, 8) | ||
| assert isinstance(coords[1], SpectralCoord) | ||
| assert coords[1].shape == (10,) | ||
|
|
||
| coords = cube.axis_world_coords(wcs=cube.combined_wcs) | ||
| assert len(coords) == 3 | ||
| assert isinstance(coords[0], SkyCoord) | ||
| assert coords[0].shape == (5, 8) | ||
| assert isinstance(coords[1], SpectralCoord) | ||
| assert coords[1].shape == (10,) | ||
| assert isinstance(coords[2], Time) | ||
| assert coords[2].shape == (5,) | ||
|
|
||
| coords = cube.axis_world_coords(wcs=cube.extra_coords) | ||
| assert len(coords) == 1 | ||
|
|
@@ -199,8 +210,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime): | |
| # slice the cube so extra_coords is empty, and then try and run axis_world_coords | ||
| awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords) | ||
| assert awc == () | ||
| sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True) | ||
| assert awc == () | ||
|
Comment on lines
-202
to
-203
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now the private method does not handle ECs, this happens now at the higher level. |
||
|
|
||
|
|
||
| @pytest.mark.xfail(reason=">1D Tables not supported") | ||
|
|
@@ -235,13 +244,31 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l): | |
| assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) | ||
|
|
||
|
|
||
| def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time): | ||
| # This replicates a specific NDCube object in visualization.rst | ||
| coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs) | ||
| assert len(coords) == 1 | ||
| assert isinstance(coords[0], Time) | ||
| assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000'])) | ||
|
|
||
| coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs) | ||
| assert len(coords) == 1 | ||
| assert isinstance(coords.time, u.Quantity) | ||
| assert_quantity_allclose(coords.time, [0, 60, 120] * u.second) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axes", [[-1], [2], ["em"]]) | ||
| def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): | ||
|
|
||
| # We go from 4 pixels to 6 pixels when we add pixel corners | ||
nabobalis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False) | ||
| assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) | ||
|
|
||
| coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True) | ||
| assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) | ||
| assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) | ||
nabobalis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True) | ||
| assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) | ||
| assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(("ndc", "item"), | ||
|
|
@@ -252,10 +279,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): | |
| indirect=("ndc",)) | ||
| def test_axis_world_coords_sliced_all_3d(ndc, item): | ||
| coords = ndc[item].axis_world_coords_values() | ||
| assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) | ||
| assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) | ||
|
|
||
| coords = ndc[item].axis_world_coords() | ||
| assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) | ||
| assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why have these asserts had to change?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wish I knew, I assume since I broke the code.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this is common in most of the other tests, you return a tuple of length N and you have to escape it always to get the coord. There are lots of tests where you do len(coords) and it is 1 but then you need to index the return to get the coord info. |
||
|
|
||
|
|
||
| @pytest.mark.parametrize(("ndc", "item"), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.