Skip to content

Commit 20bd4d4

Browse files
authored
Merge pull request #874 from ayshih/crop_fix
Fixed cropping when a world point is exactly on a pixel edge
2 parents 86fb80e + 57de6b4 commit 20bd4d4

File tree

6 files changed

+205
-70
lines changed

6 files changed

+205
-70
lines changed

changelog/874.breaking.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make `ndcube.NDCube.crop` exclude rightward pixel when upper limit determined from world points falls exactly on a pixel edge.

changelog/874.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Prevent `~ndcube.NDCube.crop` cropping array axes of a cube to length 0 when: 1, an input point is below the extent of the cube due to misinterpreting negative array indices; 2, all point lie above the extent of the cube.

ndcube/ndcube.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def crop(self,
209209
in the data array in real world coordinates.
210210
211211
The coordinates of the points as they are passed to
212-
`~astropy.wcs.wcsapi.BaseHighLevelWCS.world_to_array_index`.
212+
`~astropy.wcs.wcsapi.BaseHighLevelWCS.world_to_pixel`.
213213
Therefore their number and order must be compatible with the API
214214
of that method, i.e. they must be passed in world order.
215215
@@ -258,7 +258,7 @@ def crop_by_values(self,
258258
points: iterable
259259
Tuples of coordinate values, the length of the tuples must be
260260
equal to the number of world dimensions. These points are
261-
passed to ``wcs.world_to_array_index_values`` so their units
261+
passed to ``wcs.world_to_pixel_values`` so their units
262262
and order must be compatible with that method.
263263
264264
units: `str` or `~astropy.units.Unit`
@@ -647,7 +647,8 @@ def _get_crop_item(self, *points, wcs=None, keepdims=False):
647647
raise TypeError(f"{type(value)} of component {j} in point {i} is "
648648
f"incompatible with WCS component {comp[j]} "
649649
f"{classes[j]}.")
650-
return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims=keepdims)
650+
return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims=keepdims,
651+
original_shape=self.data.shape)
651652

652653
def crop_by_values(self, *points, units=None, wcs=None, keepdims=False):
653654
# The docstring is defined in NDCubeABC
@@ -689,7 +690,8 @@ def _get_crop_by_values_item(self, *points, units=None, wcs=None, keepdims=False
689690
raise UnitsError(f"Unit '{points[i][j].unit}' of coordinate object {j} in point {i} is "
690691
f"incompatible with WCS unit '{wcs.world_axis_units[j]}'") from err
691692

692-
return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims=keepdims)
693+
return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims=keepdims,
694+
original_shape=self.data.shape)
693695

