Skip to content

Commit 5acbd35

Browse files
committed
Add more tests of dask-backed arithmetic operations.
1 parent 83c9f3d commit 5acbd35

File tree

2 files changed

+77
-5
lines changed

2 files changed

+77
-5
lines changed

ndcube/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,13 @@ def ndcube_2d_dask(wcs_2d_lt_ln):
832832
return NDCube(da, wcs=wcs_2d_lt_ln, uncertainty=da_uncert, mask=da_mask, unit=u.J)
833833

834834

835+
@pytest.fixture
836+
def nddata_2d_dask(ndcube_2d_dask):
837+
value = astropy.nddata.NDData(ndcube_2d_dask)
838+
value._wcs = None
839+
return value
840+
841+
835842
@pytest.fixture
836843
def ndcube_2d(request):
837844
"""

ndcube/tests/test_ndcube_arithmetic.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import astropy.units as u
55
import astropy.wcs
6+
import dask.array
67
from astropy.nddata import NDData, StdDevUncertainty
78

89
from ndcube import NDCube
@@ -171,14 +172,22 @@ def test_cube_arithmetic_subtract_nddata(ndc, value, expected_kwargs, wcs_2d_lt_
171172
assert_cubes_equal(output_cube, expected_cube, check_uncertainty_values=True)
172173

173174

174-
@pytest.mark.parametrize("value",
175-
[
176-
NDData(np.ones((8, 4)), wcs=None, unit=u.J)
177-
])
178-
def test_cube_dask_arithmetic_subtract_nddata(ndcube_2d_dask, value):
175+
def test_cube_dask_arithmetic_subtract_nddata(ndcube_2d_dask):
179176
ndc = ndcube_2d_dask
177+
value = NDData(np.ones(ndc.data.shape), wcs=None, unit=ndc.unit)
180178
output_cube = ndc - value
181179
assert type(output_cube.data) is type(ndc.data)
180+
assert type(output_cube.uncertainty.array) is type(ndc.uncertainty.array)
181+
assert type(output_cube.mask) is type(ndc.mask)
182+
183+
184+
def test_cube_arithmetic_subtract_nddata_dask(wcs_2d_lt_ln, nddata_2d_dask):
185+
value = nddata_2d_dask
186+
ndc = NDCube(np.ones(value.data.shape), wcs=wcs_2d_lt_ln, unit=value.unit)
187+
output_cube = ndc - value
188+
assert type(output_cube.data) is type(value.data)
189+
assert type(output_cube.uncertainty.array) is type(value.uncertainty.array)
190+
assert type(output_cube.mask) is type(value.mask)
182191

183192

184193
@pytest.mark.parametrize('value', [
@@ -193,6 +202,24 @@ def test_cube_arithmetic_rsubtract(ndcube_2d_ln_lt_units, value):
193202
check_arithmetic_value_and_units(new_cube, value - cube_quantity)
194203

195204

205+
def test_cube_dask_arithmetic_rsubtract_nddata(ndcube_2d_dask):
206+
ndc = ndcube_2d_dask
207+
value = NDData(np.ones(ndc.data.shape), wcs=None, unit=ndc.unit)
208+
output_cube = value - ndc
209+
assert type(output_cube.data) is type(ndc.data)
210+
assert type(output_cube.uncertainty.array) is type(ndc.uncertainty.array)
211+
assert type(output_cube.mask) is type(ndc.mask)
212+
213+
214+
def test_cube_arithmetic_rsubtract_nddata_dask(wcs_2d_lt_ln, nddata_2d_dask):
215+
value = nddata_2d_dask
216+
ndc = NDCube(np.ones(value.data.shape), wcs=wcs_2d_lt_ln, unit=value.unit)
217+
output_cube = value - ndc
218+
assert type(output_cube.data) is type(value.data)
219+
assert type(output_cube.uncertainty.array) is type(value.uncertainty.array)
220+
assert type(output_cube.mask) is type(value.mask)
221+
222+
196223
@pytest.mark.parametrize('value', [
197224
10 * u.ct,
198225
u.Quantity([10], u.ct),
@@ -352,6 +379,24 @@ def test_cube_dask_arithmetic_divide_nddata(ndcube_2d_dask, value):
352379
assert type(output_cube.data) is type(ndc.data)
353380

354381

382+
def test_cube_dask_arithmetic_divide_nddata(ndcube_2d_dask):
383+
ndc = ndcube_2d_dask
384+
value = NDData(np.ones(ndc.data.shape), wcs=None, unit=ndc.unit)
385+
output_cube = ndc / value
386+
assert type(output_cube.data) is type(ndc.data)
387+
assert type(output_cube.uncertainty.array) is type(ndc.uncertainty.array)
388+
assert type(output_cube.mask) is type(ndc.mask)
389+
390+
391+
def test_cube_arithmetic_divide_nddata_dask(wcs_2d_lt_ln, nddata_2d_dask):
392+
value = nddata_2d_dask
393+
ndc = NDCube(np.ones(value.data.shape), wcs=wcs_2d_lt_ln, unit=value.unit)
394+
output_cube = ndc / value
395+
assert type(output_cube.data) is type(value.data)
396+
assert type(output_cube.uncertainty.array) is type(value.uncertainty.array)
397+
assert type(output_cube.mask) is type(value.mask)
398+
399+
355400
@pytest.mark.parametrize('value', [1, 2, -1])
356401
def test_cube_arithmetic_rdivide(ndcube_2d_ln_lt_units, value):
357402
cube_quantity = u.Quantity(ndcube_2d_ln_lt_units.data, ndcube_2d_ln_lt_units.unit)
@@ -370,6 +415,26 @@ def test_cube_arithmetic_rdivide_uncertainty(ndcube_4d_unit_uncertainty, value):
370415
check_arithmetic_value_and_units(new_cube, value / cube_quantity)
371416

372417

418+
def test_cube_dask_arithmetic_rdivide_nddata(ndcube_2d_dask):
419+
ndc = ndcube_2d_dask
420+
value = NDData(np.ones(ndc.data.shape), wcs=None, unit=ndc.unit)
421+
match = "does not support propagation of uncertainties for power. Setting uncertainties to None."
422+
with pytest.warns(NDCubeUserWarning, match=match): # noqa: PT031
423+
with np.errstate(divide='ignore'):
424+
output_cube = value / ndc
425+
assert type(output_cube.data) is type(ndc.data)
426+
assert type(output_cube.mask) is type(ndc.mask)
427+
428+
429+
def test_cube_arithmetic_rdivide_nddata_dask(wcs_2d_lt_ln, nddata_2d_dask):
430+
value = nddata_2d_dask
431+
ndc = NDCube(np.ones(value.data.shape), wcs=wcs_2d_lt_ln, unit=value.unit)
432+
output_cube = value / ndc
433+
assert type(output_cube.data) is type(value.data)
434+
assert type(output_cube.uncertainty.array) is type(value.uncertainty.array)
435+
assert type(output_cube.mask) is type(value.mask)
436+
437+
373438
def test_cube_arithmetic_neg(ndcube_2d_ln_lt_units):
374439
check_arithmetic_value_and_units(
375440
-ndcube_2d_ln_lt_units,

0 commit comments

Comments
 (0)