|
12 | 12 | import astropy.nddata |
13 | 13 | import astropy.units as u |
14 | 14 | from astropy.units import UnitsError |
| 15 | +from astropy.wcs.utils import _split_matrix |
| 16 | + |
| 17 | +from ndcube.utils.wcs import world_axis_to_pixel_axes |
15 | 18 |
|
16 | 19 | try: |
17 | 20 | # Import sunpy coordinates if available to register the frames and WCS functions with astropy |
|
20 | 23 | pass |
21 | 24 |
|
22 | 25 | from astropy.wcs import WCS |
23 | | -from astropy.wcs.utils import _split_matrix |
24 | 26 | from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper |
25 | 27 | from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects |
26 | 28 |
|
@@ -479,24 +481,76 @@ def quantity(self): |
479 | 481 | """Unitful representation of the NDCube data.""" |
480 | 482 | return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED) |
481 | 483 |
|
482 | | - def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units): |
483 | | - # Create meshgrid of all pixel coordinates. |
484 | | - # If user wants pixel_corners, set pixel values to pixel corners. |
485 | | - # Else make pixel centers. |
| 484 | + def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units): |
| 485 | + """ |
| 486 | + Generate world coordinates for independent axes. |
| 487 | +
|
| 488 | + The idea is to workout only the specific grid that is needed for independent axes. |
| 489 | + This speeds up the calculation of world coordinates and reduces memory usage. |
| 490 | +
|
| 491 | + Parameters |
| 492 | + ---------- |
| 493 | + pixel_corners : bool |
| 494 | + If one needs pixel corners, otherwise pixel centers. |
| 495 | + wcs : astropy.wcs.WCS |
| 496 | + The WCS. |
| 497 | + needed_axes : array-like |
| 498 | + The required pixel axes. |
| 499 | + units : bool |
| 500 | + If units are needed. |
| 501 | +
|
| 502 | + Returns |
| 503 | + ------- |
| 504 | + array-like |
| 505 | + The world coordinates. |
| 506 | + """ |
| 507 | + needed_axes = np.array(needed_axes).squeeze() |
| 508 | + if self.data.ndim in needed_axes: |
| 509 | + required_axes = needed_axes - 1 |
| 510 | + else: |
| 511 | + required_axes = needed_axes |
| 512 | + lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes]) |
| 513 | + indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]] |
| 514 | + world_coords = wcs.pixel_to_world_values(*indices) |
| 515 | + if units: |
| 516 | + world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes]) |
| 517 | + return world_coords |
| 518 | + |
| 519 | + def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units): |
| 520 | + """ |
| 521 | + Generate world coordinates for dependent axes. |
| 522 | +
|
| 523 | + This will work out the exact grid that is needed for dependent axes |
| 524 | + and can be time and memory consuming. |
| 525 | +
|
| 526 | + Parameters |
| 527 | + ---------- |
| 528 | + pixel_corners : bool |
| 529 | + If one needs pixel corners, otherwise pixel centers. |
| 530 | + wcs : astropy.wcs.WCS |
| 531 | + The WCS. |
| 532 | + needed_axes : array-like |
| 533 | + The required pixel axes. |
| 534 | + units : bool |
| 535 | + If units are needed. |
| 536 | +
|
| 537 | + Returns |
| 538 | + ------- |
| 539 | + array-like |
| 540 | + The world coordinates. |
| 541 | + """ |
486 | 542 | pixel_shape = self.data.shape[::-1] |
487 | 543 | if pixel_corners: |
488 | 544 | pixel_shape = tuple(np.array(pixel_shape) + 1) |
489 | 545 | ranges = [np.arange(i) - 0.5 for i in pixel_shape] |
490 | 546 | else: |
491 | 547 | ranges = [np.arange(i) for i in pixel_shape] |
492 | | - |
493 | 548 | # Limit the pixel dimensions to the ones present in the ExtraCoords |
494 | 549 | if isinstance(wcs, ExtraCoords): |
495 | 550 | ranges = [ranges[i] for i in wcs.mapping] |
496 | 551 | wcs = wcs.wcs |
497 | 552 | if wcs is None: |
498 | | - return [] |
499 | | - |
| 553 | + return () |
500 | 554 | # This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is |
501 | 555 | # required so values_to_high_level_objects in the calling function doesn't crash or warn |
502 | 556 | world_coords = [0] * wcs.world_n_dim |
@@ -528,71 +582,92 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) |
528 | 582 | array_slice[wcs.axis_correlation_matrix[idx]] = slice(None) |
529 | 583 | tmp_world = world[idx][tuple(array_slice)].T |
530 | 584 | world_coords[idx] = tmp_world |
531 | | - |
532 | 585 | if units: |
533 | 586 | for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)): |
534 | 587 | world_coords[i] = coord << u.Unit(unit) |
| 588 | + return world_coords |
| 589 | + |
| 590 | + def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None): |
| 591 | + """ |
| 592 | + Private method to generate world coordinates. |
| 593 | +
|
| 594 | + Handles both dependent and independent axes. |
| 595 | +
|
| 596 | + Parameters |
| 597 | + ---------- |
| 598 | + pixel_corners : bool |
| 599 | + If one needs pixel corners, otherwise pixel centers. |
| 600 | + wcs : astropy.wcs.WCS |
| 601 | + The WCS. |
| 602 | + needed_axes : array-like |
| 603 | + The axes that are needed. |
| 604 | + units : bool |
| 605 | + If units are needed. |
535 | 606 |
|
| 607 | + Returns |
| 608 | + ------- |
| 609 | + array-like |
| 610 | + The world coordinates. |
| 611 | + """ |
| 612 | + axes_are_independent = [] |
| 613 | + pixel_axes = set() |
| 614 | + for world_axis in needed_axes: |
| 615 | + pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix) |
| 616 | + axes_are_independent.append(len(pix_ax) == 1) |
| 617 | + pixel_axes = pixel_axes.union(set(pix_ax)) |
| 618 | + pixel_axes = list(pixel_axes) |
| 619 | + if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0: |
| 620 | + world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units) |
| 621 | + else: |
| 622 | + world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units) |
536 | 623 | return world_coords |
537 | 624 |
|
538 | 625 | @utils.cube.sanitize_wcs |
539 | 626 | def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): |
540 | 627 | # Docstring in NDCubeABC. |
541 | 628 | if isinstance(wcs, BaseHighLevelWCS): |
542 | 629 | wcs = wcs.low_level_wcs |
543 | | - |
544 | 630 | orig_wcs = wcs |
545 | 631 | if isinstance(wcs, ExtraCoords): |
546 | 632 | wcs = wcs.wcs |
547 | 633 | if not wcs: |
548 | 634 | return () |
549 | | - |
550 | 635 | object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components]) |
551 | 636 | unique_obj_names = utils.misc.unique_sorted(object_names) |
552 | 637 | world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names] |
553 | | - |
554 | 638 | # Create a mapping from world index in the WCS to object index in axes_coords |
555 | 639 | world_index_to_object_index = {} |
556 | 640 | for object_index, world_axes in enumerate(world_axes_for_obj): |
557 | 641 | for world_index in world_axes: |
558 | 642 | world_index_to_object_index[world_index] = object_index |
559 | | - |
560 | 643 | world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) |
561 | 644 | object_indices = utils.misc.unique_sorted( |
562 | 645 | [world_index_to_object_index[world_index] for world_index in world_indices] |
563 | 646 | ) |
564 | | - |
565 | | - axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False) |
566 | | - |
| 647 | + axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False) |
567 | 648 | axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs) |
568 | | - |
569 | 649 | if not axes: |
570 | 650 | return tuple(axes_coords) |
571 | | - |
572 | 651 | return tuple(axes_coords[i] for i in object_indices) |
573 | 652 |
|
574 | 653 | @utils.cube.sanitize_wcs |
575 | 654 | def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None): |
576 | 655 | # Docstring in NDCubeABC. |
577 | 656 | if isinstance(wcs, BaseHighLevelWCS): |
578 | 657 | wcs = wcs.low_level_wcs |
579 | | - |
580 | 658 | orig_wcs = wcs |
581 | 659 | if isinstance(wcs, ExtraCoords): |
582 | 660 | wcs = wcs.wcs |
583 | | - |
| 661 | + if not wcs: |
| 662 | + return () |
584 | 663 | world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) |
585 | | - |
586 | | - axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True) |
587 | | - |
| 664 | + axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True) |
588 | 665 | world_axis_physical_types = wcs.world_axis_physical_types |
589 | | - |
590 | 666 | # If user has supplied axes, extract only the |
591 | 667 | # world coords that correspond to those axes. |
592 | 668 | if axes: |
593 | 669 | axes_coords = [axes_coords[i] for i in world_indices] |
594 | 670 | world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices]) |
595 | | - |
596 | 671 | # Return in array order. |
597 | 672 | # First replace characters in physical types forbidden for namedtuple identifiers. |
598 | 673 | identifiers = [] |
|
0 commit comments