@@ -467,40 +467,36 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
467467 pixel_shape = self .data .shape [::- 1 ]
468468 if pixel_corners :
469469 pixel_shape = tuple (np .array (pixel_shape ) + 1 )
470- pixel_ranges = [np .arange (i ) - 0.5 for i in pixel_shape ]
470+ ranges = [np .arange (i ) - 0.5 for i in pixel_shape ]
471471 else :
472- pixel_ranges = [np .arange (i ) for i in pixel_shape ]
472+ ranges = [np .arange (i ) for i in pixel_shape ]
473473
474474 # Limit the pixel dimensions to the ones present in the ExtraCoords
475475 if isinstance (wcs , ExtraCoords ):
476- pixel_ranges = [pixel_ranges [i ] for i in wcs .mapping ]
476+ ranges = [ranges [i ] for i in wcs .mapping ]
477477 wcs = wcs .wcs
478478 if wcs is None :
479479 return []
480+
480481 # This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
481482 # required so values_to_high_level_objects in the calling function doesn't crash or warn
482483 world_coords = [0 ] * wcs .world_n_dim
483484 for (pixel_axes_indices , world_axes_indices ) in _split_matrix (wcs .axis_correlation_matrix ):
484- if (
485- needed_axes is not None
486- and len (needed_axes )
487- and all (
488- world_axis not in needed_axes
489- for world_axis in world_axes_indices
490- )
491- ):
485+ if (needed_axes is not None
486+ and len (needed_axes )
487+ and not any (world_axis in needed_axes for world_axis in world_axes_indices )):
492488 # needed_axes indicates which values in world_coords will be used by the calling
493489 # function, so skip this iteration if we won't be producing any of those values
494490 continue
495491 # First construct a range of pixel indices for this set of coupled dimensions
496- pixel_ranges_subset = [pixel_ranges [idx ] for idx in pixel_axes_indices ]
497- # Then get a set of non- correlated dimensions
492+ sub_range = [ranges [idx ] for idx in pixel_axes_indices ]
493+ # Then get a set of non correlated dimensions
498494 non_corr_axes = set (range (wcs .pixel_n_dim )) - set (pixel_axes_indices )
499495 # And inject 0s for those coordinates
500496 for idx in non_corr_axes :
501- pixel_ranges_subset .insert (idx , 0 )
502- # Generate a grid of broadcast-able pixel indices for all pixel dimensions
503- grid = np .meshgrid (* pixel_ranges_subset , indexing = 'ij' )
497+ sub_range .insert (idx , 0 )
498+ # Generate a grid of broadcastable pixel indices for all pixel dimensions
499+ grid = np .meshgrid (* sub_range , indexing = 'ij' )
504500 # Convert to world coordinates
505501 world = wcs .pixel_to_world_values (* grid )
506502 # TODO: this isinstance check is to mitigate https://github.com/spacetelescope/gwcs/pull/332
@@ -513,6 +509,7 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
513509 array_slice [wcs .axis_correlation_matrix [idx ]] = slice (None )
514510 tmp_world = world [idx ][tuple (array_slice )].T
515511 world_coords [idx ] = tmp_world
512+
516513 if units :
517514 for i , (coord , unit ) in enumerate (zip (world_coords , wcs .world_axis_units )):
518515 world_coords [i ] = coord << u .Unit (unit )
@@ -524,47 +521,59 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
524521 # Docstring in NDCubeABC.
525522 if isinstance (wcs , BaseHighLevelWCS ):
526523 wcs = wcs .low_level_wcs
524+
527525 orig_wcs = wcs
528526 if isinstance (wcs , ExtraCoords ):
529527 wcs = wcs .wcs
530- if not wcs :
531- return ()
528+
532529 object_names = np .array ([wao_comp [0 ] for wao_comp in wcs .world_axis_object_components ])
533530 unique_obj_names = utils .misc .unique_sorted (object_names )
534531 world_axes_for_obj = [np .where (object_names == name )[0 ] for name in unique_obj_names ]
532+
535533 # Create a mapping from world index in the WCS to object index in axes_coords
536534 world_index_to_object_index = {}
537535 for object_index , world_axes in enumerate (world_axes_for_obj ):
538536 for world_index in world_axes :
539537 world_index_to_object_index [world_index ] = object_index
538+
540539 world_indices = utils .wcs .calculate_world_indices_from_axes (wcs , axes )
541540 object_indices = utils .misc .unique_sorted (
542541 [world_index_to_object_index [world_index ] for world_index in world_indices ]
543542 )
543+
544544 axes_coords = self ._generate_world_coords (pixel_corners , orig_wcs , world_indices , units = False )
545+
545546 axes_coords = values_to_high_level_objects (* axes_coords , low_level_wcs = wcs )
547+
546548 if not axes :
547549 return tuple (axes_coords )
550+
548551 return tuple (axes_coords [i ] for i in object_indices )
549552
550553 @utils .cube .sanitize_wcs
551554 def axis_world_coords_values (self , * axes , pixel_corners = False , wcs = None ):
552555 # Docstring in NDCubeABC.
553556 if isinstance (wcs , BaseHighLevelWCS ):
554557 wcs = wcs .low_level_wcs
558+
555559 orig_wcs = wcs
556560 if isinstance (wcs , ExtraCoords ):
557561 wcs = wcs .wcs
562+
558563 if not wcs :
559564 return ()
560565 world_indices = utils .wcs .calculate_world_indices_from_axes (wcs , axes )
566+
561567 axes_coords = self ._generate_world_coords (pixel_corners , orig_wcs , world_indices , units = True )
568+
562569 world_axis_physical_types = wcs .world_axis_physical_types
570+
563571 # If user has supplied axes, extract only the
564572 # world coords that correspond to those axes.
565573 if axes :
566574 axes_coords = [axes_coords [i ] for i in world_indices ]
567575 world_axis_physical_types = tuple (np .array (world_axis_physical_types )[world_indices ])
576+
568577 # Return in array order.
569578 # First replace characters in physical types forbidden for namedtuple identifiers.
570579 identifiers = []
0 commit comments