From c4bebb6b2d28768e9f0b6c871ea6a4811bd0ee6e Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 9 Oct 2025 15:02:48 +0100 Subject: [PATCH] Simplify NDCube.to_nddata --- changelog/892.feature.rst | 1 + ndcube/ndcube.py | 85 ++++++++++++++++--------------------- ndcube/tests/test_ndcube.py | 35 ++++++++++++++- 3 files changed, 72 insertions(+), 49 deletions(-) create mode 100644 changelog/892.feature.rst diff --git a/changelog/892.feature.rst b/changelog/892.feature.rst new file mode 100644 index 000000000..528d9059d --- /dev/null +++ b/changelog/892.feature.rst @@ -0,0 +1 @@ +`~ndcube.NDCube` now accepts ``global_coords=`` and ``extra_coords=`` in the constructor of the class. diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index 2d672329b..ede144880 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -48,8 +48,6 @@ except ImportError: pass -COPY = object() - class NDCubeABC(astropy.nddata.NDDataBase): @@ -381,25 +379,27 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin): _global_coords = NDCubeLinkedDescriptor(GlobalCoords) def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None, - unit=None, copy=False, **kwargs): + unit=None, copy=False, psf=None, *, extra_coords=None, global_coords=None, **kwargs): super().__init__(data, wcs=wcs, uncertainty=uncertainty, mask=mask, - meta=meta, unit=unit, copy=copy, **kwargs) + meta=meta, unit=unit, copy=copy, psf=psf, **kwargs) # Enforce that the WCS object is not None if self.wcs is None: raise TypeError("The WCS argument can not be None.") # Get existing extra_coords if initializing from an NDCube - if hasattr(data, "extra_coords"): + if extra_coords is None and getattr(data, "extra_coords", None) is not None: extra_coords = data.extra_coords + if extra_coords is not None: if copy: extra_coords = deepcopy(extra_coords) self._extra_coords = extra_coords # Get existing global_coords if initializing from an NDCube - if hasattr(data, "global_coords"): + if global_coords is None and getattr(data, "global_coords", None) is not None: global_coords = data._global_coords + if global_coords is not None: if copy: global_coords = deepcopy(global_coords) self._global_coords = global_coords @@ -1465,24 +1465,28 @@ def fill_masked(self, fill_value, uncertainty_fill_value=None, unmask=False, fil def to_nddata(self, *, - data=COPY, - wcs=COPY, - uncertainty=COPY, - mask=COPY, - unit=COPY, - meta=COPY, - psf=COPY, - extra_coords=COPY, - global_coords=COPY, + data=True, + wcs=True, + uncertainty=True, + mask=True, + unit=True, + meta=True, + psf=True, nddata_type=NDData, **kwargs, ): """ - Constructs new type instance with the same attribute values as this `~ndcube.NDCube`. - - Attribute values can be altered on the output object by setting a kwarg with the new - value, e.g. ``data=new_data``. - Any attributes not supported by the new class (``nddata_type``), will be discarded. + Constructs a new `~astropy.nddata.NDData` instance from this object. + + By default all known ``NDData`` attributes are copied (by reference) from + this object, values can be altered on the output object by + setting a kwarg with the new value, e.g. ``data=new_data``. + Custom attributes on this class can be passed by setting that + keyword to `True`, for example ``mycube.to_nddata(spam=True)`` + is the equivalent of setting + ``mycube.to_nddata(spam=mycube.spam)``. + Any attributes not supported by the new class + (``nddata_type``), will be discarded. Parameters ---------- @@ -1500,16 +1504,11 @@ def to_nddata(self, Metadata object of new instance. Default is to use data of this instance. psf: Any, optional PSF object of new instance. Default is to use data of this instance. - extra_coords: `ndcube.ExtraCoordsABC`, optional - Extra coords object of new instance. Default is to use data of this instance. - global_coords: `ndcube.GlobalCoordsABC`, optional - WCS object of new instance. Default is to use data of this instance. nddata_type: Any, optional The type of the returned object. Must be a subclass of `~astropy.nddata.NDData` or a class that behaves like one. Default=`~astropy.nddata.NDData`. kwargs: - Additional inputs to the ``nddata_type`` constructor that should differ from, - or are not represented by, the attributes of this instance. For example, to + Additional inputs to the ``nddata_type`` constructor. For example, to set different data values on the returned object, set a kwarg ``data=new_data``, where ``new_data`` is an array of compatible shape and dtype. Note that kwargs given by the user and attributes on this instance that are not supported by the @@ -1525,12 +1524,17 @@ def to_nddata(self, Examples -------- To create an `~astropy.nddata.NDData` instance which is a copy of an `~ndcube.NDCube` - (called ``cube``) without a WCS, do: + (called ``cube``) without a WCS, do:: >>> nddata_without_coords = cube.to_nddata(wcs=None) # doctest: +SKIP + + To create a new `~ndcube.NDCube` instance which is a copy of + an `~ndcube.NDCube` (called ``cube``) without an uncertainty, + but with ``global_coords`` and ``extra_coords`` do:: + + >>> nddata_without_coords = cube.to_nddata(uncertainty=None, global_coords=True, extra_coords=True) # doctest: +SKIP """ - # Build dictionary of new attribute values from this NDCube instance - # and update with user-defined kwargs. Remove any kwargs not set by user. + # Put all NDData kwargs in a dict user_kwargs = {"data": data, "wcs": wcs, "uncertainty": uncertainty, @@ -1538,26 +1542,11 @@ def to_nddata(self, "unit": unit, "meta": meta, "psf": psf, - "extra_coords": extra_coords, - "global_coords": global_coords} - user_kwargs = {key: value for key, value in user_kwargs.items() if value is not COPY} - user_kwargs.update(kwargs) - all_kwargs = {key.strip("_"): value for key, value in self.__dict__.items()} - all_kwargs.update(user_kwargs) - # Inspect call signature of new_nddata class and - # remove unsupported items from new_kwargs. - all_kwargs = {key: value for key, value in all_kwargs.items() - if key in inspect.signature(nddata_type).parameters.keys()} + **kwargs} + # If any are True then copy by reference + user_kwargs = {key: getattr(self, key) if value is True else value for key, value in user_kwargs.items()} # Construct and return new instance. - new_nddata = nddata_type(**all_kwargs) - if isinstance(new_nddata, NDCubeBase): - if extra_coords is COPY: - extra_coords = copy.copy(self._extra_coords) - extra_coords._ndcube = new_nddata - new_nddata._extra_coords = extra_coords - if global_coords is COPY: - new_nddata._global_coords = copy.copy(self._global_coords) - return new_nddata + return nddata_type(**user_kwargs) def _create_masked_array_for_rebinning(data, mask, operation_ignores_mask): diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index 0e1e74eda..d6322a22b 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -44,6 +44,24 @@ def test_initialize_from_ndcube(ndcube_3d_l_ln_lt_ectime): assert ec is not ec3 +def test_initialize_with_extra_global_coords(ndcube_3d_ln_lt_l_ec_all_axes): + ndc = ndcube_3d_ln_lt_l_ec_all_axes[:, :, 0] + data = ndc.data + wcs = ndc.wcs + ec = ndc.extra_coords + gc = ndc.global_coords + + new_cube = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc) + assert new_cube.extra_coords is ec + assert new_cube.global_coords is gc + + new_cube_copy = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc, copy=True) + assert new_cube_copy.extra_coords is not ec + assert new_cube_copy.global_coords is not gc + helpers.assert_extra_coords_equal(new_cube_copy.extra_coords, ec) + helpers.assert_global_coords_equal(new_cube_copy.global_coords, gc) + + def test_wcs_type_after_init(ndcube_3d_ln_lt_l, wcs_3d_l_lt_ln): # Generate a low level WCS slices = np.s_[:, :, 0] @@ -254,8 +272,23 @@ def test_to_nddata_type_ndcube(ndcube_2d_ln_lt_uncert_ec): ndc = ndcube_2d_ln_lt_uncert_ec ndc.global_coords.add("wavelength", "em.wl", 100*u.nm) new_data = ndc.data * 2 - output = ndc.to_nddata(data=new_data, nddata_type=NDCube) + output = ndc.to_nddata(data=new_data, extra_coords=True, global_coords=True, nddata_type=NDCube) assert type(output) is NDCube assert (output.data == new_data).all() helpers.assert_extra_coords_equal(output.extra_coords, ndc.extra_coords) helpers.assert_global_coords_equal(output.global_coords, ndc.global_coords) + + +def test_custom_tonddata_type(ndcube_2d_ln_lt): + ndc = ndcube_2d_ln_lt + ndc.spam = "Eggs" + + class MyNDData(astropy.nddata.NDData): + def __init__(self, data, *, spam=None, **kwargs): + super().__init__(data, **kwargs) + self.spam = spam + + new_ndd = ndc.to_nddata(spam=True, nddata_type=MyNDData) + assert new_ndd.spam == "Eggs" + assert new_ndd.data is ndc.data + assert new_ndd.wcs is ndc.wcs