Skip to content

Commit 4858041

Browse files
authored
Merge pull request #693 from nabobalis/wcs
Check more things in test helpers
2 parents 3001645 + bbdfba7 commit 4858041

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

ndcube/tests/helpers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from numpy.testing import assert_equal
1414

1515
import astropy
16+
from astropy.wcs.wcsapi import BaseHighLevelWCS
1617
from astropy.wcs.wcsapi.fitswcs import SlicedFITSWCS
1718
from astropy.wcs.wcsapi.low_level_api import BaseLowLevelWCS
1819
from astropy.wcs.wcsapi.wrappers.sliced_wcs import sanitize_slices
@@ -95,9 +96,11 @@ def assert_metas_equal(test_input, expected_output):
9596
assert test_input[key] == expected_output[key]
9697

9798

98-
def assert_cubes_equal(test_input, expected_cube):
99+
def assert_cubes_equal(test_input, expected_cube, check_data=True):
99100
assert isinstance(test_input, type(expected_cube))
100101
assert np.all(test_input.mask == expected_cube.mask)
102+
if check_data:
103+
np.testing.assert_array_equal(test_input.data, expected_cube.data)
101104
assert_wcs_are_equal(test_input.wcs, expected_cube.wcs)
102105
if test_input.uncertainty:
103106
assert test_input.uncertainty.array.shape == expected_cube.uncertainty.array.shape
@@ -110,12 +113,12 @@ def assert_cubes_equal(test_input, expected_cube):
110113
assert_extra_coords_equal(test_input.extra_coords, expected_cube.extra_coords)
111114

112115

113-
def assert_cubesequences_equal(test_input, expected_sequence):
116+
def assert_cubesequences_equal(test_input, expected_sequence, check_data=True):
114117
assert isinstance(test_input, type(expected_sequence))
115118
assert_metas_equal(test_input.meta, expected_sequence.meta)
116119
assert test_input._common_axis == expected_sequence._common_axis
117120
for i, cube in enumerate(test_input.data):
118-
assert_cubes_equal(cube, expected_sequence.data[i])
121+
assert_cubes_equal(cube, expected_sequence.data[i], check_data=check_data)
119122

120123

121124
def assert_wcs_are_equal(wcs1, wcs2):
@@ -140,7 +143,12 @@ def assert_wcs_are_equal(wcs1, wcs2):
140143
assert wcs1.world_axis_units == wcs2.world_axis_units
141144
assert_equal(wcs1.axis_correlation_matrix, wcs2.axis_correlation_matrix)
142145
assert wcs1.pixel_bounds == wcs2.pixel_bounds
143-
146+
if wcs1.pixel_shape is not None:
147+
random_idx = np.random.randint(wcs1.pixel_shape,size=[10,wcs1.pixel_n_dim])
148+
# SlicedLowLevelWCS vs BaseHighLevelWCS don't have the same pixel_to_world method
149+
low_level_wcs1 = wcs1.low_level_wcs if isinstance(wcs1, BaseHighLevelWCS) else wcs1
150+
low_level_wcs2 = wcs2.low_level_wcs if isinstance(wcs2, BaseHighLevelWCS) else wcs2
151+
np.testing.assert_array_equal(low_level_wcs1.pixel_to_world_values(*random_idx.T), low_level_wcs2.pixel_to_world_values(*random_idx.T))
144152

145153
def create_sliced_wcs(wcs, item, dim):
146154
"""
@@ -152,15 +160,15 @@ def create_sliced_wcs(wcs, item, dim):
152160
return SlicedFITSWCS(wcs, item)
153161

154162

155-
def assert_collections_equal(collection1, collection2):
163+
def assert_collections_equal(collection1, collection2, check_data=True):
156164
assert collection1.keys() == collection2.keys()
157165
assert collection1.aligned_axes == collection2.aligned_axes
158166
for cube1, cube2 in zip(collection1.values(), collection2.values()):
159167
# Check cubes are same type.
160168
assert type(cube1) is type(cube2)
161169
if isinstance(cube1, NDCube):
162-
assert_cubes_equal(cube1, cube2)
170+
assert_cubes_equal(cube1, cube2, check_data=check_data)
163171
elif isinstance(cube1, NDCubeSequence):
164-
assert_cubesequences_equal(cube1, cube2)
172+
assert_cubesequences_equal(cube1, cube2, check_data=check_data)
165173
else:
166174
raise TypeError(f"Unsupported Type in NDCollection: {type(cube1)}")

0 commit comments

Comments
 (0)