Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/892.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`~ndcube.NDCube` now accepts ``global_coords=`` and ``extra_coords=`` in the constructor of the class.
85 changes: 37 additions & 48 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
except ImportError:
pass

COPY = object()


class NDCubeABC(astropy.nddata.NDDataBase):

Expand Down Expand Up @@ -381,25 +379,27 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin):
_global_coords = NDCubeLinkedDescriptor(GlobalCoords)

def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None,
unit=None, copy=False, **kwargs):
unit=None, copy=False, psf=None, *, extra_coords=None, global_coords=None, **kwargs):

super().__init__(data, wcs=wcs, uncertainty=uncertainty, mask=mask,
meta=meta, unit=unit, copy=copy, **kwargs)
meta=meta, unit=unit, copy=copy, psf=psf, **kwargs)

# Enforce that the WCS object is not None
if self.wcs is None:
raise TypeError("The WCS argument can not be None.")

# Get existing extra_coords if initializing from an NDCube
if hasattr(data, "extra_coords"):
if extra_coords is None and getattr(data, "extra_coords", None) is not None:
extra_coords = data.extra_coords
if extra_coords is not None:
if copy:
extra_coords = deepcopy(extra_coords)
self._extra_coords = extra_coords

# Get existing global_coords if initializing from an NDCube
if hasattr(data, "global_coords"):
if global_coords is None and getattr(data, "global_coords", None) is not None:
global_coords = data._global_coords
if global_coords is not None:
if copy:
global_coords = deepcopy(global_coords)
self._global_coords = global_coords
Expand Down Expand Up @@ -1465,24 +1465,28 @@ def fill_masked(self, fill_value, uncertainty_fill_value=None, unmask=False, fil

def to_nddata(self,
*,
data=COPY,
wcs=COPY,
uncertainty=COPY,
mask=COPY,
unit=COPY,
meta=COPY,
psf=COPY,
extra_coords=COPY,
global_coords=COPY,
data=True,
wcs=True,
uncertainty=True,
mask=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use mask=True here as True is a valid user input for mask.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's so freaking stupid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution to this is that the one user who wants to do this uses np.True_ rather than True to set the whole mask to be a single boolean.

unit=True,
meta=True,
psf=True,
nddata_type=NDData,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only remaining question is if we want to add a copy=False kwarg here which changes the copy behaviour to be copy by value rather that the default of copy-by-reference?

**kwargs,
):
"""
Constructs new type instance with the same attribute values as this `~ndcube.NDCube`.

Attribute values can be altered on the output object by setting a kwarg with the new
value, e.g. ``data=new_data``.
Any attributes not supported by the new class (``nddata_type``), will be discarded.
Constructs a new `~astropy.nddata.NDData` instance from this object.

By default all known ``NDData`` attributes are copied (by reference) from
this object, values can be altered on the output object by
setting a kwarg with the new value, e.g. ``data=new_data``.
Custom attributes on this class can be passed by setting that
keyword to `True`, for example ``mycube.to_nddata(spam=True)``
is the equivalent of setting
``mycube.to_nddata(spam=mycube.spam)``.
Any attributes not supported by the new class
(``nddata_type``), will be discarded.

Parameters
----------
Expand All @@ -1500,16 +1504,11 @@ def to_nddata(self,
Metadata object of new instance. Default is to use data of this instance.
psf: Any, optional
PSF object of new instance. Default is to use data of this instance.
extra_coords: `ndcube.ExtraCoordsABC`, optional
Extra coords object of new instance. Default is to use data of this instance.
global_coords: `ndcube.GlobalCoordsABC`, optional
WCS object of new instance. Default is to use data of this instance.
nddata_type: Any, optional
The type of the returned object. Must be a subclass of `~astropy.nddata.NDData`
or a class that behaves like one. Default=`~astropy.nddata.NDData`.
kwargs:
Additional inputs to the ``nddata_type`` constructor that should differ from,
or are not represented by, the attributes of this instance. For example, to
Additional inputs to the ``nddata_type`` constructor. For example, to
set different data values on the returned object, set a kwarg ``data=new_data``,
where ``new_data`` is an array of compatible shape and dtype. Note that kwargs
given by the user and attributes on this instance that are not supported by the
Expand All @@ -1525,39 +1524,29 @@ def to_nddata(self,
Examples
--------
To create an `~astropy.nddata.NDData` instance which is a copy of an `~ndcube.NDCube`
(called ``cube``) without a WCS, do:
(called ``cube``) without a WCS, do::

>>> nddata_without_coords = cube.to_nddata(wcs=None) # doctest: +SKIP

To create a new `~ndcube.NDCube` instance which is a copy of
an `~ndcube.NDCube` (called ``cube``) without an uncertainty,
but with ``global_coords`` and ``extra_coords`` do::

>>> nddata_without_coords = cube.to_nddata(uncertainty=None, global_coords=True, extra_coords=True) # doctest: +SKIP
"""
# Build dictionary of new attribute values from this NDCube instance
# and update with user-defined kwargs. Remove any kwargs not set by user.
# Put all NDData kwargs in a dict
user_kwargs = {"data": data,
"wcs": wcs,
"uncertainty": uncertainty,
"mask": mask,
"unit": unit,
"meta": meta,
"psf": psf,
"extra_coords": extra_coords,
"global_coords": global_coords}
user_kwargs = {key: value for key, value in user_kwargs.items() if value is not COPY}
user_kwargs.update(kwargs)
all_kwargs = {key.strip("_"): value for key, value in self.__dict__.items()}
all_kwargs.update(user_kwargs)
# Inspect call signature of new_nddata class and
# remove unsupported items from new_kwargs.
all_kwargs = {key: value for key, value in all_kwargs.items()
if key in inspect.signature(nddata_type).parameters.keys()}
**kwargs}
# If any are True then copy by reference
user_kwargs = {key: getattr(self, key) if value is True else value for key, value in user_kwargs.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
user_kwargs = {key: getattr(self, key) if value is True else value for key, value in user_kwargs.items()}
user_kwargs.update(kwargs)
nddata_sig = inspect.signature(nddata_type).parameters.keys()
user_kwargs = {key: value for key, value in user_kwargs.items() if key in nddata_sig and value is not COPY}
all_kwargs = {key.strip("_"): value for key, value in self.__dict__.items() if key in nddata_sig}
all_kwargs.update(user_kwargs)

# Construct and return new instance.
new_nddata = nddata_type(**all_kwargs)
if isinstance(new_nddata, NDCubeBase):
if extra_coords is COPY:
extra_coords = copy.copy(self._extra_coords)
extra_coords._ndcube = new_nddata
new_nddata._extra_coords = extra_coords
if global_coords is COPY:
new_nddata._global_coords = copy.copy(self._global_coords)
return new_nddata
return nddata_type(**user_kwargs)


def _create_masked_array_for_rebinning(data, mask, operation_ignores_mask):
Expand Down
35 changes: 34 additions & 1 deletion ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def test_initialize_from_ndcube(ndcube_3d_l_ln_lt_ectime):
assert ec is not ec3


def test_initialize_with_extra_global_coords(ndcube_3d_ln_lt_l_ec_all_axes):
ndc = ndcube_3d_ln_lt_l_ec_all_axes[:, :, 0]
data = ndc.data
wcs = ndc.wcs
ec = ndc.extra_coords
gc = ndc.global_coords

new_cube = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc)
assert new_cube.extra_coords is ec
assert new_cube.global_coords is gc

