Skip to content

Commit 941216f

Browse files
committed
Added unit test and broke it
1 parent 70ad715 commit 941216f

File tree

3 files changed

+67
-9
lines changed

3 files changed

+67
-9
lines changed

ndcube/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l():
197197
return WCS(header=header)
198198

199199

200+
@pytest.fixture
201+
def wcs_3d_wave_lt_ln():
202+
header = {
203+
'CTYPE1': 'WAVE ',
204+
'CUNIT1': 'Angstrom',
205+
'CDELT1': 0.2,
206+
'CRPIX1': 0,
207+
'CRVAL1': 10,
208+
209+
'CTYPE2': 'HPLT-TAN',
210+
'CUNIT2': 'deg',
211+
'CDELT2': 0.5,
212+
'CRPIX2': 2,
213+
'CRVAL2': 0.5,
214+
215+
'CTYPE3': 'HPLN-TAN ',
216+
'CUNIT3': 'deg',
217+
'CDELT3': 0.4,
218+
'CRPIX3': 2,
219+
'CRVAL3': 1,
220+
}
221+
return WCS(header=header)
222+
223+
200224
@pytest.fixture
201225
def wcs_2d_lt_ln():
202226
spatial = {
@@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d):
445469
return cube
446470

447471

472+
@pytest.fixture
473+
def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln):
474+
shape = (3, 4, 5)
475+
wcs_3d_wave_lt_ln.array_shape = shape
476+
data = data_nd(shape)
477+
mask = data > 0
478+
cube = NDCube(
479+
data,
480+
wcs_3d_wave_lt_ln,
481+
mask=mask,
482+
uncertainty=data,
483+
)
484+
base_time = Time('2000-01-01', format='fits', scale='utc')
485+
timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])])
486+
cube.extra_coords.add('time', 0, timestamps)
487+
return cube
488+
489+
448490
@pytest.fixture
449491
def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
450492
data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]],

ndcube/ndcube.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,17 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
448448
# This bypasses the entire rest of the function below which works out the full set of coordinates
449449
# This only works for WCS that have the same number of world and pixel dimensions
450450
if needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1:
451-
lims = (-0.5, self.data.shape[::-1][needed_axes[0]] + 1) if pixel_corners else (0, self.data.shape[::-1][needed_axes[0]])
452-
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axes][0]]
451+
# Account for non-pixel axes affecting the value of needed_axes
452+
# Only works for one axis
453+
if np.max(wcs.axis_correlation_matrix[needed_axes][0].shape) == needed_axes[0]:
454+
needed_axis = needed_axes[0] - 1
455+
else:
456+
needed_axis = needed_axes[0]
457+
lims = (-0.5, self.data.shape[::-1][needed_axis] + 1) if pixel_corners else (0, self.data.shape[::-1][needed_axis])
458+
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axis]]
453459
world_coords = wcs.pixel_to_world_values(*indices)
454460
if units:
455-
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes[0]])
461+
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axis])
456462
return world_coords
457463

458464
# Create a meshgrid of all pixel coordinates.
@@ -546,23 +552,19 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
546552
# Docstring in NDCubeABC.
547553
if isinstance(wcs, BaseHighLevelWCS):
548554
wcs = wcs.low_level_wcs
549-
550555
orig_wcs = wcs
551556
if isinstance(wcs, ExtraCoords):
552557
wcs = wcs.wcs
553-
558+
if not wcs:
559+
return ()
554560
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
555-
556561
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True)
557-
558562
world_axis_physical_types = wcs.world_axis_physical_types
559-
560563
# If user has supplied axes, extract only the
561564
# world coords that correspond to those axes.
562565
if axes:
563566
axes_coords = [axes_coords[i] for i in world_indices]
564567
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])
565-
566568
# Return in array order.
567569
# First replace characters in physical types forbidden for namedtuple identifiers.
568570
identifiers = []

ndcube/tests/test_ndcube.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,20 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):
235235
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
236236

237237

238+
def test_axis_world_coords_crazy(ndcube_3d_wave_lt_ln_ec_time):
239+
# This replicates a specific NDCube test in the visualization.rst
240+
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
241+
assert len(coords) == 1
242+
assert isinstance(coords[0], Time)
243+
assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))
244+
245+
# This fails and returns the wrong coords
246+
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
247+
assert len(coords) == 1
248+
assert isinstance(coords.time, Time)
249+
assert np.all(coords.time == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))
250+
251+
238252
@pytest.mark.parametrize("axes", [[-1], [2], ["em"]])
239253
def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
240254

0 commit comments

Comments
 (0)