Skip to content

Commit d600159

Browse files
committed
Another go that failed
1 parent 8d79d7d commit d600159

File tree

2 files changed

+34
-51
lines changed

2 files changed

+34
-51
lines changed

ndcube/ndcube.py

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,7 @@ def quantity(self):
445445
"""Unitful representation of the NDCube data."""
446446
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)
447447

448-
449-
def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, units):
448+
def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
450449
"""
451450
Generate world coordinates for independent axes.
452451
@@ -459,8 +458,8 @@ def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, uni
459458
If one needs pixel corners, otherwise pixel centers.
460459
wcs : astropy.wcs.WCS
461460
The WCS.
462-
pixel_axes : array-like
463-
The pixel axes.
461+
needed_axes : array-like
462+
The required pixel axes.
464463
units : bool
465464
If units are needed.
466465
@@ -469,29 +468,19 @@ def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, uni
469468
array-like
470469
The world coordinates.
471470
"""
472-
naxes = len(self.data.shape)
473-
pixel_indices = [np.array([0], dtype=int).reshape([1] * naxes).squeeze()] * naxes
474-
for pixel_axis in pixel_axes:
475-
len_axis = self.data.shape[::-1][pixel_axis]
476-
# Define limits of desired pixel range based on whether corners or centers are desired
477-
lims = (-0.5, len_axis + 1) if pixel_corners else (0, len_axis)
478-
pix_ind = np.arange(lims[0], lims[1])
479-
shape = [1] * naxes
480-
shape[pixel_axis] = len(pix_ind)
481-
pixel_indices[pixel_axis] = pix_ind.reshape(shape)
482-
world_coords = wcs.pixel_to_world_values(*pixel_indices)
483-
# TODO: Remove NaNs??? These should not be here
484-
if np.isnan(world_coords).any():
485-
if isinstance(world_coords, tuple| list):
486-
world_coords = [world_coord[~np.isnan(world_coord)] for world_coord in world_coords]
487-
else:
488-
world_coords = world_coords[~np.isnan(world_coords)]
471+
needed_axes = np.array(needed_axes).squeeze()
472+
if self.data.ndim in needed_axes:
473+
required_axes = needed_axes - 1
474+
else:
475+
required_axes = needed_axes
476+
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
477+
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
478+
world_coords = wcs.pixel_to_world_values(*indices)
489479
if units:
490-
mod = abs(wcs.world_n_dim - naxes) if wcs.world_n_dim > naxes else 0
491-
world_coords = world_coords << u.Unit(wcs.world_axis_units[np.squeeze(pixel_axes)+mod])
480+
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
492481
return world_coords
493482

494-
def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units):
483+
def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
495484
"""
496485
Generate world coordinates for dependent axes.
497486
@@ -504,8 +493,8 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units
504493
If one needs pixel corners, otherwise pixel centers.
505494
wcs : astropy.wcs.WCS
506495
The WCS.
507-
pixel_axes : array-like
508-
The pixel axes.
496+
needed_axes : array-like
497+
The required pixel axes.
509498
units : bool
510499
If units are needed.
511500
@@ -520,13 +509,19 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units
520509
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
521510
else:
522511
ranges = [np.arange(i) for i in pixel_shape]
512+
# Limit the pixel dimensions to the ones present in the ExtraCoords
513+
if isinstance(wcs, ExtraCoords):
514+
ranges = [ranges[i] for i in wcs.mapping]
515+
wcs = wcs.wcs
516+
if wcs is None:
517+
return []
523518
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
524519
# required so values_to_high_level_objects in the calling function doesn't crash or warn
525520
world_coords = [0] * wcs.world_n_dim
526521
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
527-
if (pixel_axes is not None
528-
and len(pixel_axes)
529-
and not any(world_axis in pixel_axes for world_axis in world_axes_indices)):
522+
if (needed_axes is not None
523+
and len(needed_axes)
524+
and not any(world_axis in needed_axes for world_axis in world_axes_indices)):
530525
# needed_axes indicates which values in world_coords will be used by the calling
531526
# function, so skip this iteration if we won't be producing any of those values
532527
continue
@@ -556,7 +551,6 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units
556551
world_coords[i] = coord << u.Unit(unit)
557552
return world_coords
558553

559-
560554
def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes=None, units=None):
561555
"""
562556
Private method to generate world coordinates.
@@ -579,32 +573,21 @@ def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes=None, units=
579573
array-like
580574
The world coordinates.
581575
"""
582-
# TODO: Workout why I need this twice now.
583576
if isinstance(wcs, ExtraCoords):
584577
wcs = wcs.wcs
585-
if not wcs:
586-
return ()
587-
if needed_axes is None or len(needed_axes) == 0:
588-
needed_axes = np.array(list(range(wcs.world_n_dim)),dtype=int)
589578
axes_are_independent = []
590579
pixel_axes = set()
591580
for world_axis in needed_axes:
592581
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
593582
axes_are_independent.append(len(pix_ax) == 1)
594583
pixel_axes = pixel_axes.union(set(pix_ax))
595-
if len(pixel_axes) == 1:
596-
pixel_axes = list(pixel_axes)
597-
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes):
598-
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, pixel_axes, units)
599-
else:
600-
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, pixel_axes, units)
601-
if len(world_coords) > 1 and isinstance(world_coords, tuple | list):
602-
world_coords = [np.squeeze(world_coord) for world_coord in world_coords]
584+
pixel_axes = list(pixel_axes)
585+
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
586+
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
603587
else:
604-
world_coords = np.squeeze(world_coords)
588+
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
605589
return world_coords
606590

607-
608591
@utils.cube.sanitize_wcs
609592
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
610593
# Docstring in NDCubeABC.
@@ -628,8 +611,6 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
628611
[world_index_to_object_index[world_index] for world_index in world_indices]
629612
)
630613
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False)
631-
if not isinstance(axes_coords, list):
632-
axes_coords = [axes_coords]
633614
axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)
634615
if not axes:
635616
return tuple(axes_coords)
@@ -662,8 +643,7 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
662643
identifier = identifier.replace("-", "__")
663644
identifiers.append(identifier)
664645
CoordValues = namedtuple("CoordValues", identifiers)
665-
flag = len(axes_coords) == 1 or isinstance(axes_coords, tuple | list)
666-
return CoordValues(*axes_coords[::-1]) if flag else CoordValues(axes_coords)
646+
return CoordValues(*axes_coords[::-1])
667647

668648
def crop(self, *points, wcs=None, keepdims=False):
669649
# The docstring is defined in NDCubeABC

ndcube/tests/test_ndcube.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,14 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):
178178

179179
coords = cube.axis_world_coords()
180180
assert len(coords) == 2
181+
assert isinstance(coords[0], SkyCoord)
182+
assert isinstance(coords[1], SpectralCoord)
181183

182184
coords = cube.axis_world_coords(wcs=cube.combined_wcs)
183185
assert len(coords) == 3
186+
assert isinstance(coords[0], SkyCoord)
187+
assert isinstance(coords[1], SpectralCoord)
188+
assert isinstance(coords[2], Time)
184189

185190
coords = cube.axis_world_coords(wcs=cube.extra_coords)
186191
assert len(coords) == 1
@@ -200,8 +205,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
200205
# slice the cube so extra_coords is empty, and then try and run axis_world_coords
201206
awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords)
202207
assert awc == ()
203-
sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True)
204-
assert awc == ()
205208

206209

207210
@pytest.mark.xfail(reason=">1D Tables not supported")

0 commit comments

Comments
 (0)