1313from numpy .testing import assert_equal
1414
1515import astropy
16+ from astropy .wcs .wcsapi import BaseHighLevelWCS
1617from astropy .wcs .wcsapi .fitswcs import SlicedFITSWCS
1718from astropy .wcs .wcsapi .low_level_api import BaseLowLevelWCS
1819from astropy .wcs .wcsapi .wrappers .sliced_wcs import sanitize_slices
@@ -95,9 +96,11 @@ def assert_metas_equal(test_input, expected_output):
9596 assert test_input [key ] == expected_output [key ]
9697
9798
98- def assert_cubes_equal (test_input , expected_cube ):
99+ def assert_cubes_equal (test_input , expected_cube , check_data = True ):
99100 assert isinstance (test_input , type (expected_cube ))
100101 assert np .all (test_input .mask == expected_cube .mask )
102+ if check_data :
103+ np .testing .assert_array_equal (test_input .data , expected_cube .data )
101104 assert_wcs_are_equal (test_input .wcs , expected_cube .wcs )
102105 if test_input .uncertainty :
103106 assert test_input .uncertainty .array .shape == expected_cube .uncertainty .array .shape
@@ -110,12 +113,12 @@ def assert_cubes_equal(test_input, expected_cube):
110113 assert_extra_coords_equal (test_input .extra_coords , expected_cube .extra_coords )
111114
112115
113- def assert_cubesequences_equal (test_input , expected_sequence ):
116+ def assert_cubesequences_equal (test_input , expected_sequence , check_data = True ):
114117 assert isinstance (test_input , type (expected_sequence ))
115118 assert_metas_equal (test_input .meta , expected_sequence .meta )
116119 assert test_input ._common_axis == expected_sequence ._common_axis
117120 for i , cube in enumerate (test_input .data ):
118- assert_cubes_equal (cube , expected_sequence .data [i ])
121+ assert_cubes_equal (cube , expected_sequence .data [i ], check_data = check_data )
119122
120123
121124def assert_wcs_are_equal (wcs1 , wcs2 ):
@@ -140,7 +143,12 @@ def assert_wcs_are_equal(wcs1, wcs2):
140143 assert wcs1 .world_axis_units == wcs2 .world_axis_units
141144 assert_equal (wcs1 .axis_correlation_matrix , wcs2 .axis_correlation_matrix )
142145 assert wcs1 .pixel_bounds == wcs2 .pixel_bounds
143-
146+ if wcs1 .pixel_shape is not None :
147+ random_idx = np .random .randint (wcs1 .pixel_shape ,size = [10 ,wcs1 .pixel_n_dim ])
148+ # SlicedLowLevelWCS vs BaseHighLevelWCS don't have the same pixel_to_world method
149+ low_level_wcs1 = wcs1 .low_level_wcs if isinstance (wcs1 , BaseHighLevelWCS ) else wcs1
150+ low_level_wcs2 = wcs2 .low_level_wcs if isinstance (wcs2 , BaseHighLevelWCS ) else wcs2
151+ np .testing .assert_array_equal (low_level_wcs1 .pixel_to_world_values (* random_idx .T ), low_level_wcs2 .pixel_to_world_values (* random_idx .T ))
144152
145153def create_sliced_wcs (wcs , item , dim ):
146154 """
@@ -152,15 +160,15 @@ def create_sliced_wcs(wcs, item, dim):
152160 return SlicedFITSWCS (wcs , item )
153161
154162
155- def assert_collections_equal (collection1 , collection2 ):
163+ def assert_collections_equal (collection1 , collection2 , check_data = True ):
156164 assert collection1 .keys () == collection2 .keys ()
157165 assert collection1 .aligned_axes == collection2 .aligned_axes
158166 for cube1 , cube2 in zip (collection1 .values (), collection2 .values ()):
159167 # Check cubes are same type.
160168 assert type (cube1 ) is type (cube2 )
161169 if isinstance (cube1 , NDCube ):
162- assert_cubes_equal (cube1 , cube2 )
170+ assert_cubes_equal (cube1 , cube2 , check_data = check_data )
163171 elif isinstance (cube1 , NDCubeSequence ):
164- assert_cubesequences_equal (cube1 , cube2 )
172+ assert_cubesequences_equal (cube1 , cube2 , check_data = check_data )
165173 else :
166174 raise TypeError (f"Unsupported Type in NDCollection: { type (cube1 )} " )
0 commit comments