694696
def __str__(self):
695697
return textwrap.dedent(f"""\

ndcube/tests/test_ndcube_slice_and_crop.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,55 @@ def test_crop_1d():
544544
output = cube.crop((7*u.nm,), (15*u.nm,))
545545

546546
helpers.assert_cubes_equal(output, expected)
547+
548+
549+
@pytest.mark.filterwarnings("ignore::Warning")
550+
@pytest.mark.parametrize(("points", "expected_slice", "crop_by_values", "keepdims"),
551+
[
552+
(((15*u.m,), (45*u.m,)), np.s_[1:4], False, False), # A range starting and ending at different pixel edges
553+
(((15*u.m,), (45*u.m,)), np.s_[1:4], True, False),
554+
(((15*u.m,)), np.s_[1:2], False, True), # A range starting and ending on same pixel edge.
555+
(((15*u.m,)), np.s_[1:2], True, True),
556+
(((5*u.m,)), np.s_[0:1], False, True), # A range starting and ending at the exact start of the cube extent.
557+
(((5*u.m,)), np.s_[0:1], True, True),
558+
(((104*u.m,)), np.s_[9:10], False, True), # A range starting and ending slightly below the end of cube extent.
559+
(((104*u.m,)), np.s_[9:10], True, True),
560+
(((1*u.m,), (40*u.m,)), np.s_[:4], False, False), # A range starting below cube extent.
561+
(((1*u.m,), (40*u.m,)), np.s_[:4], True, False),
562+
(((15*u.m,), (200*u.m,)), np.s_[1:], False, False), # A range ending above cube extent.
563+
(((15*u.m,), (200*u.m,)), np.s_[1:], True, False),
564+
])
565+
def test_crop_at_pixel_edges(points, expected_slice, crop_by_values, keepdims):
566+
wcs = astropy.wcs.WCS(naxis=1)
567+
wcs.wcs.ctype = 'WAVE',
568+
wcs.wcs.cunit = 'm',
569+
wcs.wcs.cdelt = 10,
570+
wcs.wcs.crpix = 1,
571+
wcs.wcs.crval = 10,
572+
cube = NDCube(np.arange(10), wcs=wcs)
573+
574+
expected = cube[expected_slice]
575+
576+
output = cube.crop_by_values(*points, keepdims=keepdims) if crop_by_values else cube.crop(*points, keepdims=keepdims)
577+
578+
helpers.assert_cubes_equal(output, expected)
579+
580+
581+
@pytest.mark.filterwarnings("ignore::Warning")
582+
@pytest.mark.parametrize("points",
583+
[
584+
((1*u.m,),),
585+
((105*u.m,),), # Exactly at the end of the cube extent.
586+
((200*u.m,),),
587+
])
588+
def test_crop_all_points_beyond_cube_extent_error(points):
589+
wcs = astropy.wcs.WCS(naxis=1)
590+
wcs.wcs.ctype = 'WAVE',
591+
wcs.wcs.cunit = 'm',
592+
wcs.wcs.cdelt = 10,
593+
wcs.wcs.crpix = 1,
594+
wcs.wcs.crval = 10,
595+
cube = NDCube(np.arange(10), wcs=wcs)
596+
597+
with pytest.raises(ValueError, match="are outside the range of the NDCube being cropped"):
598+
cube.crop(*points, keepdims=True)

ndcube/utils/cube.py

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS
99

1010
from ndcube.utils import wcs as wcs_utils
11+
from ndcube.utils.exceptions import warn_user
1112

1213
__all__ = [
1314
"get_crop_item_from_points",
@@ -106,7 +107,7 @@ def sanitize_crop_inputs(points, wcs):
106107
return False, points, wcs
107108

108109

109-
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
110+
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims, original_shape):
110111
"""
111112
Find slice item that crops to minimum cube in array-space containing specified world points.
112113
@@ -130,84 +131,126 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
130131
keepdims : `bool`
131132
If `False`, returned item will drop length-1 dimensions otherwise, item will keep length-1 dimensions.
132133
134+
original_shape: `tuple` of `int`
135+
The shape of the data cube before cropping.
136+
133137
Returns
134138
-------
135139
item : `tuple` of `slice`
136140
The slice item for each axis of the cube which, when applied to the cube,
137141
will return the minimum cube in array-index-space that contains all the
138142
input world points.
139143
"""
140-
# Define a list of lists to hold the array indices of the points
141-
# where each inner list gives the index of all points for that array axis.
142-
combined_points_array_idx = [[]] * wcs.pixel_n_dim
144+
# Define a list of lists to hold the pixel coordinates of the points
145+
# where each inner list gives the pixel coordinates of all points for that pixel axis.
146+
# Recall that pixel axis ordering is reversed compared to array axis ordering.
147+
combined_points_pixel_idx = [[]] * wcs.pixel_n_dim
143148
high_level_wcs = HighLevelWCSWrapper(wcs) if isinstance(wcs, BaseLowLevelWCS) else wcs
144149
low_level_wcs = high_level_wcs.low_level_wcs
145150
# For each point compute the corresponding array indices.
146151
for point in points:
147-
# Get the arrays axes associated with each element in point.
148-
if crop_by_values:
149-
point_inputs_array_axes = []
150-
for i in range(low_level_wcs.world_n_dim):
151-
pix_axes = np.array(
152-
wcs_utils.world_axis_to_pixel_axes(i, low_level_wcs.axis_correlation_matrix))
153-
point_inputs_array_axes.append(tuple(
154-
wcs_utils.convert_between_array_and_pixel_axes(pix_axes, low_level_wcs.pixel_n_dim)))
155-
point_inputs_array_axes = tuple(point_inputs_array_axes)
156-
else:
157-
point_inputs_array_axes = wcs_utils.array_indices_for_world_objects(high_level_wcs)
158-
# Get indices of array axes which correspond to only None inputs in point
152+
# Get the pixel axes associated with each element in point.
153+
point_inputs_pixel_axes = (
154+
tuple(wcs_utils.world_axis_to_pixel_axes(i, low_level_wcs.axis_correlation_matrix)
155+
for i in range(low_level_wcs.world_n_dim)) if crop_by_values
156+
else wcs_utils.pixel_indices_for_world_objects(high_level_wcs))
157+
# Get indices of pixel axes which correspond to only None inputs in point
159158
# as well as those that correspond to a coord.
160159
point_indices_with_inputs = []
161-
array_axes_with_input = []
160+
pixel_axes_with_input = []
162161
for i, coord in enumerate(point):
163162
if coord is not None:
164163
point_indices_with_inputs.append(i)
165-
array_axes_with_input.append(point_inputs_array_axes[i])
166-
array_axes_with_input = set(chain.from_iterable(array_axes_with_input))
167-
array_axes_without_input = set(range(low_level_wcs.pixel_n_dim)) - array_axes_with_input
164+
pixel_axes_with_input.append(point_inputs_pixel_axes[i])
165+
pixel_axes_with_input = set(chain.from_iterable(pixel_axes_with_input))
166+
pixel_axes_without_input = set(range(low_level_wcs.pixel_n_dim)) - pixel_axes_with_input
167+
pixel_axes_with_input = np.array(list(pixel_axes_with_input))
168+
pixel_axes_without_input = np.array(list(pixel_axes_without_input))
168169
# Slice out the axes that do not correspond to a coord
169170
# from the WCS and the input point.
170-
if len(array_axes_without_input) > 0:
171+
if len(pixel_axes_without_input) > 0:
172+
array_axes_without_input = wcs_utils.convert_between_array_and_pixel_axes(
173+
pixel_axes_without_input, low_level_wcs.pixel_n_dim)
171174
wcs_slice = np.array([slice(None)] * low_level_wcs.pixel_n_dim)
172-
wcs_slice[np.array(list(array_axes_without_input))] = 0
175+
wcs_slice[array_axes_without_input] = 0
173176
sliced_wcs = SlicedLowLevelWCS(low_level_wcs, slices=tuple(wcs_slice))
174177
sliced_point = np.array(point, dtype=object)[np.array(point_indices_with_inputs)]
175178
else:
176179
# Else, if all axes have at least one crop input, no need to slice the WCS.
177180
sliced_wcs, sliced_point = low_level_wcs, np.array(point, dtype=object)
178-
# Derive the array indices of the input point and place each index
181+
# Derive the pixel indices of the input point and place each index
179182
# in the list corresponding to its axis.
180-
if crop_by_values:
181-
point_array_indices = sliced_wcs.world_to_array_index_values(*sliced_point)
182-
# If returned value is a 0-d array, convert to a length-1 tuple.
183-
if isinstance(point_array_indices, np.ndarray) and point_array_indices.ndim == 0:
184-
point_array_indices = (point_array_indices.item(),)
185-
else:
186-
# Convert from scalar arrays to scalars
187-
point_array_indices = tuple(a.item() for a in point_array_indices)
188-
else:
189-
point_array_indices = HighLevelWCSWrapper(sliced_wcs).world_to_array_index(
190-
*sliced_point)
191-
# If returned value is a 0-d array, convert to a length-1 tuple.
192-
if isinstance(point_array_indices, np.ndarray) and point_array_indices.ndim == 0:
193-
point_array_indices = (point_array_indices.item(),)
194-
for axis, index in zip(array_axes_with_input, point_array_indices):
195-
combined_points_array_idx[axis] = combined_points_array_idx[axis] + [index]
196-
# Define slice item with which to slice cube.
183+
# Use the to_pixel methods to preserve fractional indices for future rounding.
184+
point_pixel_indices = (sliced_wcs.world_to_pixel_values(*sliced_point) if crop_by_values
185+
else HighLevelWCSWrapper(sliced_wcs).world_to_pixel(*sliced_point))
186+
# For each pixel axis associated with this point, place the pixel coords for
187+
# that pixel axis into the corresponding list within combined_points_pixel_idx.
188+
if sliced_wcs.pixel_n_dim == 1:
189+
point_pixel_indices = (point_pixel_indices,)
190+
for axis, index in zip(pixel_axes_with_input, point_pixel_indices):
191+
combined_points_pixel_idx[axis] = combined_points_pixel_idx[axis] + [index]
192+
193+
# Iterate through each array axis to determine the min and max pixel coords
194+
# and then convert to array indices. Note that combined_points_pixel_idx holds the
195+
# pixel coords for each pixel axis. Therefore, to iterate in array axis order,
196+
# combined_points_pixel_idx must be reversed.
197197
item = []
198+
ambiguous = False
199+
message = ""
198200
result_is_scalar = True
199-
for axis_indices in combined_points_array_idx:
200-
if axis_indices == []:
201+
for array_axis, pixel_coords in enumerate(combined_points_pixel_idx[::-1]):
202+
if pixel_coords == []:
201203
result_is_scalar = False
202204
item.append(slice(None))
203205
else:
204-
min_idx = min(axis_indices)
205-
max_idx = max(axis_indices) + 1
206-
if max_idx - min_idx == 1 and not keepdims:
207-
item.append(min_idx)
206+
# Calculate the index of the array element containing the pixel coordinate.
207+
# Note that integer pixel coordinates correspond to the pixel center,
208+
# while integer array indices correspond to lower edge of desired array element.
209+
# Therefore a shift of 0.5 is required in the conversion.
210+
# The max idx conversion below will discard right-ward array element if
211+
# max pixel coord corresponds to a pixel edge.
212+
min_array_idx = int(np.floor(min(pixel_coords) + 0.5))
213+
max_array_idx = int(np.ceil(max(pixel_coords) - 0.5)) + 1
214+
# Raise error if indices all lie below or all lie above array axis's extent.
215+
# Exception: min_array_idx == max_array_idx == 0 is allowed because max_array_idx
216+
# will be later changed to 1.
217+
if (min_array_idx < 0 and max_array_idx <= 0) or min_array_idx >= original_shape[array_axis]:
218+
raise ValueError(f"All world points associated with array axis {array_axis}"
219+
" are outside the range of the NDCube being cropped.")
220+
# world_to_array_index uses negative indices to represent locations to the left
221+
# of the 0th pixel, while python slicing uses them to count backwards from the
222+
# last element in the array. Therefore, set negative indices to 0.
223+
# Note that we've already checked that the max pixel_coord is >= 0.
224+
# Also note that there's no need to clip the max array idx, as values above
225+
# the array extent does not cause ambiguity in the slicing so long as the
226+
# min array idx is below that upper extent, which has also already been checked
227+
# by the above error.
228+
if min_array_idx < 0:
229+
min_array_idx = 0
230+
# Due to the above calculation, the above min and max array indices can only be
231+
# same if the original pixel coords correspond to the same pixel edge.
232+
# If this is the case, increment the max array index by 1 so the rightward array
233+
# element is kept. Also, build a warning message about this to be raised later.
234+
if min_array_idx == max_array_idx:
235+
ambiguous = True
236+
max_array_idx += 1
237+
if min_array_idx == 0:
238+
message += (f"All input points corresponding to array axis {array_axis} lie on "
239+
"the lower boundary of array element 0 (the first element). "
240+
"The cropped NDCube will only include array element 0.\n")
241+
else:
242+
message += (f"All input points corresponding to array axis {array_axis} lie on "
243+
f"the boundary between array elements {min_array_idx - 1} and "
244+
f"{min_array_idx}. The cropped NDCube will only include array "
245+
f"element {min_array_idx}.\n")
246+
if max_array_idx - min_array_idx == 1 and not keepdims:
247+
item.append(min_array_idx)
208248
else:
209-
item.append(slice(min_idx, max_idx))
249+
item.append(slice(min_array_idx, max_array_idx))
210250
result_is_scalar = False
251+
# Raise warning if all world values for any array axes correspond to a pixel edge.
252+
if ambiguous:
253+
warn_user(message)
211254
# If item will result in a scalar cube, raise an error as this is not currently supported.
212255
if result_is_scalar:
213256
raise ValueError("Input points causes cube to be cropped to a single pixel. "

0 commit comments

Comments
 (0)