new_cube_copy = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc, copy=True)
assert new_cube_copy.extra_coords is not ec
assert new_cube_copy.global_coords is not gc
helpers.assert_extra_coords_equal(new_cube_copy.extra_coords, ec)
helpers.assert_global_coords_equal(new_cube_copy.global_coords, gc)


def test_wcs_type_after_init(ndcube_3d_ln_lt_l, wcs_3d_l_lt_ln):
# Generate a low level WCS
slices = np.s_[:, :, 0]
Expand Down Expand Up @@ -254,8 +272,23 @@ def test_to_nddata_type_ndcube(ndcube_2d_ln_lt_uncert_ec):
ndc = ndcube_2d_ln_lt_uncert_ec
ndc.global_coords.add("wavelength", "em.wl", 100*u.nm)
new_data = ndc.data * 2
output = ndc.to_nddata(data=new_data, nddata_type=NDCube)
output = ndc.to_nddata(data=new_data, extra_coords=True, global_coords=True, nddata_type=NDCube)
assert type(output) is NDCube
assert (output.data == new_data).all()
helpers.assert_extra_coords_equal(output.extra_coords, ndc.extra_coords)
helpers.assert_global_coords_equal(output.global_coords, ndc.global_coords)


def test_custom_tonddata_type(ndcube_2d_ln_lt):
ndc = ndcube_2d_ln_lt
ndc.spam = "Eggs"

class MyNDData(astropy.nddata.NDData):
def __init__(self, data, *, spam=None, **kwargs):
super().__init__(data, **kwargs)
self.spam = spam

new_ndd = ndc.to_nddata(spam=True, nddata_type=MyNDData)
assert new_ndd.spam == "Eggs"
assert new_ndd.data is ndc.data
assert new_ndd.wcs is ndc.wcs