Skip to content

Commit 9d2bb0c

Browse files
committed
undo style changes
1 parent 941216f commit 9d2bb0c

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

ndcube/ndcube.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -467,40 +467,36 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
467467
pixel_shape = self.data.shape[::-1]
468468
if pixel_corners:
469469
pixel_shape = tuple(np.array(pixel_shape) + 1)
470-
pixel_ranges = [np.arange(i) - 0.5 for i in pixel_shape]
470+
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
471471
else:
472-
pixel_ranges = [np.arange(i) for i in pixel_shape]
472+
ranges = [np.arange(i) for i in pixel_shape]
473473

474474
# Limit the pixel dimensions to the ones present in the ExtraCoords
475475
if isinstance(wcs, ExtraCoords):
476-
pixel_ranges = [pixel_ranges[i] for i in wcs.mapping]
476+
ranges = [ranges[i] for i in wcs.mapping]
477477
wcs = wcs.wcs
478478
if wcs is None:
479479
return []
480+
480481
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
481482
# required so values_to_high_level_objects in the calling function doesn't crash or warn
482483
world_coords = [0] * wcs.world_n_dim
483484
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
484-
if (
485-
needed_axes is not None
486-
and len(needed_axes)
487-
and all(
488-
world_axis not in needed_axes
489-
for world_axis in world_axes_indices
490-
)
491-
):
485+
if (needed_axes is not None
486+
and len(needed_axes)
487+
and not any(world_axis in needed_axes for world_axis in world_axes_indices)):
492488
# needed_axes indicates which values in world_coords will be used by the calling
493489
# function, so skip this iteration if we won't be producing any of those values
494490
continue
495491
# First construct a range of pixel indices for this set of coupled dimensions
496-
pixel_ranges_subset = [pixel_ranges[idx] for idx in pixel_axes_indices]
497-
# Then get a set of non-correlated dimensions
492+
sub_range = [ranges[idx] for idx in pixel_axes_indices]
493+
# Then get a set of non correlated dimensions
498494
non_corr_axes = set(range(wcs.pixel_n_dim)) - set(pixel_axes_indices)
499495
# And inject 0s for those coordinates
500496
for idx in non_corr_axes:
501-
pixel_ranges_subset.insert(idx, 0)
502-
# Generate a grid of broadcast-able pixel indices for all pixel dimensions
503-
grid = np.meshgrid(*pixel_ranges_subset, indexing='ij')
497+
sub_range.insert(idx, 0)
498+
# Generate a grid of broadcastable pixel indices for all pixel dimensions
499+
grid = np.meshgrid(*sub_range, indexing='ij')
504500
# Convert to world coordinates
505501
world = wcs.pixel_to_world_values(*grid)
506502
# TODO: this isinstance check is to mitigate https://github.com/spacetelescope/gwcs/pull/332
@@ -513,6 +509,7 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
513509
array_slice[wcs.axis_correlation_matrix[idx]] = slice(None)
514510
tmp_world = world[idx][tuple(array_slice)].T
515511
world_coords[idx] = tmp_world
512+
516513
if units:
517514
for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
518515
world_coords[i] = coord << u.Unit(unit)
@@ -524,47 +521,61 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
524521
# Docstring in NDCubeABC.
525522
if isinstance(wcs, BaseHighLevelWCS):
526523
wcs = wcs.low_level_wcs
524+
527525
orig_wcs = wcs
528526
if isinstance(wcs, ExtraCoords):
529527
wcs = wcs.wcs
530528
if not wcs:
531529
return ()
530+
532531
object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
533532
unique_obj_names = utils.misc.unique_sorted(object_names)
534533
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]
534+
535535
# Create a mapping from world index in the WCS to object index in axes_coords
536536
world_index_to_object_index = {}
537537
for object_index, world_axes in enumerate(world_axes_for_obj):
538538
for world_index in world_axes:
539539
world_index_to_object_index[world_index] = object_index
540+
540541
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
541542
object_indices = utils.misc.unique_sorted(
542543
[world_index_to_object_index[world_index] for world_index in world_indices]
543544
)
545+
544546
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False)
547+
545548
axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)
549+
546550
if not axes:
547551
return tuple(axes_coords)
552+
548553
return tuple(axes_coords[i] for i in object_indices)
549554

550555
@utils.cube.sanitize_wcs
551556
def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
552557
# Docstring in NDCubeABC.
553558
if isinstance(wcs, BaseHighLevelWCS):
554559
wcs = wcs.low_level_wcs
560+
555561
orig_wcs = wcs
556562
if isinstance(wcs, ExtraCoords):
557563
wcs = wcs.wcs
558564
if not wcs:
559565
return ()
566+
560567
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
568+
561569
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True)
570+
562571
world_axis_physical_types = wcs.world_axis_physical_types
572+
563573
# If user has supplied axes, extract only the
564574
# world coords that correspond to those axes.
565575
if axes:
566576
axes_coords = [axes_coords[i] for i in world_indices]
567577
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])
578+
568579
# Return in array order.
569580
# First replace characters in physical types forbidden for namedtuple identifiers.
570581
identifiers = []

ndcube/tests/test_ndcube.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):
235235
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
236236

237237

238-
def test_axis_world_coords_crazy(ndcube_3d_wave_lt_ln_ec_time):
238+
def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time):
239239
# This replicates a specific NDCube test in the visualization.rst
240240
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
241241
assert len(coords) == 1

0 commit comments

Comments
 (0)