diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 116cd8868..dfa3c4757 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: rev: "v0.6.9" hooks: - id: ruff - args: ["--fix"] + args: ["--fix", "--unsafe-fixes"] - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: diff --git a/.ruff.toml b/.ruff.toml index ea73580e8..cfac2f120 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,45 +1,64 @@ -target-version = "py310" -line-length = 110 -exclude = [ - ".git,", - "__pycache__", - "build", - "ndcube/version.py", +# Allow unused variables when underscore-prefixed. +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +target-version = "py312" +line-length = 120 +extend-exclude=[ + "__pycache__", + "build", + "tools/**", ] - -[lint] -select = [ - "E", - "F", - "W", - #"UP", - #"PT" -] -extend-ignore = [ - "E712", - "E721", - # pycodestyle (E, W) - "E501", # LineTooLong # TODO! fix - # pytest (PT) - "PT001", # Always use pytest.fixture() - "PT004", # Fixtures which don't return anything should have leading _ - "PT007", # Parametrize should be lists of tuples # TODO! fix - "PT011", # Too broad exception assert # TODO! fix - "PT023", # Always use () on pytest decorators +lint.select = [ + "ALL", +] +lint.ignore = [ + "ANN", +] +lint.extend-ignore = [ + "ANN", # All annotation + "COM812", # May cause conflicts when used with the formatter + "CPY001", # Missing copyright notice at top of file + "D200", # One-line docstring should fit on one line + "D205", # 1 blank line required between summary line and description + "D400", # First line should end with a period + "D401", # First line should be in imperative mood + "D404", # First word of the docstring should not be "This" + "E501", # Line too long + "FIX002", # Line contains TODO, consider resolving the issue + "TD003", # Missing issue link on the line following this TODO + "TD002", # Missing author in TODO + "ISC001", # May cause conflicts when used with the formatter + "PLC0415", # `import` should be at the top-level of a file" + "PLR2004", # Magic value used in comparison + "S101", # Use of `assert` detected ] [lint.per-file-ignores] -# Part of configuration, not a package. -"setup.py" = ["INP001"] -"conftest.py" = ["INP001"] +"examples/*.py" = [ + "B018", # Not print but display + "D400", # First line should end with a period, question mark, or exclamation point + "ERA001", # Commented out code + "INP001", # Implicit namespace package + "T201", # Use print +] "docs/conf.py" = [ - "E402" # Module imports not at top of file + "INP001", # conf.py is part of an implicit namespace package ] -"docs/*.py" = [ - "INP001", # Implicit-namespace-package. The examples are not a package. +"test_*.py" = [ + "D", # All docs + "SLF001", # Access private memeber +] +"__init__.py" = [ + "D104", # Missing docstring in public package +] +"conftest.py" = [ + "D100", # Missing docstring in public module + "D103", # Missing docstring in public function ] -"__init__.py" = ["E402", "F401", "F403"] -"test_*.py" = ["B011", "D", "E402", "PGH001", "S101"] [lint.pydocstyle] convention = "numpy" + +[format] +docstring-code-format = true +indent-style = "space" +quote-style = "double" diff --git a/docs/conf.py b/docs/conf.py index 30b2b3b05..f96415b24 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,17 +1,17 @@ # # Configuration file for the Sphinx documentation builder. +import datetime import os import warnings -import datetime from astropy.utils.exceptions import AstropyDeprecationWarning from matplotlib import MatplotlibDeprecationWarning from packaging.version import Version # -- Read the Docs Specific Configuration ------------------------------------ -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if on_rtd: - os.environ['HIDE_PARFIVE_PROGESS'] = 'True' + os.environ["HIDE_PARFIVE_PROGESS"] = "True" # -- Project information ----------------------------------------------------- @@ -31,7 +31,7 @@ project = "ndcube" author = "The SunPy Community" -copyright = f'{datetime.datetime.now().year}, {author}' # noqa: A001 +copyright = f"{datetime.datetime.now().year}, {author}" # noqa: A001 warnings.filterwarnings("error", category=MatplotlibDeprecationWarning) warnings.filterwarnings("error", category=AstropyDeprecationWarning) @@ -39,21 +39,21 @@ # -- General configuration --------------------------------------------------- extensions = [ - 'matplotlib.sphinxext.plot_directive', - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.inheritance_diagram', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx.ext.doctest', - 'sphinx.ext.mathjax', - 'sphinx_automodapi.automodapi', - 'sphinx_automodapi.smart_resolver', - 'ndcube.utils.sphinx.code_context', - 'sphinx_changelog', - 'sphinx_gallery.gen_gallery', + "matplotlib.sphinxext.plot_directive", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.inheritance_diagram", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.doctest", + "sphinx.ext.mathjax", + "sphinx_automodapi.automodapi", + "sphinx_automodapi.smart_resolver", + "ndcube.utils.sphinx.code_context", + "sphinx_changelog", + "sphinx_gallery.gen_gallery", "sphinxext.opengraph", ] @@ -78,17 +78,17 @@ # -- Options for intersphinx extension --------------------------------------- intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', - (None, 'http://data.astropy.org/intersphinx/python3.inv')), - 'numpy': ('https://docs.scipy.org/doc/numpy/', - (None, 'http://data.astropy.org/intersphinx/numpy.inv')), - 'matplotlib': ('https://matplotlib.org/', - (None, 'http://data.astropy.org/intersphinx/matplotlib.inv')), - 'astropy': ('http://docs.astropy.org/en/stable/', None), - 'sunpy': ('https://docs.sunpy.org/en/stable/', None), - 'mpl_animators': ('https://docs.sunpy.org/projects/mpl-animators/en/stable/', None), - 'gwcs': ('https://gwcs.readthedocs.io/en/stable/', None), - 'reproject': ("https://reproject.readthedocs.io/en/stable/", None) + "python": ("https://docs.python.org/3/", + (None, "http://data.astropy.org/intersphinx/python3.inv")), + "numpy": ("https://docs.scipy.org/doc/numpy/", + (None, "http://data.astropy.org/intersphinx/numpy.inv")), + "matplotlib": ("https://matplotlib.org/", + (None, "http://data.astropy.org/intersphinx/matplotlib.inv")), + "astropy": ("http://docs.astropy.org/en/stable/", None), + "sunpy": ("https://docs.sunpy.org/en/stable/", None), + "mpl_animators": ("https://docs.sunpy.org/projects/mpl-animators/en/stable/", None), + "gwcs": ("https://gwcs.readthedocs.io/en/stable/", None), + "reproject": ("https://reproject.readthedocs.io/en/stable/", None), } # -- Options for HTML output ------------------------------------------------- @@ -97,9 +97,9 @@ # a list of builtin themes. html_theme = "sunpy" -html_logo = png_icon = 'logo/ndcube.png' +html_logo = png_icon = "logo/ndcube.png" -html_favicon = 'logo/favicon.png' +html_favicon = "logo/favicon.png" # Render inheritance diagrams in SVG graphviz_output_format = "svg" @@ -136,7 +136,7 @@ nitpicky = True # This is not used. See docs/nitpick-exceptions file for the actual listing. nitpick_ignore = [] -for line in open('nitpick-exceptions'): +for line in open("nitpick-exceptions"): if line.strip() == "" or line.startswith("#"): continue dtype, target = line.split(None, 1) @@ -146,18 +146,18 @@ # -- Sphinx Gallery --------------------------------------------------------- sphinx_gallery_conf = { - 'backreferences_dir': os.path.join('generated', 'modules'), - 'filename_pattern': '^((?!skip_).)*$', - 'examples_dirs': os.path.join('..', 'examples'), - 'within_subsection_order': "ExampleTitleSortKey", - 'gallery_dirs': os.path.join('generated', 'gallery'), - 'matplotlib_animations': True, + "backreferences_dir": os.path.join("generated", "modules"), + "filename_pattern": "^((?!skip_).)*$", + "examples_dirs": os.path.join("..", "examples"), + "within_subsection_order": "ExampleTitleSortKey", + "gallery_dirs": os.path.join("generated", "gallery"), + "matplotlib_animations": True, "default_thumb_file": png_icon, - 'abort_on_example_error': False, - 'plot_gallery': 'True', - 'remove_config_comments': True, - 'doc_module': ('ndcube'), - 'only_warn_on_example_error': True, + "abort_on_example_error": False, + "plot_gallery": "True", + "remove_config_comments": True, + "doc_module": ("ndcube"), + "only_warn_on_example_error": True, } # -- Sphinxext Opengraph ---------------------------------------------------- diff --git a/examples/creating_a_gwcs_from_quantities.py b/examples/creating_a_gwcs_from_quantities.py index a84495085..e82e422c6 100644 --- a/examples/creating_a_gwcs_from_quantities.py +++ b/examples/creating_a_gwcs_from_quantities.py @@ -18,17 +18,17 @@ # We aim to create coordinates that are focused around time and energies using astropy quantities. energy = np.arange(10) * u.keV -time = Time('2020-01-01 00:00:00') + np.arange(9)*u.s +time = Time("2020-01-01 00:00:00") + np.arange(9)*u.s ############################################################################## # Then, we need to turn these into lookup tables using # `~ndcube.extra_coords.table_coord.QuantityTableCoordinate` and # `~ndcube.extra_coords.table_coord.TimeTableCoordinate` to create table coordinates. -energy_coord = QuantityTableCoordinate(energy, names='energy', physical_types='em.energy') +energy_coord = QuantityTableCoordinate(energy, names="energy", physical_types="em.energy") print(energy_coord) -time_coord = TimeTableCoordinate(time, names='time', physical_types='time') +time_coord = TimeTableCoordinate(time, names="time", physical_types="time") print(time_coord) ############################################################################## diff --git a/examples/creating_even_spaced_wavelength_visualisation.py b/examples/creating_even_spaced_wavelength_visualisation.py index 8ea953828..2b17e5e70 100644 --- a/examples/creating_even_spaced_wavelength_visualisation.py +++ b/examples/creating_even_spaced_wavelength_visualisation.py @@ -34,7 +34,7 @@ # `sequence=True` causes a sequence of maps to be returned, one for each image file. sequence_of_maps = sunpy.map.Map(aia_files, sequence=True) # Sort the maps in the sequence in order of wavelength. -sequence_of_maps.maps = list(sorted(sequence_of_maps.maps, key=lambda m: m.wavelength)) +sequence_of_maps.maps = sorted(sequence_of_maps.maps, key=lambda m: m.wavelength) ############################################################################# # Using an `astropy.units.Quantity` of the wavelengths of the images, we can construct @@ -55,6 +55,6 @@ my_cube = NDCube(sequence_of_maps.as_array(), wcs=cube_wcs) # Produce an interactive plot of the spectral-image stack. -my_cube.plot(plot_axes=['y', 'x', None]) +my_cube.plot(plot_axes=["y", "x", None]) plt.show() diff --git a/examples/creating_ndcube_from_fitsfile.py b/examples/creating_ndcube_from_fitsfile.py index f2c050759..ca193eb8e 100644 --- a/examples/creating_ndcube_from_fitsfile.py +++ b/examples/creating_ndcube_from_fitsfile.py @@ -20,7 +20,7 @@ # `~ndcube.NDCube` from data stored in a FITS file. # Here we are using an example file from ``astropy``. -image_file = get_pkg_data_filename('tutorials/FITS-images/HorseHead.fits') +image_file = get_pkg_data_filename("tutorials/FITS-images/HorseHead.fits") ########################################################################### # Lets extract the image data and the header information from the FITS file. diff --git a/examples/dev/example_template.py b/examples/dev/example_template.py index babdd1f40..d97af6089 100644 --- a/examples/dev/example_template.py +++ b/examples/dev/example_template.py @@ -36,17 +36,17 @@ plt.figure() plt.imshow(z) plt.colorbar() -plt.xlabel('$x$') -plt.ylabel('$y$') +plt.xlabel("$x$") +plt.ylabel("$y$") ########################################################################### # Again it is possible to continue the discussion with a new Python string. This # time to introduce the next code block generates 2 separate figures. plt.figure() -plt.imshow(z, cmap=plt.cm.get_cmap('hot')) +plt.imshow(z, cmap=plt.cm.get_cmap("hot")) plt.figure() -plt.imshow(z, cmap=plt.cm.get_cmap('Spectral'), interpolation='none') +plt.imshow(z, cmap=plt.cm.get_cmap("Spectral"), interpolation="none") ########################################################################## # There's some subtle differences between rendered html rendered comment @@ -70,7 +70,7 @@ def dummy(): ############################################################################ # Output of the script is captured: -print('Some output from Python') +print("Some output from Python") ############################################################################ # Finally, I'll call ``show`` at the end just so someone running the Python diff --git a/examples/slicing_ndcube.py b/examples/slicing_ndcube.py index bfeb7b62c..a26a83722 100644 --- a/examples/slicing_ndcube.py +++ b/examples/slicing_ndcube.py @@ -34,12 +34,12 @@ data = np.random.rand(5, 45, 45) # Define the WCS wcs = astropy.wcs.WCS(naxis=3) -wcs.wcs.ctype = 'HPLT-TAN', 'HPLN-TAN', "WAVE" -wcs.wcs.cunit = 'arcsec', 'arcsec', 'Angstrom' +wcs.wcs.ctype = "HPLT-TAN", "HPLN-TAN", "WAVE" +wcs.wcs.cunit = "arcsec", "arcsec", "Angstrom" wcs.wcs.cdelt = 10, 10, 0.2 wcs.wcs.crpix = 2, 2, 0 wcs.wcs.crval = 1, 1, 10 -wcs.wcs.cname = 'HPC lat', 'HPC lon', 'wavelength' +wcs.wcs.cname = "HPC lat", "HPC lon", "wavelength" # Instantiate the `~ndcube.NDCube` example_cube = NDCube(data, wcs=wcs) diff --git a/ndcube/__init__.py b/ndcube/__init__.py index 74e39392d..fe2bb0979 100644 --- a/ndcube/__init__.py +++ b/ndcube/__init__.py @@ -14,4 +14,4 @@ from .ndcube_sequence import NDCubeSequence, NDCubeSequenceBase from .version import version as __version__ -__all__ = ['NDCube', 'NDCubeSequence', "NDCollection", "ExtraCoords", "GlobalCoords", "ExtraCoordsABC", "GlobalCoordsABC", "NDCubeBase", "NDCubeSequenceBase", "__version__"] +__all__ = ["NDCube", "NDCubeSequence", "NDCollection", "ExtraCoords", "GlobalCoords", "ExtraCoordsABC", "GlobalCoordsABC", "NDCubeBase", "NDCubeSequenceBase", "__version__"] diff --git a/ndcube/_dev/scm_version.py b/ndcube/_dev/scm_version.py index 1bcf0dd99..ac0ae56c1 100644 --- a/ndcube/_dev/scm_version.py +++ b/ndcube/_dev/scm_version.py @@ -5,8 +5,9 @@ try: from setuptools_scm import get_version - version = get_version(root=os.path.join('..', '..'), relative_to=__file__) + version = get_version(root=os.path.join("..", ".."), relative_to=__file__) except ImportError: raise except Exception as e: - raise ValueError('setuptools_scm can not determine version.') from e + msg = "setuptools_scm can not determine version." + raise ValueError(msg) from e diff --git a/ndcube/conftest.py b/ndcube/conftest.py index 130941e6a..d9b75efa9 100644 --- a/ndcube/conftest.py +++ b/ndcube/conftest.py @@ -19,17 +19,17 @@ # Force MPL to use non-gui backends for testing. try: - import matplotlib + import matplotlib as mpl import matplotlib.pyplot as plt except ImportError: HAVE_MATPLOTLIB = False else: HAVE_MATPLOTLIB = True - matplotlib.use('Agg') + mpl.use("Agg") console_logger = logging.getLogger() -console_logger.setLevel('INFO') +console_logger.setLevel("INFO") ################################################################################ # Helper Functions @@ -50,9 +50,9 @@ def data_nd(shape, dtype=float): def time_extra_coords(shape, axis, base): return ExtraCoords.from_lookup_tables( - ('time',), + ("time",), (axis,), - (base + TimeDelta([i * 60 for i in range(shape[axis])], format='sec'),)) + (base + TimeDelta([i * 60 for i in range(shape[axis])], format="sec"),)) def gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, time_axis, time_base, global_coords=None): @@ -85,31 +85,31 @@ def gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, time_axis, time_base, global_co @pytest.fixture def wcs_4d_t_l_lt_ln(): header = { - 'CTYPE1': 'TIME ', - 'CUNIT1': 'min', - 'CDELT1': 0.4, - 'CRPIX1': 0, - 'CRVAL1': 0, - - 'CTYPE2': 'WAVE ', - 'CUNIT2': 'Angstrom', - 'CDELT2': 0.2, - 'CRPIX2': 0, - 'CRVAL2': 0, - - 'CTYPE3': 'HPLT-TAN', - 'CUNIT3': 'arcsec', - 'CDELT3': 20, - 'CRPIX3': 0, - 'CRVAL3': 0, - - 'CTYPE4': 'HPLN-TAN', - 'CUNIT4': 'arcsec', - 'CDELT4': 5, - 'CRPIX4': 5, - 'CRVAL4': 0, - - 'DATEREF': "2020-01-01T00:00:00" + "CTYPE1": "TIME ", + "CUNIT1": "min", + "CDELT1": 0.4, + "CRPIX1": 0, + "CRVAL1": 0, + + "CTYPE2": "WAVE ", + "CUNIT2": "Angstrom", + "CDELT2": 0.2, + "CRPIX2": 0, + "CRVAL2": 0, + + "CTYPE3": "HPLT-TAN", + "CUNIT3": "arcsec", + "CDELT3": 20, + "CRPIX3": 0, + "CRVAL3": 0, + + "CTYPE4": "HPLN-TAN", + "CUNIT4": "arcsec", + "CDELT4": 5, + "CRPIX4": 5, + "CRVAL4": 0, + + "DATEREF": "2020-01-01T00:00:00", } return WCS(header=header) @@ -117,31 +117,31 @@ def wcs_4d_t_l_lt_ln(): @pytest.fixture def wcs_4d_lt_t_l_ln(): header = { - 'CTYPE1': 'HPLT-TAN', - 'CUNIT1': 'arcsec', - 'CDELT1': 20, - 'CRPIX1': 0, - 'CRVAL1': 0, - - 'CTYPE2': 'TIME ', - 'CUNIT2': 'min', - 'CDELT2': 0.4, - 'CRPIX2': 0, - 'CRVAL2': 0, - - 'CTYPE3': 'WAVE ', - 'CUNIT3': 'Angstrom', - 'CDELT3': 0.2, - 'CRPIX3': 0, - 'CRVAL3': 0, - - 'CTYPE4': 'HPLN-TAN', - 'CUNIT4': 'arcsec', - 'CDELT4': 5, - 'CRPIX4': 5, - 'CRVAL4': 0, - - 'DATEREF': "2020-01-01T00:00:00" + "CTYPE1": "HPLT-TAN", + "CUNIT1": "arcsec", + "CDELT1": 20, + "CRPIX1": 0, + "CRVAL1": 0, + + "CTYPE2": "TIME ", + "CUNIT2": "min", + "CDELT2": 0.4, + "CRPIX2": 0, + "CRVAL2": 0, + + "CTYPE3": "WAVE ", + "CUNIT3": "Angstrom", + "CDELT3": 0.2, + "CRPIX3": 0, + "CRVAL3": 0, + + "CTYPE4": "HPLN-TAN", + "CUNIT4": "arcsec", + "CDELT4": 5, + "CRPIX4": 5, + "CRVAL4": 0, + + "DATEREF": "2020-01-01T00:00:00", } return WCS(header=header) @@ -149,23 +149,23 @@ def wcs_4d_lt_t_l_ln(): @pytest.fixture def wcs_3d_l_lt_ln(): header = { - 'CTYPE1': 'WAVE ', - 'CUNIT1': 'Angstrom', - 'CDELT1': 0.2, - 'CRPIX1': 0, - 'CRVAL1': 10, - - 'CTYPE2': 'HPLT-TAN', - 'CUNIT2': 'arcsec', - 'CDELT2': 5, - 'CRPIX2': 5, - 'CRVAL2': 0, - - 'CTYPE3': 'HPLN-TAN', - 'CUNIT3': 'arcsec', - 'CDELT3': 10, - 'CRPIX3': 0, - 'CRVAL3': 0, + "CTYPE1": "WAVE ", + "CUNIT1": "Angstrom", + "CDELT1": 0.2, + "CRPIX1": 0, + "CRVAL1": 10, + + "CTYPE2": "HPLT-TAN", + "CUNIT2": "arcsec", + "CDELT2": 5, + "CRPIX2": 5, + "CRVAL2": 0, + + "CTYPE3": "HPLN-TAN", + "CUNIT3": "arcsec", + "CDELT3": 10, + "CRPIX3": 0, + "CRVAL3": 0, } return WCS(header=header) @@ -175,23 +175,23 @@ def wcs_3d_l_lt_ln(): def wcs_3d_lt_ln_l(): header = { - 'CTYPE1': 'HPLN-TAN', - 'CUNIT1': 'arcsec', - 'CDELT1': 10, - 'CRPIX1': 0, - 'CRVAL1': 0, - - 'CTYPE2': 'HPLT-TAN', - 'CUNIT2': 'arcsec', - 'CDELT2': 5, - 'CRPIX2': 5, - 'CRVAL2': 0, - - 'CTYPE3': 'WAVE ', - 'CUNIT3': 'Angstrom', - 'CDELT3': 0.2, - 'CRPIX3': 0, - 'CRVAL3': 10, + "CTYPE1": "HPLN-TAN", + "CUNIT1": "arcsec", + "CDELT1": 10, + "CRPIX1": 0, + "CRVAL1": 0, + + "CTYPE2": "HPLT-TAN", + "CUNIT2": "arcsec", + "CDELT2": 5, + "CRPIX2": 5, + "CRVAL2": 0, + + "CTYPE3": "WAVE ", + "CUNIT3": "Angstrom", + "CDELT3": 0.2, + "CRPIX3": 0, + "CRVAL3": 10, } return WCS(header=header) @@ -200,17 +200,17 @@ def wcs_3d_lt_ln_l(): @pytest.fixture def wcs_2d_lt_ln(): spatial = { - 'CTYPE1': 'HPLT-TAN', - 'CUNIT1': 'arcsec', - 'CDELT1': 2, - 'CRPIX1': 5, - 'CRVAL1': 0, - - 'CTYPE2': 'HPLN-TAN', - 'CUNIT2': 'arcsec', - 'CDELT2': 4, - 'CRPIX2': 5, - 'CRVAL2': 0, + "CTYPE1": "HPLT-TAN", + "CUNIT1": "arcsec", + "CDELT1": 2, + "CRPIX1": 5, + "CRVAL1": 0, + + "CTYPE2": "HPLN-TAN", + "CUNIT2": "arcsec", + "CDELT2": 4, + "CRPIX2": 5, + "CRVAL2": 0, } return WCS(header=spatial) @@ -218,12 +218,12 @@ def wcs_2d_lt_ln(): @pytest.fixture def wcs_1d_l(): spatial = { - 'CNAME1': 'spectral', - 'CTYPE1': 'WAVE', - 'CUNIT1': 'nm', - 'CDELT1': 0.5, - 'CRPIX1': 2, - 'CRVAL1': 0.5, + "CNAME1": "spectral", + "CTYPE1": "WAVE", + "CUNIT1": "nm", + "CDELT1": 0.5, + "CRPIX1": 2, + "CRVAL1": 0.5, } return WCS(header=spatial) @@ -231,38 +231,38 @@ def wcs_1d_l(): @pytest.fixture def wcs_3d_ln_lt_t_rotated(): h_rotated = { - 'CTYPE1': 'HPLN-TAN', - 'CUNIT1': 'arcsec', - 'CDELT1': 0.4, - 'CRPIX1': 0, - 'CRVAL1': 0, - 'NAXIS1': 5, - - 'CTYPE2': 'HPLT-TAN', - 'CUNIT2': 'arcsec', - 'CDELT2': 0.5, - 'CRPIX2': 0, - 'CRVAL2': 0, - 'NAXIS2': 5, - - 'CTYPE3': 'TIME ', - 'CUNIT3': 's', - 'CDELT3': 3, - 'CRPIX3': 0, - 'CRVAL3': 0, - 'NAXIS3': 2, - - 'DATEREF': "2020-01-01T00:00:00", - - 'PC1_1': 0.714963912964, - 'PC1_2': -0.699137151241, - 'PC1_3': 0.0, - 'PC2_1': 0.699137151241, - 'PC2_2': 0.714963912964, - 'PC2_3': 0.0, - 'PC3_1': 0.0, - 'PC3_2': 0.0, - 'PC3_3': 1.0 + "CTYPE1": "HPLN-TAN", + "CUNIT1": "arcsec", + "CDELT1": 0.4, + "CRPIX1": 0, + "CRVAL1": 0, + "NAXIS1": 5, + + "CTYPE2": "HPLT-TAN", + "CUNIT2": "arcsec", + "CDELT2": 0.5, + "CRPIX2": 0, + "CRVAL2": 0, + "NAXIS2": 5, + + "CTYPE3": "TIME ", + "CUNIT3": "s", + "CDELT3": 3, + "CRPIX3": 0, + "CRVAL3": 0, + "NAXIS3": 2, + + "DATEREF": "2020-01-01T00:00:00", + + "PC1_1": 0.714963912964, + "PC1_2": -0.699137151241, + "PC1_3": 0.0, + "PC2_1": 0.699137151241, + "PC2_2": 0.714963912964, + "PC2_3": 0.0, + "PC3_1": 0.0, + "PC3_2": 0.0, + "PC3_3": 1.0, } return WCS(header=h_rotated) @@ -274,12 +274,12 @@ def wcs_3d_ln_lt_t_rotated(): @pytest.fixture def simple_extra_coords_3d(): - return ExtraCoords.from_lookup_tables(('time', 'hello', 'bye'), + return ExtraCoords.from_lookup_tables(("time", "hello", "bye"), (0, 1, 2), (list(range(2)) * u.pix, list(range(3)) * u.pix, - list(range(4)) * u.pix - ) + list(range(4)) * u.pix, + ), ) @@ -289,7 +289,7 @@ def time_and_simple_extra_coords_2d(): (0, 1), (Time(["2000-01-01T12:00:00", "2000-01-02T12:00:00"], scale="utc", format="fits"), - list(range(3)) * u.pix) + list(range(3)) * u.pix), ) @@ -298,19 +298,19 @@ def extra_coords_3d(): coord0 = Time(["2000-01-01T12:00:00", "2000-01-02T12:00:00"], scale="utc", format="fits") coord1 = list(range(3)) * u.pix coord2 = list(range(4)) * u.m - return ExtraCoords.from_lookup_tables(('time', 'bye', 'hello'), + return ExtraCoords.from_lookup_tables(("time", "bye", "hello"), (0, 1, 2), - (coord0, coord1, coord2) + (coord0, coord1, coord2), ) @pytest.fixture def extra_coords_sharing_axis(): - return ExtraCoords.from_lookup_tables(('hello', 'bye'), + return ExtraCoords.from_lookup_tables(("hello", "bye"), (1, 1), (list(range(3)) * u.m, list(range(3)) * u.keV, - ) + ), ) @@ -464,7 +464,7 @@ def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d): def ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l): return gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, 1, - Time('2000-01-01', format='fits', scale='utc')) + Time("2000-01-01", format="fits", scale="utc")) @pytest.fixture @@ -479,8 +479,7 @@ def ndcube_2d_ln_lt_uncert(wcs_2d_lt_ln): shape = (10, 12) data_cube = data_nd(shape) uncertainty = astropy.nddata.StdDevUncertainty(data_cube * 0.1) - cube = NDCube(data_cube, wcs=wcs_2d_lt_ln, uncertainty=uncertainty) - return cube + return NDCube(data_cube, wcs=wcs_2d_lt_ln, uncertainty=uncertainty) @pytest.fixture @@ -493,8 +492,7 @@ def ndcube_2d_ln_lt_mask_uncert(wcs_2d_lt_ln): mask[2, 0] = True mask[3, 3] = True mask[4:6, :4] = True - cube = NDCube(data_cube, wcs=wcs_2d_lt_ln, uncertainty=uncertainty, mask=mask) - return cube + return NDCube(data_cube, wcs=wcs_2d_lt_ln, uncertainty=uncertainty, mask=mask) @pytest.fixture @@ -606,22 +604,22 @@ def ndcubesequence_4c_ln_lt_l_cax1(ndcube_3d_ln_lt_l): def ndcubesequence_3c_l_ln_lt_cax1(wcs_3d_lt_ln_l): common_axis = 1 - base_time1 = Time('2000-01-01', format='fits', scale='utc') + base_time1 = Time("2000-01-01", format="fits", scale="utc") gc1 = GlobalCoords() - gc1.add('distance', 'custom:distance', 1*u.m) + gc1.add("distance", "custom:distance", 1*u.m) cube1 = gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, 1, base_time1, gc1) shape = cube1.data.shape - base_time2 = base_time1 + TimeDelta([shape[common_axis] * 60], format='sec') + base_time2 = base_time1 + TimeDelta([shape[common_axis] * 60], format="sec") gc2 = GlobalCoords() - gc2.add('distance', 'custom:distance', 2*u.m) - gc2.add('global coord', 'custom:physical_type', 0*u.pix) + gc2.add("distance", "custom:distance", 2*u.m) + gc2.add("global coord", "custom:physical_type", 0*u.pix) cube2 = gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, 1, base_time2, gc2) cube2.data[:] *= 2 - base_time3 = base_time2 + TimeDelta([shape[common_axis] * 60], format='sec') + base_time3 = base_time2 + TimeDelta([shape[common_axis] * 60], format="sec") gc3 = GlobalCoords() - gc3.add('distance', 'custom:distance', 3*u.m) + gc3.add("distance", "custom:distance", 3*u.m) cube3 = gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l, 1, base_time3, gc3) cube3.data[:] *= 3 @@ -634,4 +632,4 @@ def pytest_runtest_teardown(item): if HAVE_MATPLOTLIB and plt.get_fignums(): console_logger.info(f"Removing {len(plt.get_fignums())} pyplot figure(s) " f"left open by {item.name}") - plt.close('all') + plt.close("all") diff --git a/ndcube/extra_coords/__init__.py b/ndcube/extra_coords/__init__.py index 315b95209..ad04a0472 100644 --- a/ndcube/extra_coords/__init__.py +++ b/ndcube/extra_coords/__init__.py @@ -6,5 +6,5 @@ TimeTableCoordinate, ) -__all__ = ['TimeTableCoordinate', "MultipleTableCoordinate", - 'SkyCoordTableCoordinate', 'QuantityTableCoordinate', "BaseTableCoordinate"] +__all__ = ["TimeTableCoordinate", "MultipleTableCoordinate", + "SkyCoordTableCoordinate", "QuantityTableCoordinate", "BaseTableCoordinate"] diff --git a/ndcube/extra_coords/extra_coords.py b/ndcube/extra_coords/extra_coords.py index 87f4de9b7..ce5f1c24c 100644 --- a/ndcube/extra_coords/extra_coords.py +++ b/ndcube/extra_coords/extra_coords.py @@ -24,7 +24,7 @@ TimeTableCoordinate, ) -__all__ = ['ExtraCoordsABC', 'ExtraCoords'] +__all__ = ["ExtraCoordsABC", "ExtraCoords"] class ExtraCoordsABC(abc.ABC): @@ -45,12 +45,13 @@ class ExtraCoordsABC(abc.ABC): of length equal to the number of pixel dimensions in the extra coords. """ + @abc.abstractmethod def add(self, name: str | Iterable[str], array_dimension: int | Iterable[int], lookup_table: Any, - physical_types: str | Iterable[str] = None, + physical_types: str | Iterable[str] | None = None, **kwargs): """ Add a coordinate to this `~ndcube.ExtraCoords` based on a lookup table. @@ -140,7 +141,7 @@ class ExtraCoords(ExtraCoordsABC): """ - def __init__(self, ndcube=None): + def __init__(self, ndcube=None) -> None: super().__init__() # Setup private attributes @@ -149,8 +150,8 @@ def __init__(self, ndcube=None): # Lookup tables is a list of (pixel_dim, LookupTableCoord) to allow for # one pixel dimension having more than one lookup coord. - self._lookup_tables = list() - self._dropped_tables = list() + self._lookup_tables = [] + self._dropped_tables = [] # We need a reference to the parent NDCube self._ndcube = ndcube @@ -187,19 +188,21 @@ def from_lookup_tables(cls, names, pixel_dimensions, lookup_tables, physical_typ """ if len(pixel_dimensions) != len(lookup_tables): + msg = "The length of pixel_dimensions and lookup_tables must match." raise ValueError( - "The length of pixel_dimensions and lookup_tables must match." + msg, ) if physical_types is None: physical_types = len(lookup_tables) * [physical_types] elif len(physical_types) != len(lookup_tables): - raise ValueError("The number of physical types and lookup_tables must match.") + msg = "The number of physical types and lookup_tables must match." + raise ValueError(msg) extra_coords = cls() for name, pixel_dim, lookup_table, physical_type in zip(names, pixel_dimensions, - lookup_tables, physical_types): + lookup_tables, physical_types, strict=False): extra_coords.add(name, pixel_dim, lookup_table, physical_types=physical_type) return extra_coords @@ -208,11 +211,12 @@ def add(self, name, array_dimension, lookup_table, physical_types=None, **kwargs # docstring in ABC if self._wcs is not None: + msg = "Can not add a lookup_table to an ExtraCoords which was instantiated with a WCS object." raise ValueError( - "Can not add a lookup_table to an ExtraCoords which was instantiated with a WCS object." + msg, ) - kwargs['names'] = [name] if not isinstance(name, (list, tuple)) else name + kwargs["names"] = [name] if not isinstance(name, list | tuple) else name if isinstance(lookup_table, BaseTableCoordinate): coord = lookup_table @@ -220,18 +224,19 @@ def add(self, name, array_dimension, lookup_table, physical_types=None, **kwargs coord = TimeTableCoordinate(lookup_table, physical_types=physical_types, **kwargs) elif isinstance(lookup_table, SkyCoord): coord = SkyCoordTableCoordinate(lookup_table, physical_types=physical_types, **kwargs) - elif isinstance(lookup_table, (list, tuple)): + elif isinstance(lookup_table, list | tuple): coord = QuantityTableCoordinate(*lookup_table, physical_types=physical_types, **kwargs) elif isinstance(lookup_table, u.Quantity): coord = QuantityTableCoordinate(lookup_table, physical_types=physical_types, **kwargs) else: - raise TypeError(f"The input type {type(lookup_table)} isn't supported") + msg = f"The input type {type(lookup_table)} isn't supported" + raise TypeError(msg) self._lookup_tables.append((array_dimension, coord)) # Sort the LUTs so that the mapping and the wcs are ordered in pixel dim order - self._lookup_tables = list(sorted(self._lookup_tables, - key=lambda x: x[0] if isinstance(x[0], Integral) else x[0][0])) + self._lookup_tables = sorted(self._lookup_tables, + key=lambda x: x[0] if isinstance(x[0], Integral) else x[0][0]) @property def _name_lut_map(self): @@ -243,7 +248,7 @@ def _name_lut_map(self): def keys(self): # docstring in ABC if not self.wcs: - return tuple() + return () return tuple(self.wcs.world_axis_names) if self.wcs.world_axis_names else None @@ -256,7 +261,7 @@ def mapping(self): # If mapping is not set but lookup_tables is empty then the extra # coords is empty, so there is no mapping. if not self._lookup_tables: - return tuple() + return () # The mapping is from the array index (position in the list) to the # pixel dimensions (numbers in the list) @@ -268,18 +273,20 @@ def mapping(self): @mapping.setter def mapping(self, mapping): if self._mapping is not None: - raise AttributeError("Can't set mapping if a mapping has already been specified.") + msg = "Can't set mapping if a mapping has already been specified." + raise AttributeError(msg) if self._lookup_tables: + msg = "Can't set mapping manually when ExtraCoords is built from lookup tables." raise AttributeError( - "Can't set mapping manually when ExtraCoords is built from lookup tables." + msg, ) - if self._wcs is not None: - if not max(mapping) <= self._wcs.pixel_n_dim - 1: - raise ValueError( - "Values in the mapping can not be larger than the number of pixel dimensions in the WCS." - ) + if self._wcs is not None and not max(mapping) <= self._wcs.pixel_n_dim - 1: + msg = "Values in the mapping can not be larger than the number of pixel dimensions in the WCS." + raise ValueError( + msg, + ) self._mapping = mapping @@ -292,7 +299,7 @@ def wcs(self): if not self._lookup_tables: return None - tcoords = set(lt[1] for lt in self._lookup_tables) + tcoords = {lt[1] for lt in self._lookup_tables} # created a sorted list of unique items _tmp = set() # a temporary set tcoords = [x[1] for x in self._lookup_tables if x[1] not in _tmp and _tmp.add(x[1]) is None] @@ -301,43 +308,42 @@ def wcs(self): @wcs.setter def wcs(self, wcs): if self._wcs is not None: + msg = "Can't set wcs if a WCS has already been specified." raise AttributeError( - "Can't set wcs if a WCS has already been specified." + msg, ) if self._lookup_tables: + msg = "Can't set wcs manually when ExtraCoords is built from lookup tables." raise AttributeError( - "Can't set wcs manually when ExtraCoords is built from lookup tables." + msg, ) - if self._mapping is not None: - if not max(self._mapping) <= wcs.pixel_n_dim - 1: - raise ValueError( - "Values in the mapping can not be larger than the number of pixel dimensions in the WCS." - ) + if self._mapping is not None and not max(self._mapping) <= wcs.pixel_n_dim - 1: + msg = "Values in the mapping can not be larger than the number of pixel dimensions in the WCS." + raise ValueError( + msg, + ) self._wcs = wcs @property def is_empty(self): # docstring in ABC - if not self._wcs and not self._lookup_tables: - return True - else: - return False + return bool(not self._wcs and not self._lookup_tables) def _getitem_string(self, item): """ Slice the Extracoords based on axis names. """ - for names, lut in self._name_lut_map.items(): if item in names: new_ec = ExtraCoords(ndcube=self._ndcube) new_ec._lookup_tables = [lut] return new_ec - raise KeyError(f"Can't find the world axis named {item} in this ExtraCoords object.") + msg = f"Can't find the world axis named {item} in this ExtraCoords object." + raise KeyError(msg) def _getitem_lookup_tables(self, item): """ @@ -402,7 +408,7 @@ def __getitem__(self, item): if self._wcs: return self._getitem_wcs(item) - elif self._lookup_tables: + if self._lookup_tables: return self._getitem_lookup_tables(item) # If we get here this object is empty, so just return an empty extra coords @@ -414,10 +420,8 @@ def dropped_world_dimensions(self): """ Return an APE-14 like representation of any sliced out world dimensions. """ - - if self._wcs: - if isinstance(self._wcs, SlicedLowLevelWCS): - return self._wcs.dropped_world_dimensions + if self._wcs and isinstance(self._wcs, SlicedLowLevelWCS): + return self._wcs.dropped_world_dimensions if self._lookup_tables or self._dropped_tables: mtc = MultipleTableCoordinate(*[lt[1] for lt in self._lookup_tables]) @@ -425,7 +429,7 @@ def dropped_world_dimensions(self): return mtc.dropped_world_dimensions - return dict() + return {} def resample(self, factor, offset=0, ndcube=None, **kwargs): """ @@ -466,22 +470,31 @@ def resample(self, factor, offset=0, ndcube=None, **kwargs): elif self._wcs is not None: ndim = self._wcs.pixel_n_dim else: - raise NotImplementedError( + msg = ( "Resampling a lookup-table-based ExtraCoords not yet implemented. " "Please raise an issue at https://github.com/sunpy/ndcube/issues " - "if you need this functionality") + "if you need this functionality" + ) + raise NotImplementedError( + msg) if np.isscalar(factor): factor = [factor] * ndim if len(factor) != ndim: - raise ValueError( + msg = ( "factor must be scalar or an iterable with length equal to number of cube " - f"dimensions: len(factor) = {len(factor)}; No. cube dimensions = {ndim}.") + f"dimensions: len(factor) = {len(factor)}; No. cube dimensions = {ndim}." + ) + raise ValueError( + msg) if np.isscalar(offset): offset = [offset] * ndim if len(offset) != ndim: - raise ValueError( + msg = ( "offset must be scalar or an iterable with length equal to number of cube " - f"dimensions: len(offset) = {len(offset)}; No. cube dimensions = {ndim}.") + f"dimensions: len(offset) = {len(offset)}; No. cube dimensions = {ndim}." + ) + raise ValueError( + msg) # If ExtraCoords object built on WCS, resample using WCS insfrastructure if self._wcs is not None: new_ec.wcs = HighLevelWCSWrapper(ResampledLowLevelWCS(self._wcs.low_level_wcs, @@ -490,7 +503,7 @@ def resample(self, factor, offset=0, ndcube=None, **kwargs): # Else interpolate the lookup table coordinates. factor = np.asarray(factor) new_grids = [] - for c, d, f in zip(offset, cube_shape, factor): + for c, d, f in zip(offset, cube_shape, factor, strict=False): x = np.arange(c, d+f, f) x = x[x <= d-1] new_grids.append(x) @@ -529,17 +542,14 @@ def _cube_array_axes_without_extra_coords(self): """Return the array axes not associated with any extra coord.""" return set(range(len(self._ndcube.shape))) - set(self.mapping) - def __str__(self): + def __str__(self) -> str: classname = self.__class__.__name__ elements = [f"{', '.join(table.names)} ({axes}) {table.physical_types}: {table}" for axes, table in self._lookup_tables] length = len(classname) + 2 * len(elements) + sum(len(e) for e in elements) - if length > np.get_printoptions()['linewidth']: - joiner = ',\n ' + len(classname) * ' ' - else: - joiner = ', ' + joiner = ",\n " + len(classname) * " " if length > np.get_printoptions()["linewidth"] else ", " return f"{classname}({joiner.join(elements)})" - def __repr__(self): + def __repr__(self) -> str: return f"{object.__repr__(self)}\n{self}" diff --git a/ndcube/extra_coords/table_coord.py b/ndcube/extra_coords/table_coord.py index 503d6f1c7..f3bfc3289 100644 --- a/ndcube/extra_coords/table_coord.py +++ b/ndcube/extra_coords/table_coord.py @@ -1,5 +1,6 @@ import abc import copy +import contextlib from numbers import Integral from collections import defaultdict @@ -15,12 +16,10 @@ from astropy.time import Time from astropy.wcs.wcsapi.wrappers.sliced_wcs import combine_slices, sanitize_slices -try: +with contextlib.suppress(ImportError): import scipy.interpolate -except ImportError: - pass -__all__ = ['TimeTableCoordinate', 'SkyCoordTableCoordinate', 'QuantityTableCoordinate', "BaseTableCoordinate", "MultipleTableCoordinate"] +__all__ = ["TimeTableCoordinate", "SkyCoordTableCoordinate", "QuantityTableCoordinate", "BaseTableCoordinate", "MultipleTableCoordinate"] class Length1Tabular(_Tabular): @@ -35,7 +34,7 @@ class Length1Tabular(_Tabular): points = np.zeros([1]) def __init__(self, points=None, lookup_table=None, point_width=None, value_width=None, - method='linear', bounds_error=True, fill_value=np.nan, **kwargs): + method="linear", bounds_error=True, fill_value=np.nan, **kwargs) -> None: """Create a Length-1 1-D Tabular model. Parameters @@ -53,7 +52,8 @@ def __init__(self, points=None, lookup_table=None, point_width=None, value_width Other parameters are defined by the parent class. """ if len(lookup_table) != 1: - raise ValueError("lookup_table must have length 1.") + msg = "lookup_table must have length 1." + raise ValueError(msg) super().__init__(points=points, lookup_table=lookup_table, method=method, bounds_error=bounds_error, fill_value=fill_value, **kwargs) self._value_width = value_width # Width of point in world units. @@ -67,10 +67,7 @@ def evaluate(self, x): output = np.full(x.shape, self.fill_value) diff = abs(x - self.points[0]) margin = self._point_width / 2 - if margin.value == 0: - idx = diff == margin - else: - idx = np.logical_and(diff >= -1 * margin, diff < margin) + idx = diff == margin if margin.value == 0 else np.logical_and(diff >= -1 * margin, diff < margin) output[idx] = self.lookup_table[0].value return output * self.lookup_table.unit @@ -88,7 +85,7 @@ class InverseLength1Tabular(Length1Tabular): This is the opposite direction to Length1Tabular. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: # Same inputs as Length1Tabular points = kwargs.pop("points", None) lookup_table = kwargs.pop("lookup_table", None) @@ -99,7 +96,7 @@ def __init__(self, **kwargs): def evaluate(self, x): # When calling evaluate with a bounding box, astropy strips the units. - x = u.Quantity(x, unit=self.input_units['x'], copy=False) + x = u.Quantity(x, unit=self.input_units["x"], copy=False) return super().evaluate(x) @@ -112,13 +109,13 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None): name = None axes_type = "CUSTOM" - if isinstance(unit, (u.Unit, u.IrreducibleUnit, u.CompositeUnit)): + if isinstance(unit, u.Unit | u.IrreducibleUnit | u.CompositeUnit): unit = tuple([unit] * naxes) - if all([u.m.is_equivalent(un) for un in unit]): + if all(u.m.is_equivalent(un) for un in unit): axes_type = "SPATIAL" - if all([u.pix.is_equivalent(un) for un in unit]): + if all(u.pix.is_equivalent(un) for un in unit): name = "PixelFrame" axes_type = "PIXEL" @@ -128,12 +125,13 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None): axes_names=names, name=name, axis_physical_types=physical_types) -def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs): +def _generate_tabular(lookup_table, interpolation="linear", points_unit=u.pix, **kwargs): """ Generate a Tabular model class and instance. """ if not isinstance(lookup_table, u.Quantity): - raise TypeError("lookup_table must be a Quantity.") # pragma: no cover + msg = "lookup_table must be a Quantity." + raise TypeError(msg) # pragma: no cover ndim = lookup_table.ndim TabularND = tabular_model(ndim, name=f"Tabular{ndim}D") @@ -143,9 +141,9 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, * if len(points) == 1: points = points[0] - kwargs = {'bounds_error': False, - 'fill_value': np.nan, - 'method': interpolation, + kwargs = {"bounds_error": False, + "fill_value": np.nan, + "method": interpolation, **kwargs} if len(lookup_table) == 1: @@ -196,13 +194,13 @@ class BaseTableCoordinate(abc.ABC): coordinates, meaning it can have multiple gWCS frames. """ - def __init__(self, *tables, mesh=False, names=None, physical_types=None): + def __init__(self, *tables, mesh=False, names=None, physical_types=None) -> None: self.table = tables self.mesh = mesh self.names = names if not isinstance(names, str) else [names] self.physical_types = physical_types if not isinstance(physical_types, str) else [physical_types] self._dropped_world_dimensions = defaultdict(list) - self._dropped_world_dimensions["world_axis_object_classes"] = dict() + self._dropped_world_dimensions["world_axis_object_classes"] = {} @abc.abstractmethod def __getitem__(self, item): @@ -220,15 +218,14 @@ def __and__(self, other): return MultipleTableCoordinate(self, other) - def __str__(self): + def __str__(self) -> str: header = f"{self.__class__.__name__} {self.names or ''} {self.physical_types or '[None]'}:" - content = str(self.table).lstrip('(').rstrip(',)') - if len(header) + len(content) >= np.get_printoptions()['linewidth']: - return '\n'.join((header, content)) - else: - return ' '.join((header, content)) + content = str(self.table).lstrip("(").rstrip(",)") + if len(header) + len(content) >= np.get_printoptions()["linewidth"]: + return f"{header}\n{content}" + return f"{header} {content}" - def __repr__(self): + def __repr__(self) -> str: return f"{object.__repr__(self)}\n{self}" @property @@ -301,26 +298,33 @@ class QuantityTableCoordinate(BaseTableCoordinate): a physical type must be given for each component. """ - def __init__(self, *tables, names=None, physical_types=None): - if not all([isinstance(t, u.Quantity) for t in tables]): - raise TypeError("All tables must be astropy Quantity objects") - if not all([t.unit.is_equivalent(tables[0].unit) for t in tables]): - raise u.UnitsError("All tables must have equivalent units.") + def __init__(self, *tables, names=None, physical_types=None) -> None: + if not all(isinstance(t, u.Quantity) for t in tables): + msg = "All tables must be astropy Quantity objects" + raise TypeError(msg) + if not all(t.unit.is_equivalent(tables[0].unit) for t in tables): + msg = "All tables must have equivalent units." + raise u.UnitsError(msg) ndim = len(tables) dims = np.array([t.ndim for t in tables]) if any(dims > 1): - raise ValueError( + msg = ( "Currently all tables must be 1-D. If you need >1D support, please " - "raise an issue at https://github.con/sunpy/ndcube/issues") + "raise an issue at https://github.con/sunpy/ndcube/issues" + ) + raise ValueError( + msg) if isinstance(names, str): names = [names] if names is not None and len(names) != ndim: - raise ValueError("The number of names should match the number of world dimensions") + msg = "The number of names should match the number of world dimensions" + raise ValueError(msg) if isinstance(physical_types, str): physical_types = [physical_types] if physical_types is not None and len(physical_types) != ndim: - raise ValueError("The number of physical types should match the number of world dimensions") + msg = "The number of physical types should match the number of world dimensions" + raise ValueError(msg) self.unit = tables[0].unit @@ -349,7 +353,7 @@ def _slice_table(self, i, table, item, new_components, whole_slice): dwd["world_axis_physical_types"].append(self.frame.axis_physical_types[i]) dwd["world_axis_units"].append(table.unit.to_string()) dwd["world_axis_object_components"].append((f"quantity{i}", 0, "value")) - dwd["world_axis_object_classes"].update({f"quantity{i}": (u.Quantity, tuple(), {"unit", table.unit.to_string()})}) + dwd["world_axis_object_classes"].update({f"quantity{i}": (u.Quantity, (), {"unit", table.unit.to_string()})}) return new_components["tables"].append(table[item]) @@ -359,15 +363,16 @@ def _slice_table(self, i, table, item, new_components, whole_slice): new_components["physical_types"].append(self.physical_types[i]) def __getitem__(self, item): - if isinstance(item, (slice, Integral)): + if isinstance(item, slice | Integral): item = (item,) if not (len(item) == len(self.table) or len(item) == self.table[0].ndim): - raise ValueError("Can not slice with incorrect length") + msg = "Can not slice with incorrect length" + raise ValueError(msg) new_components = defaultdict(list) new_components["dropped_world_dimensions"] = copy.deepcopy(self._dropped_world_dimensions) - for i, (ele, table) in enumerate(zip(item, self.table)): + for i, (ele, table) in enumerate(zip(item, self.table, strict=False)): self._slice_table(i, table, ele, new_components, whole_slice=item) names = new_components["names"] or None @@ -382,7 +387,7 @@ def n_inputs(self): return len(self.table) def is_scalar(self): - return all(t.shape == tuple() for t in self.table) + return all(t.shape == () for t in self.table) @property def frame(self): @@ -441,20 +446,23 @@ def interpolate(self, *new_array_grids, **kwargs): """ if self.is_scalar(): - raise ValueError("Cannot interpolate a scalar QuantityTableCoordinate.") + msg = "Cannot interpolate a scalar QuantityTableCoordinate." + raise ValueError(msg) # Sanitize input. ndim = self.ndim if len(new_array_grids) != ndim: + msg = f"A new array grid must be given for each array axis/table, i.e. {ndim}" raise ValueError( - f"A new array grid must be given for each array axis/table, i.e. {ndim}") + msg) if any(new_grid.shape != new_array_grids[0].shape for new_grid in new_array_grids): - raise ValueError("New array grids must all be same shape.") + msg = "New array grids must all be same shape." + raise ValueError(msg) # Build array grids for non-interpolated table. old_array_grids = tuple(np.arange(d) for d in self.shape) # Iterate through tables and interpolate each. new_tables = [ np.interp(new_grid, old_grid, t.value, **kwargs) * t.unit - for new_grid, old_grid, t in zip(new_array_grids, old_array_grids, self.table)] + for new_grid, old_grid, t in zip(new_array_grids, old_array_grids, self.table, strict=False)] # Rebuild return interpolated coord. new_coord = type(self)(*new_tables, names=self.names, physical_types=self.physical_types) new_coord._dropped_world_dimensions = self._dropped_world_dimensions @@ -489,21 +497,29 @@ class SkyCoordTableCoordinate(BaseTableCoordinate): same length. """ - def __init__(self, *tables, mesh=False, names=None, physical_types=None): - if not len(tables) == 1 and isinstance(tables[0], SkyCoord): - raise ValueError("SkyCoordLookupTable can only be constructed from a single SkyCoord object") + def __init__(self, *tables, mesh=False, names=None, physical_types=None) -> None: + if len(tables) != 1 and isinstance(tables[0], SkyCoord): + msg = "SkyCoordLookupTable can only be constructed from a single SkyCoord object" + raise ValueError(msg) if mesh and tables[0].ndim > 1: - raise ValueError("If mesh is True, input SkyCoord must be 1-D.") + msg = "If mesh is True, input SkyCoord must be 1-D." + raise ValueError(msg) if isinstance(names, str): names = [names] n_components = len(tables[0].data.components) if names is not None and len(names) != n_components: - raise ValueError("The number of names must equal number of components in the input " - f"SkyCoord: {n_components}.") + msg = ( + "The number of names must equal number of components in the input " + f"SkyCoord: {n_components}." + ) + raise ValueError(msg) if physical_types is not None and len(physical_types) != n_components: - raise ValueError("The number of physical types must equal number of components in " - f"the input SkyCoord: {n_components}.") + msg = ( + "The number of physical types must equal number of components in " + f"the input SkyCoord: {n_components}." + ) + raise ValueError(msg) sc = tables[0] @@ -516,13 +532,14 @@ def n_inputs(self): return len(self.table.data.components) def is_scalar(self): - return self.table.shape == tuple() + return self.table.shape == () @staticmethod def combine_slices(slice1, slice2): ints = [isinstance(s, Integral) for s in (slice1, slice2)] if all(ints): - raise ValueError("Can not combine two integers") + msg = "Can not combine two integers" + raise ValueError(msg) if any(ints): return (slice1, slice2)[ints.index(True)] return combine_slices(slice1, slice2) @@ -532,23 +549,23 @@ def __getitem__(self, item): try: sane_item = sanitize_slices(item, self.n_inputs) except ValueError as ex: - raise ValueError("Can not slice with incorrect length") from ex + msg = "Can not slice with incorrect length" + raise ValueError(msg) from ex if not self.mesh: return type(self)(self.table[item], mesh=False, names=self.names, physical_types=self.physical_types) - else: - self._slice = [self.combine_slices(a, b) for a, b in zip(sane_item, self._slice)] - if all([isinstance(s, Integral) for s in self._slice]): - # Here we rebuild the SkyCoord with the slice applied to the individual components. - new_sc = SkyCoord(self.table.realize_frame(type(self.table.data)(*self._sliced_components))) - return type(self)(new_sc, - mesh=False, - names=self.names, - physical_types=self.physical_types) - return self + self._slice = [self.combine_slices(a, b) for a, b in zip(sane_item, self._slice, strict=False)] + if all(isinstance(s, Integral) for s in self._slice): + # Here we rebuild the SkyCoord with the slice applied to the individual components. + new_sc = SkyCoord(self.table.realize_frame(type(self.table.data)(*self._sliced_components))) + return type(self)(new_sc, + mesh=False, + names=self.names, + physical_types=self.physical_types) + return self @property def frame(self): @@ -558,7 +575,7 @@ def frame(self): sc = self.table components = tuple(getattr(sc.data, comp) for comp in sc.data.components) ref_frame = sc.frame.replicate_without_data() - units = list(c.unit for c in components) + units = [c.unit for c in components] # TODO: Currently this limits you to 2D due to gwcs#120 return cf.CelestialFrame(reference_frame=ref_frame, @@ -570,7 +587,7 @@ def frame(self): @property def _sliced_components(self): return tuple(getattr(self.table.data, comp)[slc] - for comp, slc in zip(self.table.data.components, self._slice)) + for comp, slc in zip(self.table.data.components, self._slice, strict=False)) @property def model(self): @@ -590,8 +607,7 @@ def ndim(self): """ if self.mesh: return len(self.table.data.components) - else: - return self.table.ndim + return self.table.ndim @property def shape(self): @@ -605,8 +621,7 @@ def shape(self): """ if self.mesh: return tuple(list(self.table.shape) * self.ndim) - else: - return self.table.shape + return self.table.shape def interpolate(self, *new_array_grids, mesh_output=None, **kwargs): """ @@ -636,7 +651,8 @@ def interpolate(self, *new_array_grids, mesh_output=None, **kwargs): New TableCoordinate object holding the interpolated coords. """ if self.is_scalar(): - raise ValueError("Cannot interpolate a scalar SkyCoordTableCoordinate.") + msg = "Cannot interpolate a scalar SkyCoordTableCoordinate." + raise ValueError(msg) # SkyCoords have multiple world components, e.g. lat and lon, even if # it 1-D. Interpolate the components separately then recombine into a new SkyCoord. # First, inspect underlying SkyCoord. @@ -644,23 +660,22 @@ def interpolate(self, *new_array_grids, mesh_output=None, **kwargs): ndim = self.ndim shape = self.shape if len(new_array_grids) != ndim: - raise ValueError(f"A new array grid must be given for each array axis, i.e. {ndim}") + msg = f"A new array grid must be given for each array axis, i.e. {ndim}" + raise ValueError(msg) if any(new_grid.shape != new_array_grids[0].shape for new_grid in new_array_grids): - raise ValueError("New array grids must all be same shape.") + msg = "New array grids must all be same shape." + raise ValueError(msg) if mesh_output is None: - if new_array_grids[0].ndim > 1: - mesh_output = False - else: - mesh_output = self.mesh + mesh_output = False if new_array_grids[0].ndim > 1 else self.mesh # Build old array grids. Note self._slice give the slice item(s) required to # make the underlying SkyCoord match the dimensionality of the associated data cube. - old_array_grids = [np.arange(d)[slc] for d, slc in zip(shape, self._slice)] + old_array_grids = [np.arange(d)[slc] for d, slc in zip(shape, self._slice, strict=False)] # Iterate through components and interpolate each. if self.mesh: new_components = [np.interp(new_grid, old_grid, comp, **kwargs) for new_grid, old_grid, comp - in zip(new_array_grids, old_array_grids, self._sliced_components)] + in zip(new_array_grids, old_array_grids, self._sliced_components, strict=False)] elif ndim == 1: new_components = [np.interp(*new_array_grids, *old_array_grids, comp, **kwargs) for comp in self._sliced_components] @@ -701,9 +716,10 @@ class TimeTableCoordinate(BaseTableCoordinate): Default is first time coordinate in table input. """ - def __init__(self, *tables, names=None, physical_types=None, reference_time=None): - if not len(tables) == 1 and isinstance(tables[0], Time): - raise ValueError("TimeLookupTable can only be constructed from a single Time object.") + def __init__(self, *tables, names=None, physical_types=None, reference_time=None) -> None: + if len(tables) != 1 and isinstance(tables[0], Time): + msg = "TimeLookupTable can only be constructed from a single Time object." + raise ValueError(msg) if isinstance(names, str): names = [names] @@ -711,17 +727,20 @@ def __init__(self, *tables, names=None, physical_types=None, reference_time=None physical_types = [physical_types] if names is not None and len(names) != 1: - raise ValueError("A Time coordinate can only have one name.") + msg = "A Time coordinate can only have one name." + raise ValueError(msg) if physical_types is not None and len(physical_types) != 1: - raise ValueError("A Time coordinate can only have one physical type.") + msg = "A Time coordinate can only have one physical type." + raise ValueError(msg) super().__init__(*tables, mesh=False, names=names, physical_types=physical_types) self.table = self.table[0] self.reference_time = reference_time or self.table[0] def __getitem__(self, item): - if not (isinstance(item, (slice, Integral)) or len(item) == 1): - raise ValueError("Can not slice with incorrect length") + if not (isinstance(item, slice | Integral) or len(item) == 1): + msg = "Can not slice with incorrect length" + raise ValueError(msg) return type(self)(self.table[item], names=self.names, @@ -733,7 +752,7 @@ def n_inputs(self): return 1 # The time table has to be one dimensional def is_scalar(self): - return self.table.shape == tuple() + return self.table.shape == () @property def frame(self): @@ -775,7 +794,8 @@ def interpolate(self, new_array_grids, **kwargs): """ if self.is_scalar(): - raise ValueError("Cannot interpolate a scalar TimeTableCoordinate.") + msg = "Cannot interpolate a scalar TimeTableCoordinate." + raise ValueError(msg) # Build pixel grids for current TimeTableCoord. old_array_grids = np.arange(len(self.table)) # Interpolate using MJD format and convert back to a Time object. @@ -808,21 +828,21 @@ class MultipleTableCoordinate(BaseTableCoordinate): operator. """ - def __init__(self, *table_coordinates): + def __init__(self, *table_coordinates) -> None: if not all(isinstance(lt, BaseTableCoordinate) and not (isinstance(lt, MultipleTableCoordinate)) for lt in table_coordinates): - raise TypeError("All arguments must be BaseTableCoordinate instances, such as QuantityTableCoordinate, " - "and not instances of MultipleTableCoordinate.") + msg = ( + "All arguments must be BaseTableCoordinate instances, such as QuantityTableCoordinate, " + "and not instances of MultipleTableCoordinate." + ) + raise TypeError(msg) self._table_coords = list(table_coordinates) - self._dropped_coords = list() + self._dropped_coords = [] - def __str__(self): + def __str__(self) -> str: classname = self.__class__.__name__ length = len(classname) + sum(len(str(t)) for t in self._table_coords) + 10 - if length > np.get_printoptions()['linewidth']: - joiner = ',\n ' + (len(classname) + 8) * ' ' - else: - joiner = ', ' + joiner = ",\n " + (len(classname) + 8) * " " if length > np.get_printoptions()["linewidth"] else ", " return f"{classname}(tables=[{joiner.join([str(t) for t in self._table_coords])}])" @@ -830,10 +850,7 @@ def __and__(self, other): if not isinstance(other, BaseTableCoordinate): return NotImplemented - if isinstance(other, MultipleTableCoordinate): - others = other._table_coords - else: - others = [other] + others = other._table_coords if isinstance(other, MultipleTableCoordinate) else [other] return type(self)(*(self._table_coords + others)) @@ -842,15 +859,16 @@ def __rand__(self, other): if not isinstance(other, BaseTableCoordinate) or isinstance(other, MultipleTableCoordinate): return NotImplemented - return type(self)(*([other] + self._table_coords)) + return type(self)(*([other, *self._table_coords])) def __getitem__(self, item): - if isinstance(item, (slice, Integral)): + if isinstance(item, slice | Integral): item = (item,) - if not len(item) == self.n_inputs: + if len(item) != self.n_inputs: + msg = f"length of the slice ({len(item)}) must match the number of coordinates {self.n_inputs}" raise ValueError( - f"length of the slice ({len(item)}) must match the number of coordinates {self.n_inputs}") + msg) new_tables = [] dropped_tables = [] @@ -892,24 +910,23 @@ def frame(self): """ if len(self._table_coords) == 1: return self._table_coords[0].frame - else: - frames = [t.frame for t in self._table_coords] + frames = [t.frame for t in self._table_coords] - # We now have to set the axes_order of all the frames so that we - # have one consistent WCS with the correct number of pixel - # dimensions. - ind = 0 - for f in frames: - new_ind = ind + f.naxes - f._axes_order = tuple(range(ind, new_ind)) - ind = new_ind + # We now have to set the axes_order of all the frames so that we + # have one consistent WCS with the correct number of pixel + # dimensions. + ind = 0 + for f in frames: + new_ind = ind + f.naxes + f._axes_order = tuple(range(ind, new_ind)) + ind = new_ind - return cf.CompositeFrame(frames) + return cf.CompositeFrame(frames) @property def dropped_world_dimensions(self): dropped_world_dimensions = defaultdict(list) - dropped_world_dimensions["world_axis_object_classes"] = dict() + dropped_world_dimensions["world_axis_object_classes"] = {} # Combine the dicts on the tables with our dict for lutc in self._table_coords: diff --git a/ndcube/extra_coords/tests/test_extra_coords.py b/ndcube/extra_coords/tests/test_extra_coords.py index 485bda918..87c29931d 100644 --- a/ndcube/extra_coords/tests/test_extra_coords.py +++ b/ndcube/extra_coords/tests/test_extra_coords.py @@ -54,9 +54,9 @@ def test_empty_ec(wcs_1d_l): # Test slice of an empty EC assert ec[0].wcs is None - assert ec.mapping == tuple() + assert ec.mapping == () assert ec.wcs is None - assert ec.keys() == tuple() + assert ec.keys() == () ec.wcs = wcs_1d_l assert ec.wcs is wcs_1d_l @@ -78,7 +78,7 @@ def test_exceptions(wcs_1d_l): ec.add(None, 0, None) with pytest.raises(KeyError): - ExtraCoords()['empty'] + ExtraCoords()["empty"] def test_mapping_setter(wcs_1d_l, wave_lut): @@ -124,12 +124,12 @@ def test_wcs_1d(wcs_1d_l): ec.wcs = wcs_1d_l ec.mapping = (0,) - assert ec.keys() == ('spectral',) + assert ec.keys() == ("spectral",) assert ec.mapping == (0,) assert ec.wcs is wcs_1d_l subec = ec[1:] - assert ec.keys() == ('spectral',) + assert ec.keys() == ("spectral",) assert ec.mapping == (0,) assert np.allclose(ec.wcs.pixel_to_world_values(1), subec.wcs.pixel_to_world_values(1)) @@ -179,7 +179,6 @@ def test_two_1d_from_lookup_tables(time_lut): """ Create ExtraCoords from both tables at once using `from_lookup_tables` with `physical_types`. """ - exposure_lut = range(10) * u.s pt = ["custom:time:creation"] @@ -329,10 +328,10 @@ def test_slice_extra_1d(time_lut, wave_lut): sec = ec[:, 3:7] assert len(sec._lookup_tables) == 2 - assert u.allclose(sec['wavey'].wcs.pixel_to_world_values(list(range(4))), - ec['wavey'].wcs.pixel_to_world_values(list(range(3, 7)))) - assert u.allclose(sec['time'].wcs.pixel_to_world_values(list(range(4))), - ec['time'].wcs.pixel_to_world_values(list(range(4)))) + assert u.allclose(sec["wavey"].wcs.pixel_to_world_values(list(range(4))), + ec["wavey"].wcs.pixel_to_world_values(list(range(3, 7)))) + assert u.allclose(sec["time"].wcs.pixel_to_world_values(list(range(4))), + ec["time"].wcs.pixel_to_world_values(list(range(4)))) def test_slice_extra_2d(time_lut, skycoord_2d_lut): @@ -343,13 +342,13 @@ def test_slice_extra_2d(time_lut, skycoord_2d_lut): sec = ec[1:5, 1:5] assert len(sec._lookup_tables) == 2 - assert u.allclose(sec['lat'].wcs.pixel_to_world_values(list(range(2)), list(range(2))), - ec['lat'].wcs.pixel_to_world_values(list(range(1, 3)), list(range(1, 3)))) - assert u.allclose(sec['lon'].wcs.pixel_to_world_values(list(range(2)), list(range(2))), - ec['lon'].wcs.pixel_to_world_values(list(range(1, 3)), list(range(1, 3)))) + assert u.allclose(sec["lat"].wcs.pixel_to_world_values(list(range(2)), list(range(2))), + ec["lat"].wcs.pixel_to_world_values(list(range(1, 3)), list(range(1, 3)))) + assert u.allclose(sec["lon"].wcs.pixel_to_world_values(list(range(2)), list(range(2))), + ec["lon"].wcs.pixel_to_world_values(list(range(1, 3)), list(range(1, 3)))) - assert u.allclose(sec['exposure_time'].wcs.pixel_to_world_values(list(range(0, 3))), - ec['exposure_time'].wcs.pixel_to_world_values(list(range(1, 4)))) + assert u.allclose(sec["exposure_time"].wcs.pixel_to_world_values(list(range(3))), + ec["exposure_time"].wcs.pixel_to_world_values(list(range(1, 4)))) def test_slice_drop_dimensions(time_lut, skycoord_2d_lut): @@ -360,21 +359,21 @@ def test_slice_drop_dimensions(time_lut, skycoord_2d_lut): sec = ec[0, :] assert len(sec._lookup_tables) == 1 - assert u.allclose(sec['lat'].wcs.pixel_to_world_values(list(range(2))), - ec['lat'].wcs.pixel_to_world_values([0, 0], list(range(2)))) - assert u.allclose(sec['lon'].wcs.pixel_to_world_values(list(range(2))), - ec['lon'].wcs.pixel_to_world_values([0, 0], list(range(2)))) + assert u.allclose(sec["lat"].wcs.pixel_to_world_values(list(range(2))), + ec["lat"].wcs.pixel_to_world_values([0, 0], list(range(2)))) + assert u.allclose(sec["lon"].wcs.pixel_to_world_values(list(range(2))), + ec["lon"].wcs.pixel_to_world_values([0, 0], list(range(2)))) sec = ec[:, 0] assert len(sec._lookup_tables) == 2 - assert u.allclose(sec['lat'].wcs.pixel_to_world_values(list(range(2))), - ec['lat'].wcs.pixel_to_world_values(list(range(2)), [0, 0])) - assert u.allclose(sec['lon'].wcs.pixel_to_world_values(list(range(2))), - ec['lon'].wcs.pixel_to_world_values(list(range(2)), [0, 0])) + assert u.allclose(sec["lat"].wcs.pixel_to_world_values(list(range(2))), + ec["lat"].wcs.pixel_to_world_values(list(range(2)), [0, 0])) + assert u.allclose(sec["lon"].wcs.pixel_to_world_values(list(range(2))), + ec["lon"].wcs.pixel_to_world_values(list(range(2)), [0, 0])) - assert u.allclose(sec['exposure_time'].wcs.pixel_to_world_values(list(range(2)), list(range(2))), - ec['exposure_time'].wcs.pixel_to_world_values(list(range(2)), list(range(2)))) + assert u.allclose(sec["exposure_time"].wcs.pixel_to_world_values(list(range(2)), list(range(2))), + ec["exposure_time"].wcs.pixel_to_world_values(list(range(2)), list(range(2)))) def test_slice_extra_twice(time_lut, wave_lut): @@ -385,14 +384,14 @@ def test_slice_extra_twice(time_lut, wave_lut): sec = ec[1:, 0] assert len(sec._lookup_tables) == 1 - assert u.allclose(sec['time'].wcs.pixel_to_world_values(list(range(0, 2))), - ec['time'].wcs.pixel_to_world_values(list(range(1, 3)))) + assert u.allclose(sec["time"].wcs.pixel_to_world_values(list(range(2))), + ec["time"].wcs.pixel_to_world_values(list(range(1, 3)))) sec = sec[1:, 0] assert len(sec._lookup_tables) == 1 - assert u.allclose(sec['time'].wcs.pixel_to_world_values(list(range(0, 2))), - ec['time'].wcs.pixel_to_world_values(list(range(2, 4)))) + assert u.allclose(sec["time"].wcs.pixel_to_world_values(list(range(2))), + ec["time"].wcs.pixel_to_world_values(list(range(2, 4)))) def test_slice_extra_1d_drop(time_lut, wave_lut): @@ -403,8 +402,8 @@ def test_slice_extra_1d_drop(time_lut, wave_lut): sec = ec[:, 3] assert len(sec._lookup_tables) == 1 - assert u.allclose(sec['time'].wcs.pixel_to_world_values(list(range(4))), - ec['time'].wcs.pixel_to_world_values(list(range(4)))) + assert u.allclose(sec["time"].wcs.pixel_to_world_values(list(range(4))), + ec["time"].wcs.pixel_to_world_values(list(range(4)))) dwd = sec.dropped_world_dimensions dwd.pop("world_axis_object_classes") @@ -420,8 +419,8 @@ def test_slice_extra_1d_drop_alter_mapping_tuple_item(time_lut, wave_lut): sec = ec[0, :] assert len(sec._lookup_tables) == 1 assert sec._lookup_tables[0][0] == (0,) - assert u.allclose(sec['wavey'].wcs.pixel_to_world_values(list(range(10))), - ec['wavey'].wcs.pixel_to_world_values(list(range(10)))) + assert u.allclose(sec["wavey"].wcs.pixel_to_world_values(list(range(10))), + ec["wavey"].wcs.pixel_to_world_values(list(range(10)))) dwd = sec.dropped_world_dimensions dwd.pop("world_axis_object_classes") @@ -437,8 +436,8 @@ def test_slice_extra_1d_drop_alter_mapping_int_item(time_lut, wave_lut): sec = ec[0] assert len(sec._lookup_tables) == 1 assert sec._lookup_tables[0][0] == (0,) - assert u.allclose(sec['wavey'].wcs.pixel_to_world_values(list(range(10))), - ec['wavey'].wcs.pixel_to_world_values(list(range(10)))) + assert u.allclose(sec["wavey"].wcs.pixel_to_world_values(list(range(10))), + ec["wavey"].wcs.pixel_to_world_values(list(range(10)))) dwd = sec.dropped_world_dimensions dwd.pop("world_axis_object_classes") @@ -449,16 +448,16 @@ def test_slice_extra_1d_drop_alter_mapping_int_item(time_lut, wave_lut): def test_dropped_dimension_reordering(): data = np.ones((3, 4, 5)) wcs_input_dict = { - 'CTYPE1': 'WAVE ', 'CUNIT1': 'Angstrom', 'CDELT1': 0.2, 'CRPIX1': 0, 'CRVAL1': 10, 'NAXIS1': 5, - 'CTYPE2': 'HPLT-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.5, 'CRPIX2': 2, 'CRVAL2': 0.5, 'NAXIS2': 4, - 'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 3} + "CTYPE1": "WAVE ", "CUNIT1": "Angstrom", "CDELT1": 0.2, "CRPIX1": 0, "CRVAL1": 10, "NAXIS1": 5, + "CTYPE2": "HPLT-TAN", "CUNIT2": "deg", "CDELT2": 0.5, "CRPIX2": 2, "CRVAL2": 0.5, "NAXIS2": 4, + "CTYPE3": "HPLN-TAN", "CUNIT3": "deg", "CDELT3": 0.4, "CRPIX3": 2, "CRVAL3": 1, "NAXIS3": 3} input_wcs = WCS(wcs_input_dict) - base_time = Time('2000-01-01', format='fits', scale='utc') - timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])]) + base_time = Time("2000-01-01", format="fits", scale="utc") + timestamps = Time([base_time + TimeDelta(60 * i, format="sec") for i in range(data.shape[0])]) my_cube = NDCube(data, input_wcs) - my_cube.extra_coords.add('time', (0,), timestamps) + my_cube.extra_coords.add("time", (0,), timestamps) # If the argument to extra_coords.add is array index then it should end up # in the first element of array_axis_physical_types diff --git a/ndcube/extra_coords/tests/test_lookup_table_coord.py b/ndcube/extra_coords/tests/test_lookup_table_coord.py index e810d5b5f..ffeb3b051 100644 --- a/ndcube/extra_coords/tests/test_lookup_table_coord.py +++ b/ndcube/extra_coords/tests/test_lookup_table_coord.py @@ -17,7 +17,7 @@ @pytest.fixture def lut_1d_distance(): lookup_table = u.Quantity(np.arange(10) * u.km) - return QuantityTableCoordinate(lookup_table, names='x') + return QuantityTableCoordinate(lookup_table, names="x") @pytest.fixture @@ -26,7 +26,7 @@ def lut_3d_distance_mesh(): u.Quantity(np.arange(10, 20) * u.km), u.Quantity(np.arange(20, 30) * u.km)) - return QuantityTableCoordinate(*lookup_table, names=['x', 'y', 'z']) + return QuantityTableCoordinate(*lookup_table, names=["x", "y", "z"]) @pytest.fixture @@ -38,7 +38,7 @@ def lut_2d_distance_no_mesh(): @pytest.fixture def lut_1d_skycoord_no_mesh(): sc = SkyCoord(range(10), range(10), unit=u.deg) - return SkyCoordTableCoordinate(sc, mesh=False, names=['lon', 'lat']) + return SkyCoordTableCoordinate(sc, mesh=False, names=["lon", "lat"]) @pytest.fixture @@ -66,7 +66,7 @@ def lut_1d_time(): "2011-01-01T00:00:10", "2011-01-01T00:00:20", "2011-01-01T00:00:30"], format="isot") - return TimeTableCoordinate(data, names='time', physical_types='time') + return TimeTableCoordinate(data, names="time", physical_types="time") @pytest.fixture @@ -85,11 +85,11 @@ def test_exceptions(): assert "All tables must have equivalent units." in str(ei) with pytest.raises(ValueError) as ei: - QuantityTableCoordinate(u.Quantity([1, 2, 3], u.nm), [1, 2, 3] * u.m, names='x') + QuantityTableCoordinate(u.Quantity([1, 2, 3], u.nm), [1, 2, 3] * u.m, names="x") assert "The number of names should match the number of world dimensions" in str(ei) with pytest.raises(ValueError) as ei: - QuantityTableCoordinate(u.Quantity([1, 2, 3], u.nm), [1, 2, 3] * u.m, physical_types='x') + QuantityTableCoordinate(u.Quantity([1, 2, 3], u.nm), [1, 2, 3] * u.m, physical_types="x") assert "The number of physical types should match the number of world dimensions" in str(ei) # Test two Time @@ -98,11 +98,11 @@ def test_exceptions(): assert "single Time object" in str(ei) with pytest.raises(ValueError) as ei: - TimeTableCoordinate(Time("2011-01-01"), names=['a', 'b']) + TimeTableCoordinate(Time("2011-01-01"), names=["a", "b"]) assert "only have one name." in str(ei) with pytest.raises(ValueError) as ei: - TimeTableCoordinate(Time("2011-01-01"), physical_types=['a', 'b']) + TimeTableCoordinate(Time("2011-01-01"), physical_types=["a", "b"]) assert "only have one physical type." in str(ei) # Test two SkyCoord @@ -111,11 +111,11 @@ def test_exceptions(): assert "single SkyCoord object" in str(ei) with pytest.raises(ValueError) as ei: - SkyCoordTableCoordinate(SkyCoord(10, 10, unit=u.deg), names='x') + SkyCoordTableCoordinate(SkyCoord(10, 10, unit=u.deg), names="x") assert "The number of names must equal number of components in the input SkyCoord: 2." in str(ei) with pytest.raises(ValueError) as ei: - SkyCoordTableCoordinate(SkyCoord(10, 10, unit=u.deg), physical_types='x') + SkyCoordTableCoordinate(SkyCoord(10, 10, unit=u.deg), physical_types="x") assert "The number of physical types must equal number of components in the input SkyCoord: 2." in str(ei) with pytest.raises(TypeError) as ei: @@ -296,12 +296,12 @@ def test_repr_str(lut_1d_time, lut_1d_wave): def test_slicing_quantity_table_coordinate(): - qtc = QuantityTableCoordinate(range(10)*u.m, names='x', physical_types='pos:x') + qtc = QuantityTableCoordinate(range(10)*u.m, names="x", physical_types="pos:x") assert u.allclose(qtc[2:8].table[0], range(2, 8)*u.m) assert u.allclose(qtc[2].table[0], 2*u.m) - assert qtc.names == ['x'] - assert qtc.physical_types == ['pos:x'] + assert qtc.names == ["x"] + assert qtc.physical_types == ["pos:x"] qtc = QuantityTableCoordinate(range(10)*u.m) @@ -309,7 +309,7 @@ def test_slicing_quantity_table_coordinate(): assert u.allclose(qtc[2].table[0], 2*u.m) qtc = QuantityTableCoordinate(range(10)*u.m, range(10)*u.m, - names=['x', 'y'], physical_types=['pos:x', 'pos:y']) + names=["x", "y"], physical_types=["pos:x", "pos:y"]) assert u.allclose(qtc[2:8, 2:8].table[0], range(2, 8)*u.m) assert u.allclose(qtc[2:8, 2:8].table[1], range(2, 8)*u.m) @@ -317,25 +317,25 @@ def test_slicing_quantity_table_coordinate(): assert len(qtc[2, 2:8].table) == 1 assert u.allclose(qtc[2, 2:8].table[0], range(2, 8)*u.m) - assert qtc.names == ['x', 'y'] - assert qtc.physical_types == ['pos:x', 'pos:y'] + assert qtc.names == ["x", "y"] + assert qtc.physical_types == ["pos:x", "pos:y"] - assert qtc.frame.axes_names == ('x', 'y') - assert qtc.frame.axis_physical_types == ('custom:pos:x', 'custom:pos:y') + assert qtc.frame.axes_names == ("x", "y") + assert qtc.frame.axis_physical_types == ("custom:pos:x", "custom:pos:y") @pytest.mark.xfail(reason=">1D Tables not supported") def test_slicing_quantity_table_coordinate_2d(): qtc = QuantityTableCoordinate(*np.mgrid[0:10, 0:10]*u.m, - names=['x', 'y'], physical_types=['pos:x', 'pos:y']) + names=["x", "y"], physical_types=["pos:x", "pos:y"]) assert u.allclose(qtc[2:8, 2:8].table[0], (np.mgrid[2:8, 2:8]*u.m)[0]) assert u.allclose(qtc[2:8, 2:8].table[1], (np.mgrid[2:8, 2:8]*u.m)[1]) - assert qtc.names == ['x', 'y'] - assert qtc.physical_types == ['pos:x', 'pos:y'] + assert qtc.names == ["x", "y"] + assert qtc.physical_types == ["pos:x", "pos:y"] - assert qtc.frame.axes_names == ('x', 'y') - assert qtc.frame.axis_physical_types == ('custom:pos:x', 'custom:pos:y') + assert qtc.frame.axes_names == ("x", "y") + assert qtc.frame.axis_physical_types == ("custom:pos:x", "custom:pos:y") assert u.allclose(qtc[2, 2:8].table[0], 2*u.m) assert u.allclose(qtc[2, 2:8].table[1], (np.mgrid[2:8, 2:8]*u.m)[1]) @@ -349,22 +349,22 @@ def _assert_skycoord_equal(sc1, sc2): components1 = tuple(getattr(sc1.data, comp) for comp in sc1.data.components) components2 = tuple(getattr(sc2.data, comp) for comp in sc2.data.components) - for c1, c2 in zip(components1, components2): + for c1, c2 in zip(components1, components2, strict=False): assert u.allclose(c1, c2) def test_slicing_skycoord_table_coordinate(): # 1D, no mesh sc = SkyCoord(range(10)*u.deg, range(10)*u.deg) - stc = SkyCoordTableCoordinate(sc, mesh=False, names=['lon', 'lat'], physical_types=['pos:x', 'pos:y']) + stc = SkyCoordTableCoordinate(sc, mesh=False, names=["lon", "lat"], physical_types=["pos:x", "pos:y"]) _assert_skycoord_equal(stc[2:8].table, sc[2:8]) _assert_skycoord_equal(stc[2].table, sc[2]) - assert stc.names == ['lon', 'lat'] - assert stc.physical_types == ['pos:x', 'pos:y'] + assert stc.names == ["lon", "lat"] + assert stc.physical_types == ["pos:x", "pos:y"] - assert stc.frame.axes_names == ('lon', 'lat') - assert stc.frame.axis_physical_types == ('custom:pos:x', 'custom:pos:y') + assert stc.frame.axes_names == ("lon", "lat") + assert stc.frame.axis_physical_types == ("custom:pos:x", "custom:pos:y") # 2D, no mesh sc = SkyCoord(*np.mgrid[0:10, 0:10]*u.deg) @@ -715,10 +715,10 @@ def test_quantity_interpolate(lut_3d_distance_mesh): def test_time_interpolate(lut_1d_time): lutc = lut_1d_time - new_array_grids = np.arange(1.5, 4,) + new_array_grids = np.arange(1.5, 4) output = lutc.interpolate(new_array_grids) - expected_table = Time(['2011-01-01T00:00:05.000', '2011-01-01T00:00:15.000', - '2011-01-01T00:00:25.000'], scale="utc", format="isot") + expected_table = Time(["2011-01-01T00:00:05.000", "2011-01-01T00:00:15.000", + "2011-01-01T00:00:25.000"], scale="utc", format="isot") assert np.allclose(output.table.mjd, expected_table.mjd) assert_lutc_ancilliary_data_same(output, lutc) diff --git a/ndcube/global_coords.py b/ndcube/global_coords.py index a15c1facf..a7e32efb6 100644 --- a/ndcube/global_coords.py +++ b/ndcube/global_coords.py @@ -29,6 +29,7 @@ class GlobalCoordsABC(Mapping): from the wcs and extra coords of the ndcube. If not specified only coordinates explicitly added will be shown. """ + @abc.abstractmethod def add(self, name: str, physical_type: str, coord: Any): """ @@ -74,7 +75,7 @@ def __iter__(self): """ @abc.abstractmethod - def __len__(self): + def __len__(self) -> int: """ Establish the length of the collection. """ @@ -83,7 +84,7 @@ def __len__(self): class GlobalCoords(GlobalCoordsABC): # Docstring in GlobalCoordsABC - def __init__(self, ndcube=None): + def __init__(self, ndcube=None) -> None: super().__init__() self._ndcube = ndcube self._internal_coords = OrderedDict() @@ -140,7 +141,8 @@ def _convert_dropped_to_internal(dropped_dimensions): elif len(rest) == 1: klass_gen = rest[0] else: - raise ValueError("Tuples in world_axis_object_classes should have length 3 or 4") + msg = "Tuples in world_axis_object_classes should have length 3 or 4" + raise ValueError(msg) high_level_object = klass_gen(*args[key], *ar, **kwargs[key], **kw) @@ -177,9 +179,12 @@ def _all_coords(self): def add(self, name, physical_type, coord): # Docstring in GlobalCoordsABC - if name in self._internal_coords.keys(): - raise ValueError("coordinate with same name already exists: " - f"{name}: {self._internal_coords[name]}") + if name in self._internal_coords: + msg = ( + "coordinate with same name already exists: " + f"{name}: {self._internal_coords[name]}" + ) + raise ValueError(msg) # Ensure the physical type is valid validate_physical_types((physical_type,)) @@ -193,7 +198,7 @@ def remove(self, name): @property def physical_types(self): # Docstring in GlobalCoordsABC - return dict((name, value[0]) for name, value in self._all_coords.items()) + return {name: value[0] for name, value in self._all_coords.items()} def filter_by_physical_type(self, physical_type): """ @@ -226,21 +231,18 @@ def __iter__(self): # Docstring in GlobalCoordsABC return iter(self._all_coords) - def __len__(self): + def __len__(self) -> int: # Docstring in GlobalCoordsABC return len(self._all_coords) - def __str__(self): + def __str__(self) -> str: classname = self.__class__.__name__ - elements = [f"{name} {[ptype]}:\n{repr(coord)}" for (name, coord), ptype in - zip(self.items(), self.physical_types.values())] + elements = [f"{name} {[ptype]}:\n{coord!r}" for (name, coord), ptype in + zip(self.items(), self.physical_types.values(), strict=False)] length = len(classname) + 2 * len(elements) + sum(len(e) for e in elements) - if length > np.get_printoptions()['linewidth']: - joiner = ',\n ' + len(classname) * ' ' - else: - joiner = ', ' + joiner = ",\n " + len(classname) * " " if length > np.get_printoptions()["linewidth"] else ", " return f"{classname}({joiner.join(elements)})" - def __repr__(self): - return f"{object.__repr__(self)}\n{str(self)}" + def __repr__(self) -> str: + return f"{object.__repr__(self)}\n{self!s}" diff --git a/ndcube/mixins/__init__.py b/ndcube/mixins/__init__.py index c680bad4f..d9b06d617 100644 --- a/ndcube/mixins/__init__.py +++ b/ndcube/mixins/__init__.py @@ -1,3 +1,3 @@ from .ndslicing import NDCubeSlicingMixin -__all__ = ['NDCubeSlicingMixin'] +__all__ = ["NDCubeSlicingMixin"] diff --git a/ndcube/mixins/ndslicing.py b/ndcube/mixins/ndslicing.py index 919b4686e..3fa8337af 100644 --- a/ndcube/mixins/ndslicing.py +++ b/ndcube/mixins/ndslicing.py @@ -2,7 +2,7 @@ from astropy.nddata.mixins.ndslicing import NDSlicingMixin from astropy.wcs.wcsapi.wrappers.sliced_wcs import sanitize_slices -__all__ = ['NDCubeSlicingMixin'] +__all__ = ["NDCubeSlicingMixin"] class NDCubeSlicingMixin(NDSlicingMixin): @@ -17,7 +17,8 @@ def __getitem__(self, item): using the kwargs returned by ``_slice``. """ if item is None or (isinstance(item, tuple) and None in item): - raise IndexError("None indices not supported") + msg = "None indices not supported" + raise IndexError(msg) item = tuple(sanitize_slices(item, len(self.shape))) sliced_cube = super().__getitem__(item) diff --git a/ndcube/ndcollection.py b/ndcube/ndcollection.py index 0b37e4e91..244e6960b 100644 --- a/ndcube/ndcollection.py +++ b/ndcube/ndcollection.py @@ -46,7 +46,7 @@ class NDCollection(dict): axis 1 of cube0 is aligned with axis 1 of cube1. """ - def __init__(self, key_data_pairs, aligned_axes=None, meta=None, **kwargs): + def __init__(self, key_data_pairs, aligned_axes=None, meta=None, **kwargs) -> None: # Enter data and metadata into object. super().__init__(key_data_pairs) self.meta = meta @@ -54,15 +54,16 @@ def __init__(self, key_data_pairs, aligned_axes=None, meta=None, **kwargs): # Convert aligned axes to required format. sanitize_inputs = kwargs.pop("sanitize_inputs", True) if aligned_axes is not None: - keys, data = zip(*key_data_pairs) + keys, data = zip(*key_data_pairs, strict=False) # Sanitize aligned axes unless hidden kwarg indicates not to. if sanitize_inputs: aligned_axes = collection_utils._sanitize_aligned_axes(keys, data, aligned_axes) else: - aligned_axes = dict(zip(keys, aligned_axes)) + aligned_axes = dict(zip(keys, aligned_axes, strict=False)) if kwargs: + msg = f"__init__() got an unexpected keyword argument: '{next(iter(kwargs.keys()))}'" raise TypeError( - f"__init__() got an unexpected keyword argument: '{list(kwargs.keys())[0]}'" + msg, ) # Attach aligned axes to object. self._aligned_axes = aligned_axes @@ -80,9 +81,9 @@ def aligned_axes(self): @property def _first_key(self): - return list(self.keys())[0] + return next(iter(self.keys())) - def __str__(self): + def __str__(self) -> str: return (textwrap.dedent(f"""\ NDCollection ------------ @@ -91,8 +92,8 @@ def __str__(self): Aligned dimensions: {self.aligned_dimensions} Aligned physical types: {self.aligned_axis_physical_types}""")) - def __repr__(self): - return f"{object.__repr__(self)}\n{str(self)}" + def __repr__(self) -> str: + return f"{object.__repr__(self)}\n{self!s}" @property def aligned_dimensions(self): @@ -105,6 +106,7 @@ def aligned_dimensions(self): return np.asanyarray(self[self._first_key].shape, dtype=object)[ np.array(self.aligned_axes[self._first_key]) ] + return None @property def aligned_axis_physical_types(self): @@ -137,40 +139,41 @@ def __getitem__(self, item): return super().__getitem__(item) # If item is not a single string... + # If item is a sequence, ensure strings and numeric items are not mixed. + item_is_strings = False + if isinstance(item, collections.abc.Sequence): + item_strings = [isinstance(item_, str) for item_ in item] + item_is_strings = all(item_strings) + # Ensure strings are not mixed with slices. + if (not item_is_strings) and (not all(np.invert(item_strings))): + msg = "Cannot mix keys and non-keys when indexing instance." + raise TypeError(msg) + + # If sequence is all strings, extract the cubes corresponding to the string keys. + if item_is_strings: + new_data = [self[_item] for _item in item] + new_keys = item + new_aligned_axes = tuple([self.aligned_axes[item_] for item_ in item]) + + # Else, the item is assumed to be a typical slicing item. + # Slice each cube in collection using information in this item. + # However, this can only be done if there are aligned axes. else: - # If item is a sequence, ensure strings and numeric items are not mixed. - item_is_strings = False - if isinstance(item, collections.abc.Sequence): - item_strings = [isinstance(item_, str) for item_ in item] - item_is_strings = all(item_strings) - # Ensure strings are not mixed with slices. - if (not item_is_strings) and (not all(np.invert(item_strings))): - raise TypeError("Cannot mix keys and non-keys when indexing instance.") - - # If sequence is all strings, extract the cubes corresponding to the string keys. - if item_is_strings: - new_data = [self[_item] for _item in item] - new_keys = item - new_aligned_axes = tuple([self.aligned_axes[item_] for item_ in item]) - - # Else, the item is assumed to be a typical slicing item. - # Slice each cube in collection using information in this item. - # However, this can only be done if there are aligned axes. - else: - if self.aligned_axes is None: - raise IndexError("Cannot slice unless collection has aligned axes.") - # Derive item to be applied to each cube in collection and - # whether any aligned axes are dropped by the slicing. - collection_items, new_aligned_axes = self._generate_collection_getitems(item) - # Apply those slice items to each cube in collection. - new_data = [self[key][tuple(cube_item)] - for key, cube_item in zip(self, collection_items)] - # Since item is not strings, no cube in collection is dropped. - # Therefore the collection keys remain unchanged. - new_keys = list(self.keys()) - - return self.__class__(list(zip(new_keys, new_data)), aligned_axes=new_aligned_axes, - meta=self.meta, sanitize_inputs=False) + if self.aligned_axes is None: + msg = "Cannot slice unless collection has aligned axes." + raise IndexError(msg) + # Derive item to be applied to each cube in collection and + # whether any aligned axes are dropped by the slicing. + collection_items, new_aligned_axes = self._generate_collection_getitems(item) + # Apply those slice items to each cube in collection. + new_data = [self[key][tuple(cube_item)] + for key, cube_item in zip(self, collection_items, strict=False)] + # Since item is not strings, no cube in collection is dropped. + # Therefore the collection keys remain unchanged. + new_keys = list(self.keys()) + + return self.__class__(list(zip(new_keys, new_data, strict=False)), aligned_axes=new_aligned_axes, + meta=self.meta, sanitize_inputs=False) def _generate_collection_getitems(self, item): # There are 3 supported cases of the slice item: int, slice, tuple of ints and/or slices. @@ -203,7 +206,8 @@ def _generate_collection_getitems(self, item): elif isinstance(item, tuple): # Ensure item is not longer than number of aligned axes if len(item) > self.n_aligned_axes: - raise IndexError("Too many indices") + msg = "Too many indices" + raise IndexError(msg) for i, axis_item in enumerate(item): if isinstance(axis_item, int): drop_aligned_axes_indices.append(i) @@ -211,7 +215,8 @@ def _generate_collection_getitems(self, item): collection_items[j][self.aligned_axes[key][i]] = axis_item else: - raise TypeError(f"Unsupported slicing type: {axis_item}") + msg = f"Unsupported slicing type: {axis_item}" + raise TypeError(msg) # Use indices of dropped axes determine above to update aligned_axes # by removing any that have been dropped. @@ -228,11 +233,13 @@ def copy(self): def setdefault(self): """Not supported by `~ndcube.NDCollection`""" - raise NotImplementedError("NDCollection does not support setdefault.") + msg = "NDCollection does not support setdefault." + raise NotImplementedError(msg) def popitem(self): """Not supported by `~ndcube.NDCollection`""" - raise NotImplementedError("NDCollection does not support popitem.") + msg = "NDCollection does not support popitem." + raise NotImplementedError(msg) def pop(self, key): """ @@ -261,13 +268,13 @@ def update(self, *args): # If two inputs, inputs must be key_data_pairs and aligned_axes. if len(args) == 2: key_data_pairs = args[0] - new_keys, new_data = zip(*key_data_pairs) + new_keys, new_data = zip(*key_data_pairs, strict=False) new_aligned_axes = collection_utils._sanitize_aligned_axes(new_keys, new_data, args[1]) else: # If one arg given, input must be NDCollection. collection = args[0] new_keys = list(collection.keys()) new_data = list(collection.values()) - key_data_pairs = zip(new_keys, new_data) + key_data_pairs = zip(new_keys, new_data, strict=False) new_aligned_axes = collection.aligned_axes # Check aligned axes of new inputs are compatible with those in self. @@ -276,17 +283,20 @@ def update(self, *args): first_new_aligned_axes = new_aligned_axes[new_keys[0]] if new_aligned_axes is not None else None collection_utils.assert_aligned_axes_compatible( self[self._first_key].shape, new_data[0].shape, - first_old_aligned_axes, first_new_aligned_axes + first_old_aligned_axes, first_new_aligned_axes, ) # Update collection super().update(key_data_pairs) if first_old_aligned_axes is not None: # since the above assertion passed, if one aligned axes is not None, both are not None self.aligned_axes.update(new_aligned_axes) - def __delitem__(self, key): + def __delitem__(self, key) -> None: super().__delitem__(key) self.aligned_axes.__delitem__(key) - def __setitem__(self, key, value): - raise NotImplementedError("NDCollection does not support __setitem__. " - "Use NDCollection.update instead") + def __setitem__(self, key, value) -> None: + msg = ( + "NDCollection does not support __setitem__. " + "Use NDCollection.update instead" + ) + raise NotImplementedError(msg) diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index a9827b474..e2948f9d1 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -33,7 +33,7 @@ from ndcube.visualization import PlotterDescriptor from ndcube.wcs.wrappers import CompoundLowLevelWCS, ResampledLowLevelWCS -__all__ = ['NDCubeABC', 'NDCubeLinkedDescriptor'] +__all__ = ["NDCubeABC", "NDCubeLinkedDescriptor"] # Create mapping to masked array types based on data array type for use in analysis methods. ARRAY_MASK_MAP = {} @@ -95,7 +95,7 @@ def array_axis_physical_types(self) -> Iterable[tuple[str, ...]]: def axis_world_coords(self, *axes: int | str, pixel_corners: bool = False, - wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None + wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None, ) -> Iterable[Any]: """ Returns objects representing the world coordinates of pixel centers for a desired axes. @@ -117,6 +117,7 @@ def axis_world_coords(self, ``self.wcs``, ``self.extra_coords``, or ``self.combined_wcs`` combining both the WCS and extra coords. Default=self.wcs + Returns ------- axes_coords: iterable @@ -128,6 +129,7 @@ def axis_world_coords(self, their corresponding array dimensions, unless ``pixel_corners=True`` in which case the length along each axis will be 1 greater than the number of pixels. + Examples -------- >>> NDCube.axis_world_coords('lat', 'lon') # doctest: +SKIP @@ -139,7 +141,7 @@ def axis_world_coords(self, def axis_world_coords_values(self, *axes: int | str, pixel_corners: bool = False, - wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None + wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None, ) -> Iterable[u.Quantity]: """ Returns the world coordinate values of all pixels for desired axes. @@ -237,7 +239,7 @@ def crop_by_values(self, *points: Iterable[u.Quantity | float], units: Iterable[str | u.Unit] | None = None, wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None, - keepdims: bool = False + keepdims: bool = False, ) -> "NDCubeABC": """ Crop using real world coordinates. @@ -290,7 +292,7 @@ class NDCubeLinkedDescriptor: A descriptor which gives the property a reference to the cube to which it is attached. """ - def __init__(self, default_type): + def __init__(self, default_type) -> None: self._default_type = default_type self._property_name = None @@ -309,22 +311,25 @@ def __set_name__(self, owner, name): def __get__(self, obj, objtype=None): if obj is None: - return + return None if getattr(obj, self._attribute_name, None) is None and self._default_type is not None: self.__set__(obj, self._default_type) return getattr(obj, self._attribute_name) - def __set__(self, obj, value): + def __set__(self, obj, value) -> None: if isinstance(value, self._default_type): value._ndcube = obj elif issubclass(value, self._default_type): value = value(obj) else: - raise ValueError( + msg = ( f"Unable to set value for {self._property_name} it should " - f"be an instance or subclass of {self._default_type}") + f"be an instance or subclass of {self._default_type}" + ) + raise ValueError( + msg) setattr(obj, self._attribute_name, value) @@ -371,19 +376,21 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin): Default is `False`. """ + # Instances of Extra and Global coords are managed through descriptors _extra_coords = NDCubeLinkedDescriptor(ExtraCoords) _global_coords = NDCubeLinkedDescriptor(GlobalCoords) def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None, - unit=None, copy=False, **kwargs): + unit=None, copy=False, **kwargs) -> None: super().__init__(data, wcs=wcs, uncertainty=uncertainty, mask=mask, meta=meta, unit=unit, copy=copy, **kwargs) # Enforce that the WCS object is not None if self.wcs is None: - raise TypeError("The WCS argument can not be None.") + msg = "The WCS argument can not be None." + raise TypeError(msg) # Get existing extra_coords if initializing from an NDCube if hasattr(data, "extra_coords"): @@ -417,7 +424,7 @@ def combined_wcs(self): mapping = list(range(self.wcs.pixel_n_dim)) + list(self.extra_coords.mapping) return HighLevelWCSWrapper( - CompoundLowLevelWCS(self.wcs.low_level_wcs, self._extra_coords.wcs, mapping=mapping) + CompoundLowLevelWCS(self.wcs.low_level_wcs, self._extra_coords.wcs, mapping=mapping), ) @property @@ -474,12 +481,12 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) # First construct a range of pixel indices for this set of coupled dimensions sub_range = [ranges[idx] for idx in pixel_axes_indices] # Then get a set of non correlated dimensions - non_corr_axes = set(list(range(wcs.pixel_n_dim))) - set(pixel_axes_indices) + non_corr_axes = set(range(wcs.pixel_n_dim)) - set(pixel_axes_indices) # And inject 0s for those coordinates for idx in non_corr_axes: sub_range.insert(idx, 0) # Generate a grid of broadcastable pixel indices for all pixel dimensions - grid = np.meshgrid(*sub_range, indexing='ij') + grid = np.meshgrid(*sub_range, indexing="ij") # Convert to world coordinates world = wcs.pixel_to_world_values(*grid) # TODO: this isinstance check is to mitigate https://github.com/spacetelescope/gwcs/pull/332 @@ -494,7 +501,7 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) world_coords[idx] = tmp_world if units: - for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)): + for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units, strict=False)): world_coords[i] = coord << u.Unit(unit) return world_coords @@ -509,7 +516,7 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): if isinstance(wcs, ExtraCoords): wcs = wcs.wcs if not wcs: - return tuple() + return () object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components]) unique_obj_names = utils.misc.unique_sorted(object_names) @@ -523,7 +530,7 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) object_indices = utils.misc.unique_sorted( - [world_index_to_object_index[world_index] for world_index in world_indices] + [world_index_to_object_index[world_index] for world_index in world_indices], ) axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False) @@ -581,24 +588,29 @@ def _get_crop_item(self, *points, wcs=None, keepdims=False): # Quit out early if we are no-op if no_op: return tuple([slice(None)] * wcs.pixel_n_dim) - else: - comp = [c[0] for c in wcs.world_axis_object_components] - # Trim to unique component names - `np.unique(..., return_index=True) - # keeps sorting alphabetically, set() seems just nondeterministic. - for k, c in enumerate(comp): - if comp.count(c) > 1: - comp.pop(k) - classes = [wcs.world_axis_object_classes[c][0] for c in comp] - for i, point in enumerate(points): - if len(point) != len(comp): - raise ValueError(f"{len(point)} components in point {i} do not match " - f"WCS with {len(comp)} components.") - for j, value in enumerate(point): - if not (value is None or isinstance(value, classes[j])): - raise TypeError(f"{type(value)} of component {j} in point {i} is " - f"incompatible with WCS component {comp[j]} " - f"{classes[j]}.") - return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims=keepdims) + comp = [c[0] for c in wcs.world_axis_object_components] + # Trim to unique component names - `np.unique(..., return_index=True) + # keeps sorting alphabetically, set() seems just nondeterministic. + for k, c in enumerate(comp): + if comp.count(c) > 1: + comp.pop(k) + classes = [wcs.world_axis_object_classes[c][0] for c in comp] + for i, point in enumerate(points): + if len(point) != len(comp): + msg = ( + f"{len(point)} components in point {i} do not match " + f"WCS with {len(comp)} components." + ) + raise ValueError(msg) + for j, value in enumerate(point): + if not (value is None or isinstance(value, classes[j])): + msg = ( + f"{type(value)} of component {j} in point {i} is " + f"incompatible with WCS component {comp[j]} " + f"{classes[j]}." + ) + raise TypeError(msg) + return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims=keepdims) def crop_by_values(self, *points, units=None, wcs=None, keepdims=False): # The docstring is defined in NDCubeABC @@ -618,31 +630,41 @@ def _get_crop_by_values_item(self, *points, units=None, wcs=None, keepdims=False if units is None: units = [None] * n_coords elif len(units) != n_coords: - raise ValueError(f"Units must be None or have same length {n_coords} as corner inputs.") + msg = f"Units must be None or have same length {n_coords} as corner inputs." + raise ValueError(msg) types_with_units = (u.Quantity, type(None)) for i, point in enumerate(points): if len(point) != wcs.world_n_dim: - raise ValueError(f"{len(point)} dimensions in point {i} do not match " - f"WCS with {wcs.world_n_dim} world dimensions.") - for j, (value, unit) in enumerate(zip(point, units)): + msg = ( + f"{len(point)} dimensions in point {i} do not match " + f"WCS with {wcs.world_n_dim} world dimensions." + ) + raise ValueError(msg) + for j, (value, unit) in enumerate(zip(point, units, strict=False)): value_is_float = not isinstance(value, types_with_units) if value_is_float: if unit is None: - raise TypeError( + msg = ( "If an element of a point is not a Quantity or None, " "the corresponding unit must be a valid astropy Unit or unit string." - f"index: {i}; coord type: {type(value)}; unit: {unit}") + f"index: {i}; coord type: {type(value)}; unit: {unit}" + ) + raise TypeError( + msg) points[i][j] = u.Quantity(value, unit=unit) if value is not None: try: - points[i][j] = points[i][j].to(wcs.world_axis_units[j]) + points[i][j] = point[j].to(wcs.world_axis_units[j]) except UnitsError as err: - raise UnitsError(f"Unit '{points[i][j].unit}' of coordinate object {j} in point {i} is " - f"incompatible with WCS unit '{wcs.world_axis_units[j]}'") from err + msg = ( + f"Unit '{point[j].unit}' of coordinate object {j} in point {i} is " + f"incompatible with WCS unit '{wcs.world_axis_units[j]}'" + ) + raise UnitsError(msg) from err return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims=keepdims) - def __str__(self): + def __str__(self) -> str: return textwrap.dedent(f"""\ NDCube ------ @@ -651,8 +673,8 @@ def __str__(self): Unit: {self.unit} Data Type: {self.data.dtype}""") - def __repr__(self): - return f"{object.__repr__(self)}\n{str(self)}" + def __repr__(self) -> str: + return f"{object.__repr__(self)}\n{self!s}" def explode_along_axis(self, axis): """ @@ -687,7 +709,7 @@ def explode_along_axis(self, axis): # Creating a new NDCubeSequence with the result_cubes and common axis as axis return NDCubeSequence(result_cubes, meta=self.meta) - def reproject_to(self, target_wcs, algorithm='interpolation', shape_out=None, return_footprint=False, **reproject_args): + def reproject_to(self, target_wcs, algorithm="interpolation", shape_out=None, return_footprint=False, **reproject_args): """ Reprojects the instance to the coordinates described by another WCS object. @@ -730,7 +752,6 @@ def reproject_to(self, target_wcs, algorithm='interpolation', shape_out=None, re See Also -------- - reproject.reproject_interp reproject.reproject_adaptive reproject.reproject_exact @@ -744,7 +765,8 @@ def reproject_to(self, target_wcs, algorithm='interpolation', shape_out=None, re from reproject import reproject_adaptive, reproject_exact, reproject_interp from reproject.wcs_utils import has_celestial except ModuleNotFoundError: - raise ImportError(f"The {type(self).__name__}.reproject_to method requires the `reproject` library to be installed.") + msg = f"The {type(self).__name__}.reproject_to method requires the `reproject` library to be installed." + raise ImportError(msg) algorithms = { "interpolation": reproject_interp, @@ -752,34 +774,41 @@ def reproject_to(self, target_wcs, algorithm='interpolation', shape_out=None, re "exact": reproject_exact, } - if algorithm not in algorithms.keys(): - raise ValueError(f"{algorithm=} is not valid, it must be one of {', '.join(algorithms.keys())}.") + if algorithm not in algorithms: + msg = f"{algorithm=} is not valid, it must be one of {', '.join(algorithms.keys())}." + raise ValueError(msg) if isinstance(target_wcs, Mapping): target_wcs = WCS(header=target_wcs) - low_level_target_wcs = utils.wcs.get_low_level_wcs(target_wcs, 'target_wcs') + low_level_target_wcs = utils.wcs.get_low_level_wcs(target_wcs, "target_wcs") # 'adaptive' and 'exact' algorithms work only on 2D celestial WCS. - if algorithm == 'adaptive' or algorithm == 'exact': + if algorithm in ("adaptive", "exact"): if low_level_target_wcs.pixel_n_dim != 2 or low_level_target_wcs.world_n_dim != 2: - raise ValueError('For adaptive and exact algorithms, target_wcs must be 2D.') + msg = "For adaptive and exact algorithms, target_wcs must be 2D." + raise ValueError(msg) if not has_celestial(target_wcs): - raise ValueError('For adaptive and exact algorithms, ' - 'target_wcs must contain celestial axes only.') + msg = ( + "For adaptive and exact algorithms, " + "target_wcs must contain celestial axes only." + ) + raise ValueError(msg) if not utils.wcs.compare_wcs_physical_types(self.wcs, target_wcs): - raise ValueError('Given target_wcs is not compatible with this NDCube, the physical types do not match.') + msg = "Given target_wcs is not compatible with this NDCube, the physical types do not match." + raise ValueError(msg) # TODO: Upstream this check into reproject # If shape_out is not specified explicitly, # try to extract it from the low level WCS if not shape_out: - if hasattr(low_level_target_wcs, 'array_shape') and low_level_target_wcs.array_shape is not None: + if hasattr(low_level_target_wcs, "array_shape") and low_level_target_wcs.array_shape is not None: shape_out = low_level_target_wcs.array_shape else: - raise ValueError("shape_out must be specified if target_wcs does not have the array_shape attribute.") + msg = "shape_out must be specified if target_wcs does not have the array_shape attribute." + raise ValueError(msg) data = algorithms[algorithm](self, output_projection=target_wcs, @@ -845,6 +874,7 @@ class NDCube(NDCubeBase): Default is False. """ + # Enabling the NDCube reflected operators is a bit subtle. The NDCube # reflected operator will be used only if the Quantity non-reflected operator # returns NotImplemented. The Quantity operator strips the unit from the @@ -867,14 +897,13 @@ class NDCube(NDCubeBase): def _as_mpl_axes(self): if hasattr(self.plotter, "_as_mpl_axes"): return self.plotter._as_mpl_axes() - else: - warn_user(f"The current plotter {self.plotter} does not have a '_as_mpl_axes' method. " - "The default MatplotlibPlotter._as_mpl_axes method will be used instead.") + warn_user(f"The current plotter {self.plotter} does not have a '_as_mpl_axes' method. " + "The default MatplotlibPlotter._as_mpl_axes method will be used instead.") - from ndcube.visualization.mpl_plotter import MatplotlibPlotter + from ndcube.visualization.mpl_plotter import MatplotlibPlotter - plotter = MatplotlibPlotter(self) - return plotter._as_mpl_axes() + plotter = MatplotlibPlotter(self) + return plotter._as_mpl_axes() def plot(self, *args, **kwargs): """ @@ -886,21 +915,24 @@ def plot(self, *args, **kwargs): """ if self.plotter is None: - raise NotImplementedError( + msg = ( "This NDCube object does not have a .plotter defined so " - "no default plotting functionality is available.") + "no default plotting functionality is available." + ) + raise NotImplementedError( + msg) return self.plotter.plot(*args, **kwargs) def _new_instance(self, **kwargs): - keys = ('unit', 'wcs', 'mask', 'meta', 'uncertainty', 'psf') + keys = ("unit", "wcs", "mask", "meta", "uncertainty", "psf") new_kwargs = {k: deepcopy(getattr(self, k, None)) for k in keys} # To support old versions of astropy, we need to make sure # we only pass in the parameters that are valid for the NDData params = list(inspect.signature(astropy.nddata.NDData).parameters) full_kwargs = {x: new_kwargs.pop(x) for x in params & new_kwargs.keys()} # We Explicitly DO NOT deepcopy any data - full_kwargs['data'] = self.data + full_kwargs["data"] = self.data full_kwargs.update(kwargs) new_cube = type(self)(**full_kwargs) if self.extra_coords is not None: @@ -913,20 +945,21 @@ def __neg__(self): return self._new_instance(data=-self.data) def __add__(self, value): - if hasattr(value, 'unit'): + if hasattr(value, "unit"): if isinstance(value, u.Quantity): # NOTE: if the cube does not have units, we cannot # perform arithmetic between a unitful quantity. # This forces a conversion to a dimensionless quantity # so that an error is thrown if value is not dimensionless - cube_unit = u.Unit('') if self.unit is None else self.unit + cube_unit = u.Unit("") if self.unit is None else self.unit new_data = self.data + value.to_value(cube_unit) else: # NOTE: This explicitly excludes other NDCube objects and NDData objects # which could carry a different WCS than the NDCube return NotImplemented elif self.unit not in (None, u.Unit("")): - raise TypeError("Cannot add a unitless object to an NDCube with a unit.") + msg = "Cannot add a unitless object to an NDCube with a unit." + raise TypeError(msg) else: new_data = self.data + value return self._new_instance(data=new_data) @@ -941,12 +974,12 @@ def __rsub__(self, value): return self.__neg__().__add__(value) def __mul__(self, value): - if hasattr(value, 'unit'): + if hasattr(value, "unit"): if isinstance(value, u.Quantity): # NOTE: if the cube does not have units, set the unit # to dimensionless such that we can perform arithmetic # between the two. - cube_unit = u.Unit('') if self.unit is None else self.unit + cube_unit = u.Unit("") if self.unit is None else self.unit value_unit = value.unit value = value.to_value() new_unit = cube_unit * value_unit @@ -957,8 +990,7 @@ def __mul__(self, value): new_data = self.data * value new_uncertainty = (type(self.uncertainty)(self.uncertainty.array * value) if self.uncertainty is not None else None) - new_cube = self._new_instance(data=new_data, unit=new_unit, uncertainty=new_uncertainty) - return new_cube + return self._new_instance(data=new_data, unit=new_unit, uncertainty=new_uncertainty) def __rmul__(self, value): return self.__mul__(value) @@ -985,7 +1017,7 @@ def __pow__(self, value): new_uncertainty = None warn_user(f"{e.args[0]} Setting uncertainties to None.") else: - raise e + raise return self._new_instance(data=new_data, unit=new_unit, uncertainty=new_uncertainty) @@ -1138,11 +1170,15 @@ def my_propagate(uncertainty, data, mask, **kwargs): data_shape = self.shape naxes = len(data_shape) if len(bin_shape) != naxes: - raise ValueError("bin_shape must have an entry for each array axis.") + msg = "bin_shape must have an entry for each array axis." + raise ValueError(msg) if (np.mod(data_shape, bin_shape) != 0).any(): - raise ValueError( + msg = ( "bin shape must be an integer fraction of the data shape in each dimension. " - f"data shape: {data_shape}; bin shape: {bin_shape}") + f"data shape: {data_shape}; bin shape: {bin_shape}" + ) + raise ValueError( + msg) # Reshape array so odd dimensions represent pixels to be binned # then apply function over those axes. @@ -1160,7 +1196,7 @@ def my_propagate(uncertainty, data, mask, **kwargs): new_data = new_data.data if handle_mask is None: new_mask = None - elif isinstance(self.mask, (type(None), bool)): # Preserve original mask type. + elif isinstance(self.mask, type(None) | bool): # Preserve original mask type. new_mask = self.mask else: reshaped_mask = self.mask.reshape(reshape) @@ -1192,7 +1228,7 @@ def my_propagate(uncertainty, data, mask, **kwargs): # in each bin can be iterated (all bins being treated in parallel) and # their uncertainties propagated. bin_size = bin_shape.prod() - flat_shape = [bin_size] + list(new_shape) + flat_shape = [bin_size, *list(new_shape)] dummy_axes = tuple(range(1, len(reshape), 2)) flat_data = np.moveaxis(reshaped_data, dummy_axes, tuple(range(naxes))) flat_data = flat_data.reshape(flat_shape) @@ -1222,7 +1258,7 @@ def my_propagate(uncertainty, data, mask, **kwargs): uncertainty=new_uncertainty, mask=new_mask, meta=self.meta, - unit=new_unit + unit=new_unit, ) new_cube._global_coords = self._global_coords # Reconstitute extra coords @@ -1257,11 +1293,13 @@ def squeeze(self, axis=None): axis = (axis,) axis = np.asarray(axis) if not (shape[axis] == 1).all(): - raise ValueError("Cannot select any axis to squeeze out, as none of them has size equal to one.") + msg = "Cannot select any axis to squeeze out, as none of them has size equal to one." + raise ValueError(msg) item[axis] = 0 # Scalar NDCubes are not supported, so we raise error as the operation would cause all the axes to be squeezed. if (item == 0).all(): - raise ValueError("All axes are of length 1, therefore we will not squeeze NDCube to become a scalar. Use `axis=` keyword to specify a subset of axes to squeeze.") + msg = "All axes are of length 1, therefore we will not squeeze NDCube to become a scalar. Use `axis=` keyword to specify a subset of axes to squeeze." + raise ValueError(msg) return self[tuple(item)] @@ -1269,11 +1307,10 @@ def _create_masked_array_for_rebinning(data, mask, operation_ignores_mask): m = None if (mask is None or mask is False or operation_ignores_mask) else mask if m is None: return data, m + for array_type, masked_type in ARRAY_MASK_MAP.items(): + if isinstance(data, array_type): + break else: - for array_type, masked_type in ARRAY_MASK_MAP.items(): - if isinstance(data, array_type): - break - else: - masked_type = np.ma.masked_array - warn_user("data and mask arrays of different or unrecognized types. Casting them into a numpy masked array.") - return masked_type(data, m), m + masked_type = np.ma.masked_array + warn_user("data and mask arrays of different or unrecognized types. Casting them into a numpy masked array.") + return masked_type(data, m), m diff --git a/ndcube/ndcube_sequence.py b/ndcube/ndcube_sequence.py index 35d541c7c..65ae5fd4e 100644 --- a/ndcube/ndcube_sequence.py +++ b/ndcube/ndcube_sequence.py @@ -34,7 +34,7 @@ class NDCubeSequenceBase: were a single cube concatenated along the common axis. """ - def __init__(self, data_list, meta=None, common_axis=None, **kwargs): + def __init__(self, data_list, meta=None, common_axis=None, **kwargs) -> None: self.data = data_list self.meta = meta if common_axis is not None: @@ -56,7 +56,7 @@ def shape(self): @property def _shape(self): - dimensions = [len(self.data)] + list(self.data[0].data.shape) + dimensions = [len(self.data), *list(self.data[0].data.shape)] if len(dimensions) > 1: # If there is a common axis, length of cube's along it may not # be the same. Therefore if the lengths are different, @@ -74,7 +74,7 @@ def array_axis_physical_types(self): """ The physical types associated with each array axis, including the sequence axis. """ - return [("meta.obs.sequence",)] + self.data[0].array_axis_physical_types + return [("meta.obs.sequence",), *self.data[0].array_axis_physical_types] @property def cube_like_dimensions(self): @@ -83,7 +83,8 @@ def cube_like_dimensions(self): """ warn_deprecated("Replaced by ndcube.NDCubeSequence.cube_like_shape") if not isinstance(self._common_axis, int): - raise TypeError("Common axis must be set.") + msg = "Common axis must be set." + raise TypeError(msg) dimensions = list(self._dimensions) cube_like_dimensions = list(self._shape[1:]) if dimensions[self._common_axis + 1].isscalar: @@ -92,8 +93,7 @@ def cube_like_dimensions(self): else: cube_like_dimensions[self._common_axis] = sum(dimensions[self._common_axis + 1]) # Combine into single Quantity - cube_like_dimensions = u.Quantity(cube_like_dimensions, unit=u.pix) - return cube_like_dimensions + return u.Quantity(cube_like_dimensions, unit=u.pix) @property def cube_like_shape(self): @@ -101,7 +101,8 @@ def cube_like_shape(self): The length of each array axis as if all cubes were concatenated along the common axis. """ if not isinstance(self._common_axis, int): - raise TypeError("Common axis must be set.") + msg = "Common axis must be set." + raise TypeError(msg) dimensions = list(self.shape) cube_like_shape = list(self._shape[1:]) if isinstance(dimensions[self._common_axis + 1], numbers.Integral): @@ -116,7 +117,8 @@ def cube_like_array_axis_physical_types(self): The physical types associated with each array axis, omitting the sequence axis. """ if self._common_axis is None: - raise ValueError("Common axis must be set.") + msg = "Common axis must be set." + raise ValueError(msg) return self.data[0].array_axis_physical_types def __getitem__(self, item): @@ -158,7 +160,8 @@ def index_as_cube(self): >>> cs.index_as_cube[3:6, 0, :] # doctest: +SKIP """ if self._common_axis is None: - raise ValueError("common_axis cannot be None") + msg = "common_axis cannot be None" + raise ValueError(msg) return _IndexAsCubeSlicer(self) @property @@ -208,8 +211,8 @@ def sequence_axis_coords(self): # Collect names of global coords common to all cubes. global_names = set.intersection(*[set(cube.global_coords.keys()) for cube in self.data]) # For each coord, combine values from each cube's global coords property. - return dict([(name, [cube.global_coords[name] for cube in self.data]) - for name in global_names]) + return {name: [cube.global_coords[name] for cube in self.data] + for name in global_names} def explode_along_axis(self, axis): """ @@ -368,7 +371,7 @@ def _get_sequence_crop_item(self, *points, wcses=None, crop_by_values=False, uni wcses = "wcs" if isinstance(wcses, str): wcses = [wcses] * n_cubes - for i, (cube, wcs) in enumerate(zip(self.data, wcses)): + for i, (cube, wcs) in enumerate(zip(self.data, wcses, strict=False)): # For each cube, determine the range of array indices in each dimension # corresponding to the input world corners. if isinstance(wcs, str): @@ -385,9 +388,9 @@ def _get_sequence_crop_item(self, *points, wcses=None, crop_by_values=False, uni starts = starts.min(axis=0) stops = stops.max(axis=0) return tuple( - [slice(0, n_cubes)] + [slice(start, stop) for start, stop in zip(starts, stops)]) + [slice(0, n_cubes)] + [slice(start, stop) for start, stop in zip(starts, stops, strict=False)]) - def __str__(self): + def __str__(self) -> str: return (textwrap.dedent(f"""\ NDCubeSequence -------------- @@ -395,10 +398,10 @@ def __str__(self): Physical Types of Axes: {self.array_axis_physical_types} Common Cube Axis: {self._common_axis}""")) - def __repr__(self): - return f"{object.__repr__(self)}\n{str(self)}" + def __repr__(self) -> str: + return f"{object.__repr__(self)}\n{self!s}" - def __len__(self): + def __len__(self) -> int: return len(self.data) def __iter__(self): @@ -434,6 +437,7 @@ class NDCubeSequence(NDCubeSequenceBase): `ndcube.NDCubeSequence.index_as_cube` which slices the sequence as though it were a single cube concatenated along the common axis. """ + # We special case the default mpl plotter here so that we can only import # matplotlib when `.plotter` is accessed and raise an ImportError at the # last moment. @@ -449,20 +453,26 @@ def plot(self, *args, **kwargs): """ if self.plotter is None: - raise NotImplementedError( + msg = ( "This NDCubeSequence object does not have a .plotter defined so " - "no default plotting functionality is available.") + "no default plotting functionality is available." + ) + raise NotImplementedError( + msg) return self.plotter.plot(*args, **kwargs) def plot_as_cube(self, *args, **kwargs): - raise NotImplementedError( + msg = ( "NDCubeSequence plot_as_cube is no longer supported.\n" "To learn why or to tell us why it should be re-instated, " "read and comment on issue #315:\n\nhttps://github.com/sunpy/ndcube/issues/315\n\n" "To see a introductory guide on how to make your own NDCubeSequence plots, " "see the docs:\n\n" - "https://docs.sunpy.org/projects/ndcube/en/stable/ndcubesequence.html#plotting") + "https://docs.sunpy.org/projects/ndcube/en/stable/ndcubesequence.html#plotting" + ) + raise NotImplementedError( + msg) """ @@ -481,7 +491,7 @@ class _IndexAsCubeSlicer: Object of NDCubeSequence. """ - def __init__(self, seq): + def __init__(self, seq) -> None: self.seq = seq def __getitem__(self, item): @@ -492,7 +502,7 @@ def __getitem__(self, item): # If item is iint or slice, turn into a tuple, filling in items # for unincluded axes with slice(None). This ensures it is # treated the same as tuple items. - if isinstance(item, (numbers.Integral, slice)): + if isinstance(item, numbers.Integral | slice): item = [item] + [slice(None)] * n_uncommon_cube_dims else: # Item must therefore be tuple. Ensure it has an entry for each axis. @@ -500,7 +510,7 @@ def __getitem__(self, item): # If common axis item is slice(None), result is trivial as common_axis is not changed. if item[common_axis] == slice(None): # Create item for slicing through the default API and slice. - return self.seq[tuple([slice(None)] + item)] + return self.seq[(slice(None), *item)] if isinstance(item[common_axis], numbers.Integral): # If common_axis item is an int or return an NDCube with dimensionality of N-1 sequence_index, common_axis_index = \ @@ -510,25 +520,24 @@ def __getitem__(self, item): cube_item = copy.deepcopy(item) cube_item[common_axis] = common_axis_index return self.seq.data[sequence_index][tuple(cube_item)] - else: - # item can now only be a tuple whose common axis item is a non-None slice object. - # Convert item into iterable of SequenceItems and slice each cube appropriately. - # item for common_axis must always be a slice for every cube, - # even if it is only a length-1 slice. - # Thus NDCubeSequence.index_as_cube can only slice away common axis if - # item is int or item's first item is an int. - # i.e. NDCubeSequence.index_as_cube cannot cause common_axis to become None - # since in all cases where the common_axis is sliced away involve an NDCube - # is returned, not an NDCubeSequence. - # common_axis of returned sequence must be altered if axes in front of it - # are sliced away. - sequence_items = utils.sequence.cube_like_tuple_item_to_sequence_items( - item, common_axis, common_axis_lengths, n_cube_dims) - # Work out new common axis value if axes in front of it are sliced away. - new_common_axis = common_axis - sum([isinstance(i, numbers.Integral) - for i in item[:common_axis]]) - # Copy sequence and alter the data and common axis. - result = type(self.seq)([], meta=self.seq.meta, common_axis=new_common_axis) - result.data = [self.seq.data[sequence_item.sequence_index][sequence_item.cube_item] - for sequence_item in sequence_items] - return result + # item can now only be a tuple whose common axis item is a non-None slice object. + # Convert item into iterable of SequenceItems and slice each cube appropriately. + # item for common_axis must always be a slice for every cube, + # even if it is only a length-1 slice. + # Thus NDCubeSequence.index_as_cube can only slice away common axis if + # item is int or item's first item is an int. + # i.e. NDCubeSequence.index_as_cube cannot cause common_axis to become None + # since in all cases where the common_axis is sliced away involve an NDCube + # is returned, not an NDCubeSequence. + # common_axis of returned sequence must be altered if axes in front of it + # are sliced away. + sequence_items = utils.sequence.cube_like_tuple_item_to_sequence_items( + item, common_axis, common_axis_lengths, n_cube_dims) + # Work out new common axis value if axes in front of it are sliced away. + new_common_axis = common_axis - sum([isinstance(i, numbers.Integral) + for i in item[:common_axis]]) + # Copy sequence and alter the data and common axis. + result = type(self.seq)([], meta=self.seq.meta, common_axis=new_common_axis) + result.data = [self.seq.data[sequence_item.sequence_index][sequence_item.cube_item] + for sequence_item in sequence_items] + return result diff --git a/ndcube/tests/helpers.py b/ndcube/tests/helpers.py index e82215db0..f1549247e 100644 --- a/ndcube/tests/helpers.py +++ b/ndcube/tests/helpers.py @@ -20,13 +20,13 @@ from ndcube import NDCube, NDCubeSequence -__all__ = ['figure_test', - 'get_hash_library_name', - 'assert_extra_coords_equal', - 'assert_metas_equal', - 'assert_cubes_equal', - 'assert_cubesequences_equal', - 'assert_wcs_are_equal'] +__all__ = ["figure_test", + "get_hash_library_name", + "assert_extra_coords_equal", + "assert_metas_equal", + "assert_cubes_equal", + "assert_cubesequences_equal", + "assert_wcs_are_equal"] def get_hash_library_name(): @@ -34,9 +34,9 @@ def get_hash_library_name(): Generate the hash library name for this env. """ ft2_version = f"{mpl.ft2font.__freetype_version__.replace('.', '')}" - animators_version = "dev" if (("dev" in mpl_animators.__version__) or ("rc" in mpl_animators.__version__)) else mpl_animators.__version__.replace('.', '') - mpl_version = "dev" if (("dev" in mpl.__version__) or ("rc" in mpl.__version__)) else mpl.__version__.replace('.', '') - astropy_version = "dev" if (("dev" in astropy.__version__) or ("rc" in astropy.__version__)) else astropy.__version__.replace('.', '') + animators_version = "dev" if (("dev" in mpl_animators.__version__) or ("rc" in mpl_animators.__version__)) else mpl_animators.__version__.replace(".", "") + mpl_version = "dev" if (("dev" in mpl.__version__) or ("rc" in mpl.__version__)) else mpl.__version__.replace(".", "") + astropy_version = "dev" if (("dev" in astropy.__version__) or ("rc" in astropy.__version__)) else astropy.__version__.replace(".", "") return f"figure_hashes_mpl_{mpl_version}_ft_{ft2_version}_astropy_{astropy_version}_animators_{animators_version}.json" @@ -54,8 +54,8 @@ def figure_test(test_function): @pytest.mark.remote_data @pytest.mark.mpl_image_compare(hash_library=hash_library_file.resolve(), - savefig_kwargs={'metadata': {'Software': None}}, - style='default') + savefig_kwargs={"metadata": {"Software": None}}, + style="default") @wraps(test_function) def test_wrapper(*args, **kwargs): ret = test_function(*args, **kwargs) @@ -78,7 +78,7 @@ def assert_extra_coords_equal(test_input, extra_coords): if not isinstance(ec_table, tuple): test_table = (test_table,) ec_table = (ec_table,) - for test_tab, ec_tab in zip(test_table, ec_table): + for test_tab, ec_tab in zip(test_table, ec_table, strict=False): if ec_tab.isscalar: assert test_tab == ec_tab else: @@ -107,8 +107,8 @@ def assert_cubes_equal(test_input, expected_cube, check_data=True): assert np.all(test_input.shape == expected_cube.shape) assert_metas_equal(test_input.meta, expected_cube.meta) if type(test_input.extra_coords) is not type(expected_cube.extra_coords): - raise AssertionError("NDCube extra_coords not of same type: {0} != {1}".format( - type(test_input.extra_coords), type(expected_cube.extra_coords))) + msg = f"NDCube extra_coords not of same type: {type(test_input.extra_coords)} != {type(expected_cube.extra_coords)}" + raise AssertionError(msg) if test_input.extra_coords is not None: assert_extra_coords_equal(test_input.extra_coords, expected_cube.extra_coords) @@ -129,7 +129,6 @@ def assert_wcs_are_equal(wcs1, wcs2): Also checks if both the wcs objects are instance of `~astropy.wcs.wcsapi.SlicedLowLevelWCS`. """ - if not isinstance(wcs1, BaseLowLevelWCS): wcs1 = wcs1.low_level_wcs if not isinstance(wcs2, BaseLowLevelWCS): @@ -154,7 +153,6 @@ def create_sliced_wcs(wcs, item, dim): """ Creates a sliced `SlicedFITSWCS` object from the given slice item """ - # Sanitize the slices item = sanitize_slices(item, dim) return SlicedFITSWCS(wcs, item) @@ -163,7 +161,7 @@ def create_sliced_wcs(wcs, item, dim): def assert_collections_equal(collection1, collection2, check_data=True): assert collection1.keys() == collection2.keys() assert collection1.aligned_axes == collection2.aligned_axes - for cube1, cube2 in zip(collection1.values(), collection2.values()): + for cube1, cube2 in zip(collection1.values(), collection2.values(), strict=False): # Check cubes are same type. assert type(cube1) is type(cube2) if isinstance(cube1, NDCube): @@ -171,4 +169,5 @@ def assert_collections_equal(collection1, collection2, check_data=True): elif isinstance(cube1, NDCubeSequence): assert_cubesequences_equal(cube1, cube2, check_data=check_data) else: - raise TypeError(f"Unsupported Type in NDCollection: {type(cube1)}") + msg = f"Unsupported Type in NDCollection: {type(cube1)}" + raise TypeError(msg) diff --git a/ndcube/tests/test_global_coords.py b/ndcube/tests/test_global_coords.py index 38453ba47..75115f92f 100644 --- a/ndcube/tests/test_global_coords.py +++ b/ndcube/tests/test_global_coords.py @@ -20,47 +20,47 @@ def gc(): def gc_coords(gc): coord1 = 1 * u.m coord2 = 2 * u.s - gc.add('name1', 'custom:physical_type1', coord1) - gc.add('name2', 'custom:physical_type2', coord2) + gc.add("name1", "custom:physical_type1", coord1) + gc.add("name2", "custom:physical_type2", coord2) return gc def test_add(gc): coord1 = 1 * u.m coord2 = 2 * u.s - gc.add('name1', 'custom:physical_type1', coord1) - gc.add('name2', 'custom:physical_type2', coord2) - assert gc.keys() == {'name1', 'name2'} - assert gc.physical_types == dict((('name1', 'custom:physical_type1'), ('name2', 'custom:physical_type2'))) + gc.add("name1", "custom:physical_type1", coord1) + gc.add("name2", "custom:physical_type2", coord2) + assert gc.keys() == {"name1", "name2"} + assert gc.physical_types == {"name1": "custom:physical_type1", "name2": "custom:physical_type2"} def test_remove(gc_coords): - gc_coords.remove('name2') + gc_coords.remove("name2") assert len(gc_coords) == 1 - assert gc_coords.keys() == {'name1'} - assert gc_coords.physical_types == {'name1': 'custom:physical_type1'} + assert gc_coords.keys() == {"name1"} + assert gc_coords.physical_types == {"name1": "custom:physical_type1"} def test_overwrite(gc_coords): with pytest.raises(ValueError): coord2 = 2 * u.s - gc_coords.add('name1', 'custom:physical_type2', coord2) + gc_coords.add("name1", "custom:physical_type2", coord2) def test_iterating(gc_coords): for i, gc_item in enumerate(gc_coords): if i == 0: - assert gc_item == 'name1' + assert gc_item == "name1" if i == 1: - assert gc_item == 'name2' + assert gc_item == "name2" def test_slicing(gc_coords): - assert u.allclose(gc_coords['name1'], u.Quantity(1., u.m)) + assert u.allclose(gc_coords["name1"], u.Quantity(1., u.m)) def test_physical_types(gc_coords): - assert gc_coords.physical_types == dict((('name1', 'custom:physical_type1'), ('name2', 'custom:physical_type2'))) + assert gc_coords.physical_types == {"name1": "custom:physical_type1", "name2": "custom:physical_type2"} def test_len(gc_coords): @@ -68,29 +68,29 @@ def test_len(gc_coords): def test_keys(gc_coords): - assert gc_coords.keys() == {'name1', 'name2'} + assert gc_coords.keys() == {"name1", "name2"} def test_values(gc_coords): - for value, expected in zip(gc_coords.values(), (1 * u.m, 2 * u.s)): + for value, expected in zip(gc_coords.values(), (1 * u.m, 2 * u.s), strict=False): assert u.allclose(value, expected) def test_items(gc_coords): - assert gc_coords.items() == {('name1', 1 * u.m), ('name2', 2 * u.s)} + assert gc_coords.items() == {("name1", 1 * u.m), ("name2", 2 * u.s)} def test_filter(gc_coords): - filtered = gc_coords.filter_by_physical_type('custom:physical_type1') + filtered = gc_coords.filter_by_physical_type("custom:physical_type1") assert isinstance(filtered, GlobalCoords) assert len(filtered) == 1 - assert 'name1' in filtered - assert u.allclose(filtered['name1'], 1 * u.m) - assert filtered.physical_types == {'name1': 'custom:physical_type1'} + assert "name1" in filtered + assert u.allclose(filtered["name1"], 1 * u.m) + assert filtered.physical_types == {"name1": "custom:physical_type1"} def test_dropped_to_global(ndcube_4d_ln_l_t_lt): - ndcube_4d_ln_l_t_lt.wcs.wcs.cname = ['lat', 'time', 'wavelength', 'lon'] + ndcube_4d_ln_l_t_lt.wcs.wcs.cname = ["lat", "time", "wavelength", "lon"] sub = ndcube_4d_ln_l_t_lt[0, 0, :, 0] gc = sub.global_coords assert len(gc) == 2 @@ -153,22 +153,22 @@ def axis_correlation_matrix(self): @property def world_axis_physical_types(self): - return ['pos.eq.ra', 'pos.eq.dec'] + return ["pos.eq.ra", "pos.eq.dec"] @property def world_axis_units(self): - return ['deg', 'deg'] + return ["deg", "deg"] @property def world_axis_object_components(self): - return [('test', 0, 'value'), ('test2', 0, 'value')] + return [("test", 0, "value"), ("test2", 0, "value")] @property def world_axis_object_classes(self): - return {'test': ('astropy.units.Quantity', (), - {'unit': ('astropy.units.Unit', ('deg',), {})}), - 'test2': ('astropy.units.Quantity', (), - {'unit': ('astropy.units.Unit', ('deg',), {})})} + return {"test": ("astropy.units.Quantity", (), + {"unit": ("astropy.units.Unit", ("deg",), {})}), + "test2": ("astropy.units.Quantity", (), + {"unit": ("astropy.units.Unit", ("deg",), {})})} def test_serialized_classes(): @@ -177,7 +177,7 @@ def test_serialized_classes(): class MultiCoord(list): - def __init__(self, *args): + def __init__(self, *args) -> None: super().__init__(args) @property @@ -212,22 +212,22 @@ def axis_correlation_matrix(self): @property def world_axis_physical_types(self): - return ['pos.eq.ra', 'pos.eq.dec', 'pos.distance.x'] + return ["pos.eq.ra", "pos.eq.dec", "pos.distance.x"] @property def world_axis_units(self): - return ['deg', 'deg', 'm'] + return ["deg", "deg", "m"] @property def world_axis_object_components(self): - return [('test', 0, 'value'), - ('test', 1, 'value'), - ('distance', 0, 'value')] + return [("test", 0, "value"), + ("test", 1, "value"), + ("distance", 0, "value")] @property def world_axis_object_classes(self): - return {'test': (MultiCoord, (), {}), - 'distance': (u.Quantity, (), {'unit': 'm'})} + return {"test": (MultiCoord, (), {}), + "distance": (u.Quantity, (), {"unit": "m"})} def test_non_skycoord_multi_object(): diff --git a/ndcube/tests/test_ndcollection.py b/ndcube/tests/test_ndcollection.py index 8ec12630f..ebfb178fc 100644 --- a/ndcube/tests/test_ndcollection.py +++ b/ndcube/tests/test_ndcollection.py @@ -14,15 +14,15 @@ # Define WCS object for all cubes. wcs_input_dict = { - 'CTYPE1': 'WAVE ', 'CUNIT1': 'Angstrom', 'CDELT1': 0.2, 'CRPIX1': 0, 'CRVAL1': 10, 'NAXIS1': 5, - 'CTYPE2': 'HPLT-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.5, 'CRPIX2': 2, 'CRVAL2': 0.5, 'NAXIS2': 4, - 'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 3} + "CTYPE1": "WAVE ", "CUNIT1": "Angstrom", "CDELT1": 0.2, "CRPIX1": 0, "CRVAL1": 10, "NAXIS1": 5, + "CTYPE2": "HPLT-TAN", "CUNIT2": "deg", "CDELT2": 0.5, "CRPIX2": 2, "CRVAL2": 0.5, "NAXIS2": 4, + "CTYPE3": "HPLN-TAN", "CUNIT3": "deg", "CDELT3": 0.4, "CRPIX3": 2, "CRVAL3": 1, "NAXIS3": 3} input_wcs = astropy.wcs.WCS(wcs_input_dict) wcs_input_dict1 = { - 'CTYPE3': 'WAVE ', 'CUNIT3': 'Angstrom', 'CDELT3': 0.2, 'CRPIX3': 0, 'CRVAL3': 10, 'NAXIS3': 5, - 'CTYPE1': 'HPLT-TAN', 'CUNIT1': 'deg', 'CDELT1': 0.5, 'CRPIX1': 2, 'CRVAL1': 0.5, 'NAXIS1': 4, - 'CTYPE2': 'HPLN-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.4, 'CRPIX2': 2, 'CRVAL2': 1, 'NAXIS2': 3} + "CTYPE3": "WAVE ", "CUNIT3": "Angstrom", "CDELT3": 0.2, "CRPIX3": 0, "CRVAL3": 10, "NAXIS3": 5, + "CTYPE1": "HPLT-TAN", "CUNIT1": "deg", "CDELT1": 0.5, "CRPIX1": 2, "CRVAL1": 0.5, "NAXIS1": 4, + "CTYPE2": "HPLN-TAN", "CUNIT2": "deg", "CDELT2": 0.4, "CRPIX2": 2, "CRVAL2": 1, "NAXIS2": 3} input_wcs1 = astropy.wcs.WCS(wcs_input_dict1) # Define cubes. @@ -42,7 +42,7 @@ seq_collection = NDCollection([("seq0", sequence02), ("seq1", sequence20)], aligned_axes="all") -@pytest.mark.parametrize("item,collection,expected", [ +@pytest.mark.parametrize(("item", "collection", "expected"), [ (0, cube_collection, NDCollection([("cube0", cube0[:, 0]), ("cube1", cube1[:, :, 0]), ("cube2", cube2[:, 0])], aligned_axes=((1,), (0,), (1,)))), @@ -75,13 +75,13 @@ ((slice(None), 1, slice(1, 3)), seq_collection, NDCollection([("seq0", sequence02[:, 1, 1:3]), ("seq1", sequence20[:, 1, 1:3])], - aligned_axes=((0, 1, 2), (0, 1, 2)))) + aligned_axes=((0, 1, 2), (0, 1, 2)))), ]) def test_collection_slicing(item, collection, expected): helpers.assert_collections_equal(collection[item], expected) -@pytest.mark.parametrize("item,collection,expected", [("cube1", cube_collection, cube1)]) +@pytest.mark.parametrize(("item", "collection", "expected"), [("cube1", cube_collection, cube1)]) def test_slice_cube_from_collection(item, collection, expected): helpers.assert_cubes_equal(collection[item], expected) @@ -91,7 +91,7 @@ def test_collection_copy(): helpers.assert_collections_equal(unaligned_collection.copy(), unaligned_collection) -@pytest.mark.parametrize("collection,popped_key,expected_popped,expected_collection", [ +@pytest.mark.parametrize(("collection", "popped_key", "expected_popped", "expected_collection"), [ (cube_collection, "cube0", cube0, NDCollection([("cube1", cube1), ("cube2", cube2)], aligned_axes=aligned_axes[1:])), (unaligned_collection, "cube0", cube0, NDCollection([("cube1", cube1), ("cube2", cube2)]))]) @@ -102,7 +102,7 @@ def test_collection_pop(collection, popped_key, expected_popped, expected_collec helpers.assert_collections_equal(popped_collection, expected_collection) -@pytest.mark.parametrize("collection,key,expected", [ +@pytest.mark.parametrize(("collection", "key", "expected"), [ (cube_collection, "cube0", NDCollection([("cube1", cube1), ("cube2", cube2)], aligned_axes=aligned_axes[1:]))]) def test_del_collection(collection, key, expected): @@ -111,7 +111,7 @@ def test_del_collection(collection, key, expected): helpers.assert_collections_equal(del_collection, expected) -@pytest.mark.parametrize("collection,key,data,aligned_axes,expected", [ +@pytest.mark.parametrize(("collection", "key", "data", "aligned_axes", "expected"), [ (cube_collection, "cube1", cube2, aligned_axes[2], NDCollection( [("cube0", cube0), ("cube1", cube2), ("cube2", cube2)], aligned_axes=((1, 2), (1, 2), (1, 2)))), @@ -144,23 +144,22 @@ def test_collection_update_without_aligned_axes(): helpers.assert_collections_equal(orig_collection, expected) -@pytest.mark.parametrize("collection, expected_aligned_dimensions", [ +@pytest.mark.parametrize(("collection", "expected_aligned_dimensions"), [ (cube_collection, [4, 5]), (seq_collection, [2, 3, 4, 5])]) def test_aligned_dimensions(collection, expected_aligned_dimensions): assert np.all(collection.aligned_dimensions == expected_aligned_dimensions) -@pytest.mark.parametrize("collection, expected", [ - (cube_collection, [('custom:pos.helioprojective.lat', 'custom:pos.helioprojective.lon'), - ('em.wl',)]), - (seq_collection, [('meta.obs.sequence',), - ('custom:pos.helioprojective.lat', 'custom:pos.helioprojective.lon'), - ('custom:pos.helioprojective.lat', 'custom:pos.helioprojective.lon'), - ('em.wl',)])]) +@pytest.mark.parametrize(("collection", "expected"), [ + (cube_collection, [("custom:pos.helioprojective.lat", "custom:pos.helioprojective.lon"), + ("em.wl",)]), + (seq_collection, [("meta.obs.sequence",), + ("custom:pos.helioprojective.lat", "custom:pos.helioprojective.lon"), + ("custom:pos.helioprojective.lat", "custom:pos.helioprojective.lon"), + ("em.wl",)])]) def test_aligned_axis_physical_types(collection, expected): output = collection.aligned_axis_physical_types - print(output) assert len(output) == len(expected) - for output_axis_types, expect_axis_types in zip(output, expected): + for output_axis_types, expect_axis_types in zip(output, expected, strict=False): assert set(output_axis_types) == set(expect_axis_types) diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index 5d8a25800..d22a4c87d 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -34,18 +34,15 @@ def test_wcs_object(all_ndcubes): assert isinstance(all_ndcubes.wcs, BaseHighLevelWCS) -@pytest.mark.parametrize("ndc, item", - ( +@pytest.mark.parametrize(("ndc", "item"), + [ ("ndcube_3d_ln_lt_l", np.s_[:, :, 0]), ("ndcube_3d_ln_lt_l", np.s_[..., 0]), ("ndcube_3d_ln_lt_l", np.s_[1:2, 1:2, 0]), - ("ndcube_3d_ln_lt_l", np.s_[..., 0]), - ("ndcube_3d_ln_lt_l", np.s_[:, :, 0]), - ("ndcube_3d_ln_lt_l", np.s_[1:2, 1:2, 0]), ("ndcube_4d_ln_lt_l_t", np.s_[:, :, 0, 0]), ("ndcube_4d_ln_lt_l_t", np.s_[..., 0, 0]), ("ndcube_4d_ln_lt_l_t", np.s_[1:2, 1:2, 1, 1]), - ), + ], indirect=("ndc",)) def test_slicing_ln_lt(ndc, item): sndc = ndc[item] @@ -70,18 +67,15 @@ def test_slicing_ln_lt(ndc, item): assert np.allclose(sndc.wcs.axis_correlation_matrix, np.ones(2, dtype=bool)) -@pytest.mark.parametrize("ndc, item", - ( - ("ndcube_3d_ln_lt_l", np.s_[0, 0, :]), - ("ndcube_3d_ln_lt_l", np.s_[0, 0, ...]), - ("ndcube_3d_ln_lt_l", np.s_[1, 1, 1:2]), +@pytest.mark.parametrize(("ndc", "item"), + [ ("ndcube_3d_ln_lt_l", np.s_[0, 0, :]), ("ndcube_3d_ln_lt_l", np.s_[0, 0, ...]), ("ndcube_3d_ln_lt_l", np.s_[1, 1, 1:2]), ("ndcube_4d_ln_lt_l_t", np.s_[0, 0, :, 0]), ("ndcube_4d_ln_lt_l_t", np.s_[0, 0, ..., 0]), ("ndcube_4d_ln_lt_l_t", np.s_[1, 1, 1:2, 1]), - ), + ], indirect=("ndc",)) def test_slicing_wave(ndc, item): sndc = ndc[item] @@ -105,18 +99,16 @@ def test_slicing_wave(ndc, item): assert np.allclose(sndc.wcs.axis_correlation_matrix, np.ones(1, dtype=bool)) -@pytest.mark.parametrize("ndc, item", - ( +@pytest.mark.parametrize(("ndc", "item"), + [ ("ndcube_3d_ln_lt_l", np.s_[0, :, :]), ("ndcube_3d_ln_lt_l", np.s_[0, ...]), ("ndcube_3d_ln_lt_l", np.s_[1, 1:2]), - ("ndcube_3d_ln_lt_l", np.s_[0, :, :]), - ("ndcube_3d_ln_lt_l", np.s_[0, ...]), ("ndcube_3d_ln_lt_l", np.s_[1, :, 1:2]), ("ndcube_4d_ln_lt_l_t", np.s_[0, :, :, 0]), ("ndcube_4d_ln_lt_l_t", np.s_[0, ..., 0]), ("ndcube_4d_ln_lt_l_t", np.s_[1, 1:2, 1:2, 1]), - ), + ], indirect=("ndc",)) def test_slicing_split_celestial(ndc, item): sndc = ndc[item] @@ -146,7 +138,7 @@ def test_slicing_split_celestial(ndc, item): def test_slicing_preserves_global_coords(ndcube_3d_ln_lt_l): ndc = ndcube_3d_ln_lt_l - ndc.global_coords.add('distance', 'pos.distance', 1 * u.m) + ndc.global_coords.add("distance", "pos.distance", 1 * u.m) sndc = ndc[0] assert sndc._global_coords._internal_coords == ndc._global_coords._internal_coords @@ -179,7 +171,7 @@ def test_slicing_removed_world_coords(ndcube_3d_ln_lt_l): def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime): cube = ndcube_3d_l_ln_lt_ectime - coords = cube.axis_world_coords('em.wl') + coords = cube.axis_world_coords("em.wl") assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09, 1.10e-09, 1.12e-09, 1.14e-09, 1.16e-09, 1.18e-09, 1.20e-09] * u.m) @@ -206,9 +198,9 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime): # slice the cube so extra_coords is empty, and then try and run axis_world_coords awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords) - assert awc == tuple() + assert awc == () sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True) - assert awc == tuple() + assert awc == () @pytest.mark.xfail(reason=">1D Tables not supported") @@ -218,7 +210,7 @@ def test_axis_world_coords_complex_ec(ndcube_4d_ln_lt_l_t): data = np.arange(np.prod(ec_shape)).reshape(ec_shape) * u.m / u.s # The lookup table has to be in world order so transpose it. - cube.extra_coords.add('velocity', (2, 1), data.T) + cube.extra_coords.add("velocity", (2, 1), data.T) coords = cube.axis_world_coords(wcs=cube.extra_coords) assert len(coords) == 1 @@ -230,7 +222,7 @@ def test_axis_world_coords_complex_ec(ndcube_4d_ln_lt_l_t): assert u.allclose(coords[3], data) -@pytest.mark.parametrize("axes", ([-1], [2], ["em"])) +@pytest.mark.parametrize("axes", [[-1], [2], ["em"]]) def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l): coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes) assert len(coords) == 1 @@ -243,7 +235,7 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l): assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) -@pytest.mark.parametrize("axes", ([-1], [2], ["em"])) +@pytest.mark.parametrize("axes", [[-1], [2], ["em"]]) def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True) assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) @@ -252,11 +244,11 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) -@pytest.mark.parametrize("ndc, item", - ( +@pytest.mark.parametrize(("ndc", "item"), + [ ("ndcube_3d_ln_lt_l", np.s_[0, 0, :]), ("ndcube_3d_ln_lt_l", np.s_[0, 0, ...]), - ), + ], indirect=("ndc",)) def test_axis_world_coords_sliced_all_3d(ndc, item): coords = ndc[item].axis_world_coords_values() @@ -266,11 +258,11 @@ def test_axis_world_coords_sliced_all_3d(ndc, item): assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) -@pytest.mark.parametrize("ndc, item", - ( +@pytest.mark.parametrize(("ndc", "item"), + [ ("ndcube_4d_ln_lt_l_t", np.s_[0, 0, :, 0]), ("ndcube_4d_ln_lt_l_t", np.s_[0, 0, ..., 0]), - ), + ], indirect=("ndc",)) def test_axis_world_coords_sliced_all_4d(ndc, item): coords = ndc[item].axis_world_coords_values() @@ -295,12 +287,12 @@ def test_axis_world_coords_all_4d_split(ndcube_4d_ln_l_t_lt): 1.2e-10, 1.4e-10, 1.6e-10, 1.8e-10, 2.0e-10] * u.m) -@pytest.mark.parametrize('wapt', ( - ('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat', 'em.wl'), - ('custom:pos.helioprojective.lat', 'em.wl'), +@pytest.mark.parametrize("wapt", [ + ("custom:pos.helioprojective.lon", "custom:pos.helioprojective.lat", "em.wl"), + ("custom:pos.helioprojective.lat", "em.wl"), (0, 1), - (0, 1, 3) -)) + (0, 1, 3), +]) def test_axis_world_coords_all_4d_split_sub(ndcube_4d_ln_l_t_lt, wapt): coords = ndcube_4d_ln_l_t_lt.axis_world_coords(*wapt) assert len(coords) == 2 @@ -328,14 +320,14 @@ def test_axis_world_coords_all(ndcube_3d_ln_lt_l): def test_axis_world_coords_wave(ndcube_3d_ln_lt_l): - coords = ndcube_3d_ln_lt_l.axis_world_coords('em.wl') + coords = ndcube_3d_ln_lt_l.axis_world_coords("em.wl") assert len(coords) == 1 assert isinstance(coords[0], u.Quantity) assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) -@pytest.mark.parametrize('wapt', ('custom:pos.helioprojective.lon', - 'custom:pos.helioprojective.lat')) +@pytest.mark.parametrize("wapt", ["custom:pos.helioprojective.lon", + "custom:pos.helioprojective.lat"]) def test_axis_world_coords_sky(ndcube_3d_ln_lt_l, wapt): coords = ndcube_3d_ln_lt_l.axis_world_coords(wapt) assert len(coords) == 1 @@ -372,14 +364,14 @@ def test_axis_world_coords_values_all(ndcube_3d_ln_lt_l): def test_axis_world_coords_values_wave(ndcube_3d_ln_lt_l): - coords = ndcube_3d_ln_lt_l.axis_world_coords_values('em.wl') + coords = ndcube_3d_ln_lt_l.axis_world_coords_values("em.wl") assert len(coords) == 1 assert isinstance(coords[0], u.Quantity) assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) def test_axis_world_coords_values_lon(ndcube_3d_ln_lt_l): - coords = ndcube_3d_ln_lt_l.axis_world_coords_values('custom:pos.helioprojective.lon') + coords = ndcube_3d_ln_lt_l.axis_world_coords_values("custom:pos.helioprojective.lon") assert len(coords) == 1 assert all(isinstance(c, u.Quantity) for c in coords) @@ -388,7 +380,7 @@ def test_axis_world_coords_values_lon(ndcube_3d_ln_lt_l): def test_axis_world_coords_values_lat(ndcube_3d_ln_lt_l): - coords = ndcube_3d_ln_lt_l.axis_world_coords_values('custom:pos.helioprojective.lat') + coords = ndcube_3d_ln_lt_l.axis_world_coords_values("custom:pos.helioprojective.lat") assert len(coords) == 1 assert all(isinstance(c, u.Quantity) for c in coords) assert u.allclose(coords[0], [[-0.00555556, -0.00416667, -0.00277778], @@ -397,12 +389,12 @@ def test_axis_world_coords_values_lat(ndcube_3d_ln_lt_l): def test_array_axis_physical_types(ndcube_3d_ln_lt_l): expected = [ - ('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat', 'custom:PIXEL'), - ('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat', 'custom:PIXEL'), - ('em.wl', 'custom:PIXEL')] + ("custom:pos.helioprojective.lon", "custom:pos.helioprojective.lat", "custom:PIXEL"), + ("custom:pos.helioprojective.lon", "custom:pos.helioprojective.lat", "custom:PIXEL"), + ("em.wl", "custom:PIXEL")] output = ndcube_3d_ln_lt_l.array_axis_physical_types for i in range(len(expected)): - assert all([physical_type in expected[i] for physical_type in output[i]]) + assert all(physical_type in expected[i] for physical_type in output[i]) def test_crop(ndcube_4d_ln_lt_l_t): @@ -474,7 +466,7 @@ def test_crop_scalar_valuerror(ndcube_2d_ln_lt): cube = ndcube_2d_ln_lt frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs) point = SkyCoord(Tx=359.99667, Ty=-0.0011111111, unit="deg", frame=frame) - with pytest.raises(ValueError, match=r'Input points causes cube to be cropped to a single pix'): + with pytest.raises(ValueError, match=r"Input points causes cube to be cropped to a single pix"): cube.crop(point) @@ -484,7 +476,7 @@ def test_crop_missing_dimensions(ndcube_4d_ln_lt_l_t): interval0 = cube.wcs.array_index_to_world([1, 2], [0, 1], [0, 1], [0, 2])[0] lower_corner = [interval0[0], None] upper_corner = [interval0[-1], None] - with pytest.raises(ValueError, match=r'2 components in point 0 do not match WCS with 3'): + with pytest.raises(ValueError, match=r"2 components in point 0 do not match WCS with 3"): cube.crop(lower_corner, upper_corner) @@ -504,8 +496,8 @@ def test_crop_by_values(ndcube_4d_ln_lt_l_t): cube = ndcube_4d_ln_lt_l_t intervals = cube.wcs.array_index_to_world_values([1, 2], [0, 1], [0, 1], [0, 2]) units = [u.min, u.m, u.deg, u.deg] - lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)] - upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)] + lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units, strict=False)] + upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units, strict=False)] # Ensure some quantities are in units different from each other # and those stored in the WCS. lower_corner[0] = lower_corner[0].to(units[0]) @@ -520,8 +512,8 @@ def test_crop_by_values_keepdims(ndcube_4d_ln_lt_l_t): cube = ndcube_4d_ln_lt_l_t intervals = list(cube.wcs.array_index_to_world_values([1, 2], [0], [0, 1], [0, 2])) units = [u.min, u.m, u.deg, u.deg] - lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)] - upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)] + lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units, strict=False)] + upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units, strict=False)] expected = cube[1:3, 0:1, 0:2, 0:3] output = cube.crop_by_values(lower_corner, upper_corner, keepdims=True) assert output.shape == (2, 1, 2, 3) @@ -579,19 +571,19 @@ def test_crop_by_values_valueerror1(ndcube_4d_ln_lt_l_t): lower_corner[0] = 0.5 upper_corner = [None] * 4 upper_corner[0] = 1.1 - with pytest.raises(ValueError, match=r'Units must be None or have same length 4 as corner inp'): + with pytest.raises(ValueError, match=r"Units must be None or have same length 4 as corner inp"): ndcube_4d_ln_lt_l_t.crop_by_values(lower_corner, upper_corner, units=["m"]) def test_crop_by_values_valueerror2(ndcube_4d_ln_lt_l_t): """Test upper and lower coordinates not being the same length""" - with pytest.raises(ValueError, match=r'All points must have same number of coordinate objects'): + with pytest.raises(ValueError, match=r"All points must have same number of coordinate objects"): ndcube_4d_ln_lt_l_t.crop_by_values([0], [1, None]) def test_crop_by_values_missing_dimensions(ndcube_4d_ln_lt_l_t): """Test bbox coordinates not being the same length as cube WCS""" - with pytest.raises(ValueError, match=r'3 dimensions in point 0 do not match WCS with 4'): + with pytest.raises(ValueError, match=r"3 dimensions in point 0 do not match WCS with 4"): ndcube_4d_ln_lt_l_t.crop_by_values([0, None, None], [1, None, None]) @@ -727,7 +719,7 @@ def test_crop_rotated_celestial(ndcube_4d_ln_lt_l_t): def test_initialize_from_ndcube(ndcube_3d_l_ln_lt_ectime): cube = ndcube_3d_l_ln_lt_ectime - cube.global_coords.add('distance', 'pos.distance', 1 * u.m) + cube.global_coords.add("distance", "pos.distance", 1 * u.m) cube2 = NDCube(cube) assert cube.global_coords is cube2.global_coords @@ -747,7 +739,7 @@ def test_initialize_from_ndcube(ndcube_3d_l_ln_lt_ectime): def test_reproject_interpolation(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln): target_wcs_header = wcs_4d_lt_t_l_ln.low_level_wcs.to_header() - target_wcs_header['CDELT3'] = 0.1 # original value = 0.2 + target_wcs_header["CDELT3"] = 0.1 # original value = 0.2 target_wcs = astropy.wcs.WCS(header=target_wcs_header) shape_out = (5, 20, 12, 8) @@ -773,7 +765,7 @@ def test_reproject_with_header(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln): def test_reproject_return_footprint(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln): target_wcs_header = wcs_4d_lt_t_l_ln.low_level_wcs.to_header() - target_wcs_header['CDELT3'] = 0.1 # original value = 0.2 + target_wcs_header["CDELT3"] = 0.1 # original value = 0.2 target_wcs = astropy.wcs.WCS(header=target_wcs_header) shape_out = (5, 20, 12, 8) @@ -986,14 +978,14 @@ def test_rebin_specutils(): # Tests for https://github.com/sunpy/ndcube/issues/717 y = np.arange(4000)*u.ct x = np.arange(200, 4200)*u.nm - spec = Spectrum1D(flux=y, spectral_axis=x, bin_specification='centers', mask=x > 2000*u.nm) + spec = Spectrum1D(flux=y, spectral_axis=x, bin_specification="centers", mask=x > 2000*u.nm) output = spec.rebin((10,), operation=np.sum, operation_ignores_mask=False) assert output.shape == (400,) def test_reproject_adaptive(ndcube_2d_ln_lt, wcs_2d_lt_ln): shape_out = (10, 12) - resampled_cube = ndcube_2d_ln_lt.reproject_to(wcs_2d_lt_ln, algorithm='adaptive', + resampled_cube = ndcube_2d_ln_lt.reproject_to(wcs_2d_lt_ln, algorithm="adaptive", shape_out=shape_out) assert ndcube_2d_ln_lt.data.shape == (10, 12) @@ -1002,7 +994,7 @@ def test_reproject_adaptive(ndcube_2d_ln_lt, wcs_2d_lt_ln): def test_reproject_exact(ndcube_2d_ln_lt, wcs_2d_lt_ln): shape_out = (10, 12) - resampled_cube = ndcube_2d_ln_lt.reproject_to(wcs_2d_lt_ln, algorithm='exact', + resampled_cube = ndcube_2d_ln_lt.reproject_to(wcs_2d_lt_ln, algorithm="exact", shape_out=shape_out) assert ndcube_2d_ln_lt.data.shape == (10, 12) @@ -1011,29 +1003,29 @@ def test_reproject_exact(ndcube_2d_ln_lt, wcs_2d_lt_ln): def test_reproject_invalid_algorithm(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln): with pytest.raises(ValueError): - _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm='my_algorithm', + _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm="my_algorithm", shape_out=(5, 10, 12, 8)) def test_reproject_adaptive_incompatible_wcs(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln, wcs_1d_l, ndcube_1d_l): with pytest.raises(ValueError): - _ = ndcube_1d_l.reproject_to(wcs_1d_l, algorithm='adaptive', + _ = ndcube_1d_l.reproject_to(wcs_1d_l, algorithm="adaptive", shape_out=(10,)) with pytest.raises(ValueError): - _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm='adaptive', + _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm="adaptive", shape_out=(5, 10, 12, 8)) def test_reproject_exact_incompatible_wcs(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln, wcs_1d_l, ndcube_1d_l): with pytest.raises(ValueError): - _ = ndcube_1d_l.reproject_to(wcs_1d_l, algorithm='exact', + _ = ndcube_1d_l.reproject_to(wcs_1d_l, algorithm="exact", shape_out=(10,)) with pytest.raises(ValueError): - _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm='exact', + _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm="exact", shape_out=(5, 10, 12, 8)) @@ -1050,7 +1042,7 @@ def check_arithmetic_value_and_units(cube_new, data_expected): assert u.allclose(cube_quantity, data_expected) -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity(np.random.rand(12), u.ct), @@ -1063,7 +1055,7 @@ def test_cube_arithmetic_add(ndcube_2d_ln_lt_units, value): check_arithmetic_value_and_units(new_cube, cube_quantity + value) -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity(np.random.rand(12), u.ct), @@ -1075,7 +1067,7 @@ def test_cube_arithmetic_radd(ndcube_2d_ln_lt_units, value): check_arithmetic_value_and_units(new_cube, value + cube_quantity) -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity(np.random.rand(12), u.ct), @@ -1087,7 +1079,7 @@ def test_cube_arithmetic_subtract(ndcube_2d_ln_lt_units, value): check_arithmetic_value_and_units(new_cube, cube_quantity - value) -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity(np.random.rand(12), u.ct), @@ -1099,7 +1091,7 @@ def test_cube_arithmetic_rsubtract(ndcube_2d_ln_lt_units, value): check_arithmetic_value_and_units(new_cube, value - cube_quantity) -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity(np.random.rand(12), u.ct), @@ -1115,7 +1107,7 @@ def test_cube_arithmetic_multiply(ndcube_2d_ln_lt_units, value): # TODO: test that uncertainties scale correctly -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity(np.random.rand(12), u.ct), @@ -1130,7 +1122,7 @@ def test_cube_arithmetic_rmultiply(ndcube_2d_ln_lt_units, value): check_arithmetic_value_and_units(new_cube, value * cube_quantity) -@pytest.mark.parametrize('value', [ +@pytest.mark.parametrize("value", [ 10 * u.ct, u.Quantity([10], u.ct), u.Quantity([2], u.s), @@ -1145,18 +1137,18 @@ def test_cube_arithmetic_divide(ndcube_2d_ln_lt_units, value): new_cube = ndcube_2d_ln_lt_units / value check_arithmetic_value_and_units(new_cube, cube_quantity / value) -@pytest.mark.parametrize('value', [1, 2, -1]) +@pytest.mark.parametrize("value", [1, 2, -1]) def test_cube_arithmetic_rdivide(ndcube_2d_ln_lt_units, value): cube_quantity = u.Quantity(ndcube_2d_ln_lt_units.data, ndcube_2d_ln_lt_units.unit) - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): new_cube = value / ndcube_2d_ln_lt_units check_arithmetic_value_and_units(new_cube, value / cube_quantity) -@pytest.mark.parametrize('value', [1, 2, -1]) +@pytest.mark.parametrize("value", [1, 2, -1]) def test_cube_arithmetic_rdivide_uncertainty(ndcube_4d_unit_uncertainty, value): cube_quantity = u.Quantity(ndcube_4d_unit_uncertainty.data, ndcube_4d_unit_uncertainty.unit) with pytest.warns(NDCubeUserWarning, match="UnknownUncertainty does not support uncertainty propagation with correlation. Setting uncertainties to None."): - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): new_cube = value / ndcube_4d_unit_uncertainty check_arithmetic_value_and_units(new_cube, value / cube_quantity) @@ -1183,33 +1175,33 @@ def test_cube_arithmetic_multiply_notimplementederror(ndcube_2d_ln_lt_units): -@pytest.mark.parametrize('power', [2, -2, 10, 0.5]) +@pytest.mark.parametrize("power", [2, -2, 10, 0.5]) def test_cube_arithmetic_power(ndcube_2d_ln_lt, power): cube_quantity = u.Quantity(ndcube_2d_ln_lt.data, ndcube_2d_ln_lt.unit) - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): new_cube = ndcube_2d_ln_lt ** power check_arithmetic_value_and_units(new_cube, cube_quantity**power) -@pytest.mark.parametrize('power', [2, -2, 10, 0.5]) +@pytest.mark.parametrize("power", [2, -2, 10, 0.5]) def test_cube_arithmetic_power_unknown_uncertainty(ndcube_4d_unit_uncertainty, power): cube_quantity = u.Quantity(ndcube_4d_unit_uncertainty.data, ndcube_4d_unit_uncertainty.unit) with pytest.warns(NDCubeUserWarning, match="UnknownUncertainty does not support uncertainty propagation with correlation. Setting uncertainties to None."): - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): new_cube = ndcube_4d_unit_uncertainty ** power check_arithmetic_value_and_units(new_cube, cube_quantity**power) -@pytest.mark.parametrize('power', [2, -2, 10, 0.5]) +@pytest.mark.parametrize("power", [2, -2, 10, 0.5]) def test_cube_arithmetic_power_std_uncertainty(ndcube_2d_ln_lt_uncert, power): cube_quantity = u.Quantity(ndcube_2d_ln_lt_uncert.data, ndcube_2d_ln_lt_uncert.unit) with pytest.warns(NDCubeUserWarning, match=r" does not support propagation of uncertainties for power. Setting uncertainties to None."): - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): new_cube = ndcube_2d_ln_lt_uncert ** power check_arithmetic_value_and_units(new_cube, cube_quantity**power) -@pytest.mark.parametrize('new_unit', [u.mJ, 'mJ']) +@pytest.mark.parametrize("new_unit", [u.mJ, "mJ"]) def test_to(ndcube_1d_l, new_unit): cube = ndcube_1d_l expected_factor = 1000 diff --git a/ndcube/tests/test_ndcubesequence.py b/ndcube/tests/test_ndcubesequence.py index 1dee3e096..4571c1b93 100644 --- a/ndcube/tests/test_ndcubesequence.py +++ b/ndcube/tests/test_ndcubesequence.py @@ -14,7 +14,7 @@ def derive_sliced_cube_dims(orig_cube_dims, tuple_item): len_cube_item = len(tuple_item) - 1 if len_cube_item > 0: cube_item = tuple_item[1:] - for i, s in zip(np.arange(len_cube_item)[::-1], cube_item[::-1]): + for i, s in zip(np.arange(len_cube_item)[::-1], cube_item[::-1], strict=False): if isinstance(s, int): del expected_cube_dims[i] else: @@ -22,14 +22,14 @@ def derive_sliced_cube_dims(orig_cube_dims, tuple_item): return tuple(expected_cube_dims) -@pytest.mark.parametrize("ndc, item", - ( - ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0:1], ), +@pytest.mark.parametrize(("ndc", "item"), + [ + ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0:1] ), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0:1, 0:2]), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0:1, 1]), - ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[1:3, 1, 0:2]) - ), - indirect=('ndc',)) + ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[1:3, 1, 0:2]), + ], + indirect=("ndc",)) def test_slice_sequence_axis(ndc, item): # Calculate expected dimensions of cubes with sequence after slicing. tuple_item = item if isinstance(item, tuple) else (item,) @@ -41,13 +41,13 @@ def test_slice_sequence_axis(ndc, item): assert np.all(sliced_sequence[0].shape == expected_cube0_dims) -@pytest.mark.parametrize("ndc, item", - ( +@pytest.mark.parametrize(("ndc", "item"), + [ ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0]), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[1, 0:1]), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[2, 1]), - ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[3, 1, 0:2]) - ), + ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[3, 1, 0:2]), + ], indirect=("ndc",)) def test_extract_ndcube(ndc, item): cube = ndc[item] @@ -57,41 +57,41 @@ def test_extract_ndcube(ndc, item): assert np.all(cube.shape == expected_cube_dims) -@pytest.mark.parametrize("ndc, item, expected_common_axis", - ( +@pytest.mark.parametrize(("ndc", "item", "expected_common_axis"), + [ ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, 0], 0), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, 0:1, 0:2], 1), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, :, :, 1], 1), - ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, :, 0], None) - ), + ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, :, 0], None), + ], indirect=("ndc",)) def test_slice_common_axis(ndc, item, expected_common_axis): sliced_sequence = ndc[item] assert sliced_sequence._common_axis == expected_common_axis -@pytest.mark.parametrize("ndc, item, expected_shape", - ( +@pytest.mark.parametrize(("ndc", "item", "expected_shape"), + [ ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, 1:7], (3, 2, (2, 3, 1), 4)), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0, 1:7], (3, (2, 3, 1), 4)), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, 2:4], (2, 2, 1, 4)), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[:, 0:6], (2, 2, 3, 4)), ("ndcubesequence_4c_ln_lt_l_cax1", np.s_[0, 0:6], (2, 3, 4)), - ), + ], indirect=("ndc",)) def test_index_as_cube(ndc, item, expected_shape): assert (ndc.index_as_cube[item].shape == expected_shape) -@pytest.mark.parametrize("ndc, axis, expected_shape", - ( +@pytest.mark.parametrize(("ndc", "axis", "expected_shape"), + [ ("ndcubesequence_4c_ln_lt_l", 0, (8, 3, 4)), ("ndcubesequence_4c_ln_lt_l_cax1", 1, (12, 2, - 4)) - ), + 4)), + ], indirect=("ndc",)) def test_explode_along_axis_common_axis_None(ndc, axis, expected_shape): exploded_sequence = ndc.explode_along_axis(axis) @@ -99,74 +99,74 @@ def test_explode_along_axis_common_axis_None(ndc, axis, expected_shape): assert exploded_sequence._common_axis is None -@pytest.mark.parametrize("ndc", (('ndcubesequence_4c_ln_lt_l_cax1',)), indirect=("ndc",)) +@pytest.mark.parametrize("ndc", (["ndcubesequence_4c_ln_lt_l_cax1"]), indirect=("ndc",)) def test_explode_along_axis_common_axis_same(ndc): exploded_sequence = ndc.explode_along_axis(2) assert exploded_sequence.shape == (16, 2, 3) assert exploded_sequence._common_axis == ndc._common_axis -@pytest.mark.parametrize("ndc", (('ndcubesequence_4c_ln_lt_l_cax1',)), indirect=("ndc",)) +@pytest.mark.parametrize("ndc", (["ndcubesequence_4c_ln_lt_l_cax1"]), indirect=("ndc",)) def test_explode_along_axis_common_axis_changed(ndc): exploded_sequence = ndc.explode_along_axis(0) assert exploded_sequence.shape == (8, 3, 4) assert exploded_sequence._common_axis == ndc._common_axis - 1 -@pytest.mark.parametrize("ndc, expected_shape", - ( +@pytest.mark.parametrize(("ndc", "expected_shape"), + [ ("ndcubesequence_4c_ln_lt_l_cax1", (4, 2., 3., 4.)), - ), + ], indirect=("ndc",)) def test_shape(ndc, expected_shape): np.testing.assert_array_equal(ndc.shape, expected_shape) -@pytest.mark.parametrize("ndc, expected_shape", - ( +@pytest.mark.parametrize(("ndc", "expected_shape"), + [ ("ndcubesequence_4c_ln_lt_l_cax1", [2., 12, 4]), - ), + ], indirect=("ndc",)) def test_cube_like_shape(ndc, expected_shape): assert np.all(ndc.cube_like_shape == expected_shape) -@pytest.mark.parametrize("ndc", (("ndcubesequence_4c_ln_lt_l",)), indirect=("ndc",)) +@pytest.mark.parametrize("ndc", (["ndcubesequence_4c_ln_lt_l"]), indirect=("ndc",)) def test_cube_like_shape_error(ndc): with pytest.raises(TypeError): ndc.cube_like_shape -@pytest.mark.parametrize("ndc", (("ndcubesequence_3c_l_ln_lt_cax1",)), indirect=("ndc",)) +@pytest.mark.parametrize("ndc", (["ndcubesequence_3c_l_ln_lt_cax1"]), indirect=("ndc",)) def test_common_axis_coords(ndc): # Construct expected skycoord - common_coords = [cube.axis_world_coords('lon') for cube in ndc] + common_coords = [cube.axis_world_coords("lon") for cube in ndc] expected_skycoords = [] for cube_coords in common_coords: expected_skycoords += [cube_coords[0][i] for i in range(len(cube_coords[0]))] # Construct expected Times - base_time = Time('2000-01-01', format='fits', scale='utc') - expected_times = [base_time + TimeDelta(60*i, format='sec') for i in range(15)] + base_time = Time("2000-01-01", format="fits", scale="utc") + expected_times = [base_time + TimeDelta(60*i, format="sec") for i in range(15)] # Run test function. output = ndc.common_axis_coords # Check right number of coords returned. assert len(output) == 2 output_skycoords, output_times = output # Check SkyCoords are equal. - for output_coord, expected_coord in zip(output_skycoords, expected_skycoords): + for output_coord, expected_coord in zip(output_skycoords, expected_skycoords, strict=False): assert all(output_coord == expected_coord) # Check times are equal - for output_time, expected_time in zip(output_times, expected_times): + for output_time, expected_time in zip(output_times, expected_times, strict=False): td = output_time - expected_time assert u.allclose(td.to(u.s), 0*u.s, atol=1e-10*u.s) -@pytest.mark.parametrize("ndc", (("ndcubesequence_3c_l_ln_lt_cax1",)), indirect=("ndc",)) +@pytest.mark.parametrize("ndc", (["ndcubesequence_3c_l_ln_lt_cax1"]), indirect=("ndc",)) def test_sequence_axis_coords(ndc): - expected = {'distance': [1*u.m, 2*u.m, 3*u.m]} + expected = {"distance": [1*u.m, 2*u.m, 3*u.m]} output = ndc.sequence_axis_coords assert output == expected @@ -185,8 +185,8 @@ def test_crop_by_values(ndcubesequence_4c_ln_lt_l): seq = ndcubesequence_4c_ln_lt_l intervals = seq[0].wcs.array_index_to_world_values([1, 2], [0, 1], [0, 2]) units = [u.m, u.deg, u.deg] - lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)] - upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)] + lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units, strict=False)] + upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units, strict=False)] # Ensure some quantities are in units different from each other # and those stored in the WCS. lower_corner[0] = lower_corner[0].to(units[0]) diff --git a/ndcube/utils/__init__.py b/ndcube/utils/__init__.py index a61744f25..f351718da 100644 --- a/ndcube/utils/__init__.py +++ b/ndcube/utils/__init__.py @@ -1,3 +1,3 @@ from . import collection, cube, misc, sequence, wcs -__all__ = ['collection', 'cube', 'misc', 'sequence', 'wcs'] +__all__ = ["collection", "cube", "misc", "sequence", "wcs"] diff --git a/ndcube/utils/collection.py b/ndcube/utils/collection.py index 452344c8d..582c509f4 100644 --- a/ndcube/utils/collection.py +++ b/ndcube/utils/collection.py @@ -2,27 +2,28 @@ import numpy as np -__all__ = ['assert_aligned_axes_compatible'] +__all__ = ["assert_aligned_axes_compatible"] def _sanitize_aligned_axes(keys, data, aligned_axes): if aligned_axes is None: return None # If aligned_axes set to "all", assume all axes are aligned in order. - elif isinstance(aligned_axes, str) and aligned_axes.lower() == "all": + if isinstance(aligned_axes, str) and aligned_axes.lower() == "all": # Check all cubes are of same shape cube0_dims = data[0].shape - cubes_same_shape = all([all([d.shape[i] == dim for i, dim in enumerate(cube0_dims)]) - for d in data]) + cubes_same_shape = all(all(d.shape[i] == dim for i, dim in enumerate(cube0_dims)) + for d in data) if cubes_same_shape is not True: + msg = "All cubes in data not of same shape. Please set aligned_axes kwarg." raise ValueError( - "All cubes in data not of same shape. Please set aligned_axes kwarg.") + msg) sanitized_axes = tuple([tuple(range(len(cube0_dims)))] * len(data)) else: # Else, sanitize user-supplied aligned axes. sanitized_axes = _sanitize_user_aligned_axes(data, aligned_axes) - return dict(zip(keys, sanitized_axes)) + return dict(zip(keys, sanitized_axes, strict=False)) def _sanitize_user_aligned_axes(data, aligned_axes): @@ -49,8 +50,8 @@ def _sanitize_user_aligned_axes(data, aligned_axes): if not isinstance(aligned_axes, tuple): raise ValueError(aligned_axes_error_message) # Check type of each element. - axes_all_ints = all([isinstance(axis, numbers.Integral) for axis in aligned_axes]) - axes_all_tuples = all([isinstance(axis, tuple) for axis in aligned_axes]) + axes_all_ints = all(isinstance(axis, numbers.Integral) for axis in aligned_axes) + axes_all_tuples = all(isinstance(axis, tuple) for axis in aligned_axes) # If all elements are int, duplicate tuple so there is one for each cube. n_cubes = len(data) if axes_all_ints: @@ -61,7 +62,8 @@ def _sanitize_user_aligned_axes(data, aligned_axes): # all elements of each sub-tuple are ints. elif axes_all_tuples: if len(aligned_axes) != n_cubes: - raise ValueError("aligned_axes must have a tuple for each element in data.") + msg = "aligned_axes must have a tuple for each element in data." + raise ValueError(msg) n_aligned_axes = len(aligned_axes[0]) @@ -71,21 +73,25 @@ def _sanitize_user_aligned_axes(data, aligned_axes): # and the dimensions of the aligned axes in each cube are the same. subtuples_are_ints = [False] * n_cubes aligned_axes_same_lengths = [False] * n_cubes - if not all([len(axes) == n_aligned_axes for axes in aligned_axes]): - raise ValueError("Each element in aligned_axes must have same length.") + if not all(len(axes) == n_aligned_axes for axes in aligned_axes): + msg = "Each element in aligned_axes must have same length." + raise ValueError(msg) for i in range(n_cubes): # Check each cube has at least as many dimensions as there are aligned axes # and that all cubes have enough dimensions to accommodate aligned axes. n_cube_dims = len(data[i].shape) max_aligned_axis = max(aligned_axes[i]) if n_cube_dims < max([max_aligned_axis, n_aligned_axes]): - raise ValueError( + msg = ( "Each cube in data must have at least as many axes as aligned axes " "and aligned axis indices must be less than number of cube axes.\n" f"Cube number: {i};\n" f"Number of cube dimensions: {n_cube_dims};\n" f"No. aligned axes: {n_aligned_axes};\n" - f"Highest aligned axis: {max_aligned_axis}") + f"Highest aligned axis: {max_aligned_axis}" + ) + raise ValueError( + msg) subtuple_types = [False] * n_aligned_axes cube_lengths_equal = [False] * n_aligned_axes for j, axis in enumerate(aligned_axes[i]): @@ -96,16 +102,18 @@ def _sanitize_user_aligned_axes(data, aligned_axes): if not all(subtuples_are_ints): raise ValueError(aligned_axes_error_message) if not all(aligned_axes_same_lengths): - raise ValueError("Aligned cube/sequence axes must be of same length.") + msg = "Aligned cube/sequence axes must be of same length." + raise ValueError(msg) else: raise ValueError(aligned_axes_error_message) # Ensure all aligned axes are of same length. - check_dimensions = set([len(set([cube.shape[cube_aligned_axes[j]] - for cube, cube_aligned_axes in zip(data, aligned_axes)])) - for j in range(n_aligned_axes)]) + check_dimensions = {len({cube.shape[cube_aligned_axes[j]] + for cube, cube_aligned_axes in zip(data, aligned_axes, strict=False)}) + for j in range(n_aligned_axes)} if check_dimensions != {1}: - raise ValueError("Aligned axes are not all of same length.") + msg = "Aligned axes are not all of same length." + raise ValueError(msg) return aligned_axes @@ -118,7 +126,7 @@ def _update_aligned_axes(drop_aligned_axes_indices, aligned_axes, first_key): new_aligned_axes = None else: new_aligned_axes = [] - for key in aligned_axes.keys(): + for key in aligned_axes: cube_aligned_axes = np.array(aligned_axes[key]) for drop_axis_index in drop_aligned_axes_indices: drop_axis = cube_aligned_axes[drop_axis_index] @@ -150,14 +158,19 @@ def assert_aligned_axes_compatible(data_dimensions1, data_dimensions2, data_axes """ # If one set of aligned axes is None and the other isn't, they are defined as not compatible. if (data_axes1 is None and data_axes2 is not None) or (data_axes1 is not None and data_axes2 is None): - raise ValueError("Both collections must use aligned_axes or both axes must not use aligned_axes." - f"Currently {data_axes1} != {data_axes2}") + msg = ( + "Both collections must use aligned_axes or both axes must not use aligned_axes." + f"Currently {data_axes1} != {data_axes2}" + ) + raise ValueError(msg) # Aligned_axes are being used for both collections if data_axes1 is not None: # Confirm same number of aligned axes. if len(data_axes1) != len(data_axes2): - raise ValueError(f"Number of aligned axes must be equal: {len(data_axes1)} != {len(data_axes2)}") + msg = f"Number of aligned axes must be equal: {len(data_axes1)} != {len(data_axes2)}" + raise ValueError(msg) # Confirm dimension lengths of each aligned axis is the same. if not all(np.array(data_dimensions1)[np.array(data_axes1)] == np.array(data_dimensions2)[np.array(data_axes2)]): - raise ValueError("All corresponding aligned axes between cubes must be of same length.") + msg = "All corresponding aligned axes between cubes must be of same length." + raise ValueError(msg) diff --git a/ndcube/utils/cube.py b/ndcube/utils/cube.py index af6485d24..3973f3e95 100644 --- a/ndcube/utils/cube.py +++ b/ndcube/utils/cube.py @@ -32,23 +32,26 @@ def sanitize_wcs(func): def wcs_wrapper(*args, **kwargs): sig = inspect.signature(func) params = sig.bind(*args, **kwargs) - wcs = params.arguments.get('wcs', None) - self = params.arguments['self'] + wcs = params.arguments.get("wcs", None) + self = params.arguments["self"] if wcs is None: wcs = self.wcs - if not isinstance(wcs, ExtraCoords): - if not wcs.pixel_n_dim == self.data.ndim: - raise ValueError("The supplied WCS must have the same number of " - "pixel dimensions as the NDCube object. " - "If you specified `cube.extra_coords.wcs` " - "please just pass `cube.extra_coords`.") + if not isinstance(wcs, ExtraCoords) and wcs.pixel_n_dim != self.data.ndim: + msg = ( + "The supplied WCS must have the same number of " + "pixel dimensions as the NDCube object. " + "If you specified `cube.extra_coords.wcs` " + "please just pass `cube.extra_coords`." + ) + raise ValueError(msg) - if not isinstance(wcs, (BaseHighLevelWCS, ExtraCoords)): - raise TypeError("wcs argument must be a High Level WCS or an ExtraCoords object.") + if not isinstance(wcs, BaseHighLevelWCS | ExtraCoords): + msg = "wcs argument must be a High Level WCS or an ExtraCoords object." + raise TypeError(msg) - params.arguments['wcs'] = wcs + params.arguments["wcs"] = wcs return func(*params.args, **params.kwargs) @@ -67,7 +70,7 @@ def sanitize_crop_inputs(points, wcs): values_are_none = [False] * n_points for i, point in enumerate(points): # Ensure each point is a list - if isinstance(point, (tuple, list)): + if isinstance(point, tuple | list): points[i] = list(point) else: points[i] = [point] @@ -75,7 +78,7 @@ def sanitize_crop_inputs(points, wcs): # Later we will ensure all points have same number of objects. n_coords[i] = len(points[i]) # Confirm whether point contains at least one None entry. - if all([coord is None for coord in points[i]]): + if all(coord is None for coord in points[i]): values_are_none[i] = True # If no points contain a coord, i.e. if all entries in all points are None, # set no-op flag to True and exit. @@ -83,8 +86,11 @@ def sanitize_crop_inputs(points, wcs): return True, points, wcs # Not not all points are of same length, error. if len(set(n_coords)) != 1: - raise ValueError("All points must have same number of coordinate objects." - f"Number of objects in each point: {n_coords}") + msg = ( + "All points must have same number of coordinate objects." + f"Number of objects in each point: {n_coords}" + ) + raise ValueError(msg) # Import must be here to avoid circular import. from ndcube.extra_coords.extra_coords import ExtraCoords if isinstance(wcs, ExtraCoords): @@ -181,7 +187,7 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims): # If returned value is a 0-d array, convert to a length-1 tuple. if isinstance(point_array_indices, np.ndarray) and point_array_indices.ndim == 0: point_array_indices = (point_array_indices.item(),) - for axis, index in zip(array_axes_with_input, point_array_indices): + for axis, index in zip(array_axes_with_input, point_array_indices, strict=False): combined_points_array_idx[axis] = combined_points_array_idx[axis] + [index] # Define slice item with which to slice cube. item = [] @@ -200,8 +206,11 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims): result_is_scalar = False # If item will result in a scalar cube, raise an error as this is not currently supported. if result_is_scalar: - raise ValueError("Input points causes cube to be cropped to a single pixel. " - "This is not supported.") + msg = ( + "Input points causes cube to be cropped to a single pixel. " + "This is not supported." + ) + raise ValueError(msg) return tuple(item) @@ -250,17 +259,18 @@ def propagate_rebin_uncertainties(uncertainty, data, mask, operation, operation_ first dimension. """ flat_axis = 0 - operation_is_mean = True if operation in {np.mean, np.nanmean} else False - operation_is_nantype = True if operation in {np.nansum, np.nanmean, np.nanprod} else False + operation_is_mean = operation in {np.mean, np.nanmean} + operation_is_nantype = operation in {np.nansum, np.nanmean, np.nanprod} # If propagation_operation kwarg not set manually, try to set it based on operation kwarg. if not propagation_operation: if operation in {np.sum, np.nansum, np.mean, np.nanmean}: propagation_operation = np.add # TODO: product was renamed to prod for numpy 2.0 - elif operation in {np.prod, np.nanprod, np.product if hasattr(np, "product") else np.prod}: + elif operation in {np.prod, np.nanprod, np.prod if hasattr(np, "product") else np.prod}: propagation_operation = np.multiply else: - raise ValueError("propagation_operation not recognized.") + msg = "propagation_operation not recognized." + raise ValueError(msg) # Build mask if not provided. new_uncertainty = uncertainty[0] # Define uncertainty for initial iteration step. if operation_ignores_mask or mask is None: diff --git a/ndcube/utils/misc.py b/ndcube/utils/misc.py index b41b0ae88..b037f0fae 100644 --- a/ndcube/utils/misc.py +++ b/ndcube/utils/misc.py @@ -1,6 +1,6 @@ import astropy.units as u -__all__ = ['unique_sorted', 'convert_quantities_to_units'] +__all__ = ["unique_sorted", "convert_quantities_to_units"] def unique_sorted(iterable): @@ -31,4 +31,4 @@ def convert_quantities_to_units(coords, units): Non-quantity types remain. """ return [coord.to(unit) if isinstance(coord, u.Quantity) else coord - for coord, unit in zip(coords, units)] + for coord, unit in zip(coords, units, strict=False)] diff --git a/ndcube/utils/sequence.py b/ndcube/utils/sequence.py index 05af52585..66051c471 100644 --- a/ndcube/utils/sequence.py +++ b/ndcube/utils/sequence.py @@ -8,9 +8,9 @@ import numpy as np -__all__ = ['SequenceItem', - 'cube_like_index_to_sequence_and_common_axis_indices', - 'cube_like_tuple_item_to_sequence_items'] +__all__ = ["SequenceItem", + "cube_like_index_to_sequence_and_common_axis_indices", + "cube_like_tuple_item_to_sequence_items"] SequenceItem = namedtuple("SequenceItem", "sequence_index cube_item") @@ -48,10 +48,7 @@ def cube_like_index_to_sequence_and_common_axis_indices(cube_like_index, common_ """ cumul_lengths = np.cumsum(common_axis_lengths) sequence_index = np.arange(len(cumul_lengths))[cumul_lengths > cube_like_index][0] - if sequence_index == 0: - common_axis_index = cube_like_index - else: - common_axis_index = cube_like_index - cumul_lengths[sequence_index - 1] + common_axis_index = cube_like_index if sequence_index == 0 else cube_like_index - cumul_lengths[sequence_index - 1] return sequence_index, common_axis_index @@ -84,26 +81,27 @@ def cube_like_tuple_item_to_sequence_items(item, common_axis, common_axis_length via the cube-like API. """ if not hasattr(item, "__len__"): - raise TypeError("item must be an iterable of slices and/or ints.") + msg = "item must be an iterable of slices and/or ints." + raise TypeError(msg) if len(item) <= common_axis: - raise ValueError("item must be include an entry for the common axis, " - "i.e. length of item must be > common_axis.") + msg = ( + "item must be include an entry for the common axis, " + "i.e. length of item must be > common_axis." + ) + raise ValueError(msg) if not isinstance(item[common_axis], slice): - raise TypeError("This function should only be used when the common axis entry " - "of item is a slice object.") + msg = ( + "This function should only be used when the common axis entry " + "of item is a slice object." + ) + raise TypeError(msg) # Define default item for slicing the cubes default_cube_item = list(item) default_cube_item[common_axis] = slice(None) # Convert start and stop cube-like indices to sequence and cube indices. - if item[common_axis].start is None: - common_axis_start = 0 - else: - common_axis_start = item[common_axis].start - if item[common_axis].stop is None: - common_axis_stop = sum(common_axis_lengths) - else: - common_axis_stop = item[common_axis].stop + common_axis_start = 0 if item[common_axis].start is None else item[common_axis].start + common_axis_stop = sum(common_axis_lengths) if item[common_axis].stop is None else item[common_axis].stop item[common_axis] = slice(common_axis_start, common_axis_stop) start_sequence_index, start_common_axis_index = \ cube_like_index_to_sequence_and_common_axis_indices( diff --git a/ndcube/utils/sphinx/code_context.py b/ndcube/utils/sphinx/code_context.py index 1048dcdd0..7385a8455 100644 --- a/ndcube/utils/sphinx/code_context.py +++ b/ndcube/utils/sphinx/code_context.py @@ -15,35 +15,36 @@ class ExpandingCodeBlock(CodeBlock): It behaves like the code-block directive, with the addition of a ``:summary:`` option, which sets the unexpanded text. """ + has_content = True required_arguments = 0 optional_arguments = 1 final_argument_whitespace = False option_spec = { - 'force': directives.flag, - 'linenos': directives.flag, - 'dedent': int, - 'lineno-start': int, - 'emphasize-lines': directives.unchanged_required, - 'caption': directives.unchanged_required, - 'class': directives.class_option, - 'name': directives.unchanged, - 'summary': directives.unchanged_required, + "force": directives.flag, + "linenos": directives.flag, + "dedent": int, + "lineno-start": int, + "emphasize-lines": directives.unchanged_required, + "caption": directives.unchanged_required, + "class": directives.class_option, + "name": directives.unchanged, + "summary": directives.unchanged_required, } def run(self): source, lineno = self.state_machine.get_source_and_line(self.lineno) - summary_text = self.options.get('summary', 'Show setup code') + summary_text = self.options.get("summary", "Show setup code") opening_details = f"""\
{summary_text} """ - open_raw_node = nodes.raw('', opening_details, format='html') + open_raw_node = nodes.raw("", opening_details, format="html") open_raw_node.source, open_raw_node.line = source, lineno - close_raw_node = nodes.raw('', "
", format='html') + close_raw_node = nodes.raw("", "", format="html") close_raw_node.source, close_raw_node.line = source, lineno literal = super().run()[0] @@ -52,6 +53,6 @@ def run(self): def setup(app): - app.add_directive('expanding-code-block', ExpandingCodeBlock) + app.add_directive("expanding-code-block", ExpandingCodeBlock) - return {'parallel_read_safe': True, 'parallel_write_safe': True} + return {"parallel_read_safe": True, "parallel_write_safe": True} diff --git a/ndcube/utils/tests/test_utils_collection.py b/ndcube/utils/tests/test_utils_collection.py index 63e7d7000..394fb333f 100644 --- a/ndcube/utils/tests/test_utils_collection.py +++ b/ndcube/utils/tests/test_utils_collection.py @@ -4,7 +4,7 @@ from ndcube.utils import collection as collection_utils -@pytest.mark.parametrize("data_dimensions1,data_dimensions2,data_axes1,data_axes2", [ +@pytest.mark.parametrize(("data_dimensions1", "data_dimensions2", "data_axes1", "data_axes2"), [ ([3., 4., 5.], [3., 5., 15.], (0, 2), (0, 1))]) def test_assert_aligned_axes_compatible(data_dimensions1, data_dimensions2, data_axes1, data_axes2): @@ -12,7 +12,7 @@ def test_assert_aligned_axes_compatible(data_dimensions1, data_dimensions2, data_axes1, data_axes2) -@pytest.mark.parametrize("data_dimensions1,data_dimensions2,data_axes1,data_axes2", [ +@pytest.mark.parametrize(("data_dimensions1", "data_dimensions2", "data_axes1", "data_axes2"), [ ([3., 4., 5.], [3., 5., 15.], (0, 1), (0, 1)), ([3., 4., 5.], [3., 5., 15.], (0, 1), (0, 1, 2)), ([3., 4., 5.], [3., 5., 15.], (0, 1), None)]) diff --git a/ndcube/utils/tests/test_utils_sequence.py b/ndcube/utils/tests/test_utils_sequence.py index 373b4baf8..960257f68 100644 --- a/ndcube/utils/tests/test_utils_sequence.py +++ b/ndcube/utils/tests/test_utils_sequence.py @@ -14,9 +14,9 @@ @pytest.mark.parametrize( - "cube_like_index, common_axis, common_axis_lengths, expected_seq_idx, expected_common_idx", + ("cube_like_index", "common_axis", "common_axis_lengths", "expected_seq_idx", "expected_common_idx"), [(3, 1, [4, 4], 0, 3), - (3, 1, [2, 2], 1, 1)] + (3, 1, [2, 2], 1, 1)], ) def test_cube_like_index_to_sequence_and_common_axis_indices( cube_like_index, common_axis, common_axis_lengths, expected_seq_idx, expected_common_idx): @@ -28,14 +28,14 @@ def test_cube_like_index_to_sequence_and_common_axis_indices( @pytest.mark.parametrize( - "item, common_axis, common_axis_lengths, n_cube_dims, expected_sequence_items", [ + ("item", "common_axis", "common_axis_lengths", "n_cube_dims", "expected_sequence_items"), [ ((slice(None), slice(4, 6)), 1, [3, 3], 4, [SequenceItem(sequence_index=1, cube_item=slice(1, 3))]), ((slice(None), slice(None)), 1, [3, 3, 3], 4, [SequenceItem(sequence_index=0, cube_item=slice(0, None)), SequenceItem(sequence_index=1, cube_item=slice(None)), - SequenceItem(sequence_index=2, cube_item=slice(None, 9))])] + SequenceItem(sequence_index=2, cube_item=slice(None, 9))])], ) def test_cube_like_tuple_item_to_sequence_items( item, common_axis, common_axis_lengths, n_cube_dims, expected_sequence_items): diff --git a/ndcube/utils/tests/test_utils_wcs.py b/ndcube/utils/tests/test_utils_wcs.py index f21244ca4..dd44abc49 100644 --- a/ndcube/utils/tests/test_utils_wcs.py +++ b/ndcube/utils/tests/test_utils_wcs.py @@ -7,25 +7,25 @@ from ndcube import utils ht_with_celestial = { - 'CTYPE4': 'HPLN-TAN', 'CUNIT4': 'deg', 'CDELT4': 1, 'CRPIX4': 0, 'CRVAL4': 0, 'NAXIS4': 1, - 'CNAME4': 'redundant axis', 'CROTA4': 0, - 'CTYPE3': 'HPLT-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.5, 'CRPIX3': 0, 'CRVAL3': 0, 'NAXIS3': 2, - 'CTYPE2': 'WAVE ', 'CUNIT2': 'Angstrom', 'CDELT2': 0.2, 'CRPIX2': 0, 'CRVAL2': 0, - 'NAXIS2': 3, - 'CTYPE1': 'TIME ', 'CUNIT1': 'min', 'CDELT1': 0.4, 'CRPIX1': 0, 'CRVAL1': 0, 'NAXIS1': 4} - -hm = {'CTYPE1': 'WAVE ', 'CUNIT1': 'Angstrom', 'CDELT1': 0.2, 'CRPIX1': 0, 'CRVAL1': 10, - 'NAXIS1': 4, - 'CTYPE2': 'HPLT-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.5, 'CRPIX2': 2, 'CRVAL2': 0.5, - 'NAXIS2': 3, - 'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 2} + "CTYPE4": "HPLN-TAN", "CUNIT4": "deg", "CDELT4": 1, "CRPIX4": 0, "CRVAL4": 0, "NAXIS4": 1, + "CNAME4": "redundant axis", "CROTA4": 0, + "CTYPE3": "HPLT-TAN", "CUNIT3": "deg", "CDELT3": 0.5, "CRPIX3": 0, "CRVAL3": 0, "NAXIS3": 2, + "CTYPE2": "WAVE ", "CUNIT2": "Angstrom", "CDELT2": 0.2, "CRPIX2": 0, "CRVAL2": 0, + "NAXIS2": 3, + "CTYPE1": "TIME ", "CUNIT1": "min", "CDELT1": 0.4, "CRPIX1": 0, "CRVAL1": 0, "NAXIS1": 4} + +hm = {"CTYPE1": "WAVE ", "CUNIT1": "Angstrom", "CDELT1": 0.2, "CRPIX1": 0, "CRVAL1": 10, + "NAXIS1": 4, + "CTYPE2": "HPLT-TAN", "CUNIT2": "deg", "CDELT2": 0.5, "CRPIX2": 2, "CRVAL2": 0.5, + "NAXIS2": 3, + "CTYPE3": "HPLN-TAN", "CUNIT3": "deg", "CDELT3": 0.4, "CRPIX3": 2, "CRVAL3": 1, "NAXIS3": 2} wm = WCS(header=hm) hm_reindexed_102 = { - 'CTYPE2': 'WAVE ', 'CUNIT2': 'Angstrom', 'CDELT2': 0.2, 'CRPIX2': 0, 'CRVAL2': 10, - 'NAXIS2': 4, - 'CTYPE1': 'HPLT-TAN', 'CUNIT1': 'deg', 'CDELT1': 0.5, 'CRPIX1': 2, 'CRVAL1': 0.5, 'NAXIS1': 3, - 'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 2} + "CTYPE2": "WAVE ", "CUNIT2": "Angstrom", "CDELT2": 0.2, "CRPIX2": 0, "CRVAL2": 10, + "NAXIS2": 4, + "CTYPE1": "HPLT-TAN", "CUNIT1": "deg", "CDELT1": 0.5, "CRPIX1": 2, "CRVAL1": 0.5, "NAXIS1": 3, + "CTYPE3": "HPLN-TAN", "CUNIT3": "deg", "CDELT3": 0.4, "CRPIX3": 2, "CRVAL3": 1, "NAXIS3": 2} wm_reindexed_102 = WCS(header=hm_reindexed_102) @@ -47,9 +47,9 @@ def test_wcs(): class WCSTest: - def __init__(self): + def __init__(self) -> None: self.world_axis_physical_types = [ - 'custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat', 'em.wl', 'time'] + "custom:pos.helioprojective.lon", "custom:pos.helioprojective.lat", "em.wl", "time"] self.axis_correlation_matrix = _axis_correlation_matrix() @@ -75,21 +75,21 @@ def test_world_axis_to_pixel_axes(axis_correlation_matrix): def test_pixel_axis_to_physical_types(test_wcs): output = utils.wcs.pixel_axis_to_physical_types(0, test_wcs) - expected = np.array(['custom:pos.helioprojective.lon', - 'custom:pos.helioprojective.lat', 'time']) + expected = np.array(["custom:pos.helioprojective.lon", + "custom:pos.helioprojective.lat", "time"]) assert all(output == expected) def test_physical_type_to_pixel_axes(test_wcs): - output = utils.wcs.physical_type_to_pixel_axes('lon', test_wcs) + output = utils.wcs.physical_type_to_pixel_axes("lon", test_wcs) expected = np.array([0, 1]) assert all(output == expected) -@pytest.mark.parametrize("test_input,expected", [('wl', 2), ('em.wl', 2)]) +@pytest.mark.parametrize(("test_input", "expected"), [("wl", 2), ("em.wl", 2)]) def test_physical_type_to_world_axis(test_input, expected): - world_axis_physical_types = ['custom:pos.helioprojective.lon', - 'custom:pos.helioprojective.lat', 'em.wl', 'time'] + world_axis_physical_types = ["custom:pos.helioprojective.lon", + "custom:pos.helioprojective.lat", "em.wl", "time"] output = utils.wcs.physical_type_to_world_axis(test_input, world_axis_physical_types) assert output == expected @@ -114,8 +114,8 @@ def test_get_dependent_world_axes(axis_correlation_matrix): def test_get_dependent_physical_types(test_wcs): output = utils.wcs.get_dependent_physical_types("time", test_wcs) - expected = np.array(['custom:pos.helioprojective.lon', - 'custom:pos.helioprojective.lat', 'time']) + expected = np.array(["custom:pos.helioprojective.lon", + "custom:pos.helioprojective.lat", "time"]) assert all(output == expected) @@ -124,15 +124,15 @@ def test_array_indices_for_world_objects(wcs_4d_t_l_lt_ln): assert len(array_indices) == 3 assert array_indices == ((3,), (2,), (0, 1)) - array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_t_l_lt_ln, ('time',)) + array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_t_l_lt_ln, ("time",)) assert len(array_indices) == 1 assert array_indices == ((3,),) - array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_t_l_lt_ln, ('time', 'em.wl')) + array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_t_l_lt_ln, ("time", "em.wl")) assert len(array_indices) == 2 assert array_indices == ((3,), (2,)) - array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_t_l_lt_ln, ('lat',)) + array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_t_l_lt_ln, ("lat",)) assert len(array_indices) == 1 assert array_indices == ((0, 1),) @@ -142,15 +142,15 @@ def test_array_indices_for_world_objects_2(wcs_4d_lt_t_l_ln): assert len(array_indices) == 3 assert array_indices == ((0, 3), (2,), (1,)) - array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_lt_t_l_ln, ('lat',)) + array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_lt_t_l_ln, ("lat",)) assert len(array_indices) == 1 assert array_indices == ((0, 3),) - array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_lt_t_l_ln, ('lat', 'time')) + array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_lt_t_l_ln, ("lat", "time")) assert len(array_indices) == 2 assert array_indices == ((0, 3), (2,)) - array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_lt_t_l_ln, ('lon', 'time')) + array_indices = utils.wcs.array_indices_for_world_objects(wcs_4d_lt_t_l_ln, ("lon", "time")) assert len(array_indices) == 2 assert array_indices == ((0, 3), (2,)) @@ -164,8 +164,8 @@ def test_identify_invariant_axes(wcs_3d_l_lt_ln): source_wcs = wcs_3d_l_lt_ln target_wcs_header = wcs_3d_l_lt_ln.low_level_wcs.to_header().copy() - target_wcs_header['CDELT2'] = 10 - target_wcs_header['CDELT3'] = 20 + target_wcs_header["CDELT2"] = 10 + target_wcs_header["CDELT3"] = 20 target_wcs = WCS(header=target_wcs_header) invariant_axes = utils.wcs.identify_invariant_axes(source_wcs, target_wcs, (4, 4, 4)) diff --git a/ndcube/utils/wcs.py b/ndcube/utils/wcs.py index 9268676e2..9822c271a 100644 --- a/ndcube/utils/wcs.py +++ b/ndcube/utils/wcs.py @@ -10,14 +10,14 @@ from astropy.wcs.utils import pixel_to_pixel from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, low_level_api -__all__ = ['array_indices_for_world_objects', 'convert_between_array_and_pixel_axes', - 'calculate_world_indices_from_axes', 'wcs_ivoa_mapping', - 'pixel_axis_to_world_axes', 'world_axis_to_pixel_axes', - 'pixel_axis_to_physical_types', 'physical_type_to_pixel_axes', - 'physical_type_to_world_axis', 'get_dependent_pixel_axes', - 'get_dependent_array_axes', 'get_dependent_world_axes', - 'get_dependent_physical_types', 'array_indices_for_world_objects', - 'validate_physical_types'] +__all__ = ["array_indices_for_world_objects", "convert_between_array_and_pixel_axes", + "calculate_world_indices_from_axes", "wcs_ivoa_mapping", + "pixel_axis_to_world_axes", "world_axis_to_pixel_axes", + "pixel_axis_to_physical_types", "physical_type_to_pixel_axes", + "physical_type_to_world_axis", "get_dependent_pixel_axes", + "get_dependent_array_axes", "get_dependent_world_axes", + "get_dependent_physical_types", "array_indices_for_world_objects", + "validate_physical_types"] class TwoWayDict(UserDict): @@ -49,7 +49,7 @@ def inv(self): "HECH": "pos.bodyrc.alt", } wcs_ivoa_mapping = TwoWayDict() -for key in wcs_to_ivoa.keys(): +for key in wcs_to_ivoa: wcs_ivoa_mapping[key] = wcs_to_ivoa[key] @@ -74,18 +74,22 @@ def convert_between_array_and_pixel_axes(axis, naxes): """ # Check type of input. if not isinstance(axis, np.ndarray): - raise TypeError(f"input must be of array type. Got type: {type(axis)}") - if axis.dtype.char not in np.typecodes['AllInteger']: - raise TypeError(f"input dtype must be of int type. Got dtype: {axis.dtype})") + msg = f"input must be of array type. Got type: {type(axis)}" + raise TypeError(msg) + if axis.dtype.char not in np.typecodes["AllInteger"]: + msg = f"input dtype must be of int type. Got dtype: {axis.dtype})" + raise TypeError(msg) # Convert negative indices to positive equivalents. axis[axis < 0] += naxes if any(axis > naxes - 1): - raise IndexError("Axis out of range. " - f"Number of axes = {naxes}; Axis numbers requested = {axis}") + msg = ( + "Axis out of range. " + f"Number of axes = {naxes}; Axis numbers requested = {axis}" + ) + raise IndexError(msg) # Reflect axis about center of number of axes. - reflected_axis = naxes - 1 - axis + return naxes - 1 - axis - return reflected_axis def pixel_axis_to_world_axes(pixel_axis, axis_correlation_matrix): @@ -200,11 +204,14 @@ def physical_type_to_world_axis(physical_type, world_axis_physical_types): for world_axis_physical_type in world_axis_physical_types] widx = np.arange(len(world_axis_physical_types))[widx] if len(widx) != 1: - raise ValueError( + msg = ( "Input does not uniquely correspond to a physical type." f" Expected unique substring of one of {world_axis_physical_types}." f" Got: {physical_type}" ) + raise ValueError( + msg, + ) # Return axes with duplicates removed. return widx[0] @@ -242,8 +249,7 @@ def get_dependent_pixel_axes(pixel_axis, axis_correlation_matrix): # To do this we take a column from the matrix and find if there are # any entries in common with all other columns in the matrix. world_dep = axis_correlation_matrix[:, pixel_axis:pixel_axis + 1] - dependent_pixel_axes = np.sort(np.nonzero((world_dep & axis_correlation_matrix).any(axis=0))[0]) - return dependent_pixel_axes + return np.sort(np.nonzero((world_dep & axis_correlation_matrix).any(axis=0))[0]) def get_dependent_array_axes(array_axis, axis_correlation_matrix): @@ -308,8 +314,7 @@ def get_dependent_world_axes(world_axis, axis_correlation_matrix): # To do this we take a row from the matrix and find if there are # any entries in common with all other rows in the matrix. pixel_dep = axis_correlation_matrix[world_axis:world_axis + 1] - dependent_world_axes = np.sort(np.nonzero((pixel_dep & axis_correlation_matrix).any(axis=1))[0]) - return dependent_world_axes + return np.sort(np.nonzero((pixel_dep & axis_correlation_matrix).any(axis=1))[0]) def get_dependent_physical_types(physical_type, wcs): @@ -332,8 +337,7 @@ def get_dependent_physical_types(physical_type, wcs): world_axis_physical_types = wcs.world_axis_physical_types world_axis = physical_type_to_world_axis(physical_type, world_axis_physical_types) dependent_world_axes = get_dependent_world_axes(world_axis, wcs.axis_correlation_matrix) - dependent_physical_types = np.array(world_axis_physical_types)[dependent_world_axes] - return dependent_physical_types + return np.array(world_axis_physical_types)[dependent_world_axes] def validate_physical_types(physical_types): @@ -343,12 +347,15 @@ def validate_physical_types(physical_types): try: low_level_api.validate_physical_types(physical_types) except ValueError as e: - invalid_type = str(e).split(':')[1].strip() - raise ValueError( + invalid_type = str(e).split(":")[1].strip() + msg = ( f"'{invalid_type}' is not a valid IOVA UCD1+ physical type. " "It must be a string specified in the list (http://www.ivoa.net/documents/latest/UCDlist.html) " "or if no matching type exists it can be any string prepended with 'custom:'." ) + raise ValueError( + msg, + ) def calculate_world_indices_from_axes(wcs, axes): @@ -371,8 +378,11 @@ def calculate_world_indices_from_axes(wcs, axes): # If axis is str, it is a physical type or substring of a physical type. world_indices.append(physical_type_to_world_axis(axis, wcs.world_axis_physical_types)) else: - raise TypeError(f"Unrecognized axis type: {axis, type(axis)}. " - "Must be of type (numbers.Integral, str)") + msg = ( + f"Unrecognized axis type: {axis, type(axis)}. " + "Must be of type (numbers.Integral, str)" + ) + raise TypeError(msg) # Use inferred world axes to extract the desired coord value # and corresponding physical types. return np.unique(np.array(world_indices, dtype=int)) @@ -408,10 +418,7 @@ def array_indices_for_world_objects(wcs, axes=None): coordinates. The array indices will be returned in the sub-tuple in array index order, i.e ascending. """ - if axes: - world_indices = calculate_world_indices_from_axes(wcs, axes) - else: - world_indices = np.arange(wcs.world_n_dim) + world_indices = calculate_world_indices_from_axes(wcs, axes) if axes else np.arange(wcs.world_n_dim) object_names = np.array([wao_comp[0] for wao_comp in wcs.low_level_wcs.world_axis_object_components]) array_indices = [[]] * len(object_names) @@ -431,7 +438,7 @@ def array_indices_for_world_objects(wcs, axes=None): return tuple(ai for ai in array_indices if ai) -def get_low_level_wcs(wcs, name='wcs'): +def get_low_level_wcs(wcs, name="wcs"): """ Returns a low level WCS object from a low level or high level WCS. @@ -447,13 +454,12 @@ def get_low_level_wcs(wcs, name='wcs'): ------- wcs: `astropy.wcs.wcsapi.BaseLowLevelWCS` """ - if isinstance(wcs, BaseHighLevelWCS): return wcs.low_level_wcs - elif isinstance(wcs, BaseLowLevelWCS): + if isinstance(wcs, BaseLowLevelWCS): return wcs - else: - raise ValueError(f'{name} must implement either BaseHighLevelWCS or BaseLowLevelWCS') + msg = f"{name} must implement either BaseHighLevelWCS or BaseLowLevelWCS" + raise ValueError(msg) def compare_wcs_physical_types(source_wcs, target_wcs): @@ -472,9 +478,8 @@ def compare_wcs_physical_types(source_wcs, target_wcs): ------- result : `bool` """ - - source_wcs = get_low_level_wcs(source_wcs, 'source_wcs') - target_wcs = get_low_level_wcs(target_wcs, 'target_wcs') + source_wcs = get_low_level_wcs(source_wcs, "source_wcs") + target_wcs = get_low_level_wcs(target_wcs, "target_wcs") return source_wcs.world_axis_physical_types == target_wcs.world_axis_physical_types @@ -505,10 +510,9 @@ def identify_invariant_axes(source_wcs, target_wcs, input_shape, atol=1e-6, rtol A list of booleans denoting whether the axis is invariant or not. Follows the WCS ordering. """ - input_pixel_coords = np.meshgrid(*[np.arange(n) for n in input_shape]) output_pixel_coords = pixel_to_pixel(source_wcs, target_wcs, *input_pixel_coords) return [np.allclose(input_coord, output_coord, atol=atol, rtol=rtol) - for input_coord, output_coord in zip(input_pixel_coords, output_pixel_coords)] + for input_coord, output_coord in zip(input_pixel_coords, output_pixel_coords, strict=False)] diff --git a/ndcube/version.py b/ndcube/version.py index 515c2f0af..90d4111be 100644 --- a/ndcube/version.py +++ b/ndcube/version.py @@ -14,4 +14,4 @@ ) del warnings - version = '0.0.0' + version = "0.0.0" diff --git a/ndcube/visualization/__init__.py b/ndcube/visualization/__init__.py index b1087e642..cc68c3c7f 100644 --- a/ndcube/visualization/__init__.py +++ b/ndcube/visualization/__init__.py @@ -1,4 +1,4 @@ from .base import BasePlotter from .descriptor import PlotterDescriptor -__all__ = ['BasePlotter', "PlotterDescriptor"] +__all__ = ["BasePlotter", "PlotterDescriptor"] diff --git a/ndcube/visualization/base.py b/ndcube/visualization/base.py index d981242b0..5284f93b1 100644 --- a/ndcube/visualization/base.py +++ b/ndcube/visualization/base.py @@ -6,7 +6,7 @@ class BasePlotter(abc.ABC): Base class for NDCube plotter objects. """ - def __init__(self, ndcube=None): + def __init__(self, ndcube=None) -> None: self._ndcube = ndcube @abc.abstractmethod diff --git a/ndcube/visualization/descriptor.py b/ndcube/visualization/descriptor.py index 06926fe80..08307e56a 100644 --- a/ndcube/visualization/descriptor.py +++ b/ndcube/visualization/descriptor.py @@ -5,11 +5,11 @@ MISSING_ANIMATORS_ERROR_MSG = ("mpl_animators cannot be imported, so the default plotting " "functionality is disabled. Please install mpl_animators") -__all__ = ['PlotterDescriptor', 'MISSING_MATPLOTLIB_ERROR_MSG', 'MISSING_ANIMATORS_ERROR_MSG'] +__all__ = ["PlotterDescriptor", "MISSING_MATPLOTLIB_ERROR_MSG", "MISSING_ANIMATORS_ERROR_MSG"] class PlotterDescriptor: - def __init__(self, default_type=None): + def __init__(self, default_type=None) -> None: self._default_type = default_type def __set_name__(self, owner, name): @@ -47,31 +47,32 @@ def _resolve_default_type(self, raise_error=True): except ImportError as e: if raise_error: raise ImportError(MISSING_ANIMATORS_ERROR_MSG) from e + return None - elif self._default_type is not None: + if self._default_type is not None: return self._default_type # If we have no default type then just return None - else: - return + return None def __get__(self, obj, objtype=None): if obj is None: - return + return None if getattr(obj, self._attribute_name, None) is None: plotter_type = self._resolve_default_type() if plotter_type is None: - return + return None self.__set__(obj, plotter_type) return getattr(obj, self._attribute_name) - def __set__(self, obj, value): + def __set__(self, obj, value) -> None: if not isinstance(value, type): + msg = "Plotter attribute can only be set with an uninitialised plotter object." raise TypeError( - "Plotter attribute can only be set with an uninitialised plotter object.") + msg) setattr(obj, self._attribute_name, value(obj)) # here obj is the ndcube object and value is the plotter type diff --git a/ndcube/visualization/mpl_plotter.py b/ndcube/visualization/mpl_plotter.py index 3bedbf510..ad5a4fcdd 100644 --- a/ndcube/visualization/mpl_plotter.py +++ b/ndcube/visualization/mpl_plotter.py @@ -12,7 +12,7 @@ from .base import BasePlotter from .descriptor import MISSING_ANIMATORS_ERROR_MSG -__all__ = ['MatplotlibPlotter'] +__all__ = ["MatplotlibPlotter"] class MatplotlibPlotter(BasePlotter): @@ -77,12 +77,12 @@ def plot(self, axes=None, plot_axes=None, axes_coordinates=None, len(self._ndcube.shape), plot_wcs, plot_axes, axes_coordinates, axes_units) with warnings.catch_warnings(): - warnings.simplefilter('ignore', AstropyUserWarning) + warnings.simplefilter("ignore", AstropyUserWarning) if naxis == 1: ax = self._plot_1D_cube(plot_wcs, axes, axes_coordinates, axes_units, data_unit, **kwargs) - elif naxis == 2 and 'y' in plot_axes: + elif naxis == 2 and "y" in plot_axes: ax = self._plot_2D_cube(plot_wcs, axes, plot_axes, axes_coordinates, axes_units, data_unit, **kwargs) else: @@ -96,7 +96,7 @@ def _not_visible_coords(self, axes, axes_coordinates): """ Based on an axes object and axes_coords, work out which coords should not be visible. """ - visible_coords = set(item[1] for item in axes.coords._aliases.items() if item[0] in axes_coordinates) + visible_coords = {item[1] for item in axes.coords._aliases.items() if item[0] in axes_coordinates} return set(axes.coords._aliases.values()).difference(visible_coords) def _apply_axes_coordinates(self, axes, axes_coordinates): @@ -122,8 +122,11 @@ def _plot_1D_cube(self, wcs, axes=None, axes_coordinates=None, axes_units=None, if self._ndcube.unit is None: if data_unit is not None: - raise TypeError("Can only set y-axis unit if self._ndcube.unit is set to a " - "compatible unit.") + msg = ( + "Can only set y-axis unit if self._ndcube.unit is set to a " + "compatible unit." + ) + raise TypeError(msg) else: if data_unit is not None: ydata = u.Quantity(ydata, unit=self._ndcube.unit).to_value(data_unit) @@ -169,13 +172,14 @@ def _plot_2D_cube(self, wcs, axes=None, plot_axes=None, axes_coordinates=None, if data_unit is not None: # If user set data_unit, convert dat to desired unit if self._ndcube.unit set. if self._ndcube.unit is None: - raise TypeError("Can only set data_unit if NDCube.unit is set.") + msg = "Can only set data_unit if NDCube.unit is set." + raise TypeError(msg) data = u.Quantity(self._ndcube.data, unit=self._ndcube.unit).to_value(data_unit) if self._ndcube.mask is not None: data = np.ma.masked_array(data, self._ndcube.mask) - if plot_axes.index('x') > plot_axes.index('y'): + if plot_axes.index("x") > plot_axes.index("y"): data = data.T # Plot data @@ -208,12 +212,9 @@ def _animate_cube(self, wcs, plot_axes=None, axes_coordinates=None, # This changes the parameters for future iterations for hidden in self._not_visible_coords(ax.axes, axes_coordinates): - if hidden in ax.coord_params: - param = ax.coord_params[hidden] - else: - param = {} + param = ax.coord_params.get(hidden, {}) - param['ticks'] = False + param["ticks"] = False ax.coord_params[hidden] = param return ax @@ -230,10 +231,10 @@ def _as_mpl_axes(self): and this will generate a plot with the correct WCS coordinates on the axes. See https://wcsaxes.readthedocs.io for more information. """ - kwargs = {'wcs': self._ndcube.wcs} + kwargs = {"wcs": self._ndcube.wcs} n_dim = len(self._ndcube.shape) if n_dim > 2: - kwargs['slices'] = ['x', 'y'] + [None] * (n_dim - 2) + kwargs["slices"] = ["x", "y"] + [None] * (n_dim - 2) return WCSAxes, kwargs def _prep_animate_args(self, wcs, plot_axes, axes_units, data_unit): @@ -249,14 +250,14 @@ def _prep_animate_args(self, wcs, plot_axes, axes_units, data_unit): coord_params = {} if axes_units is not None: - for axis_unit, coord_name in zip(axes_units, wcs.world_axis_physical_types): - coord_params[coord_name] = {'format_unit': axis_unit} + for axis_unit, coord_name in zip(axes_units, wcs.world_axis_physical_types, strict=False): + coord_params[coord_name] = {"format_unit": axis_unit} # TODO: Add support for transposing the array. - if 'y' in plot_axes and plot_axes.index('y') < plot_axes.index('x'): + if "y" in plot_axes and plot_axes.index("y") < plot_axes.index("x"): warn_user( "Animating a NDCube does not support transposing the array. The world axes " - "may not display as expected because the array will not be transposed." + "may not display as expected because the array will not be transposed.", ) plot_axes = [p if p is not None else 0 for p in plot_axes] diff --git a/ndcube/visualization/mpl_sequence_plotter.py b/ndcube/visualization/mpl_sequence_plotter.py index 5fff46939..56ceb7fa3 100644 --- a/ndcube/visualization/mpl_sequence_plotter.py +++ b/ndcube/visualization/mpl_sequence_plotter.py @@ -5,7 +5,7 @@ from .base import BasePlotter from .plotting_utils import prep_plot_kwargs -__all__ = ['MatplotlibSequencePlotter', 'SequenceAnimator'] +__all__ = ["MatplotlibSequencePlotter", "SequenceAnimator"] class MatplotlibSequencePlotter(BasePlotter): @@ -31,9 +31,9 @@ def plot(self, sequence_axis_coords=None, sequence_axis_unit=None, **kwargs): """ sequence_dims = self._ndcube.shape if len(sequence_dims) == 2: - raise NotImplementedError("Visualizing sequences of 1-D cubes not currently supported.") - else: - return self.animate(sequence_axis_coords, sequence_axis_unit, **kwargs) + msg = "Visualizing sequences of 1-D cubes not currently supported." + raise NotImplementedError(msg) + return self.animate(sequence_axis_coords, sequence_axis_unit, **kwargs) def animate(self, sequence_axis_coords=None, sequence_axis_unit=None, **kwargs): """ @@ -80,11 +80,13 @@ class SequenceAnimator(ArrayAnimatorWCS): The unit in which to display the sequence_axis_coords. """ - def __init__(self, sequence, sequence_axis_coords=None, sequence_axis_unit=None, **kwargs): + def __init__(self, sequence, sequence_axis_coords=None, sequence_axis_unit=None, **kwargs) -> None: if sequence_axis_coords is not None: - raise NotImplementedError("Setting sequence_axis_coords not yet supported.") + msg = "Setting sequence_axis_coords not yet supported." + raise NotImplementedError(msg) if sequence_axis_unit is not None: - raise NotImplementedError("Setting sequence_axis_unit not yet supported.") + msg = "Setting sequence_axis_unit not yet supported." + raise NotImplementedError(msg) # Store sequence data self._cubes = sequence.data diff --git a/ndcube/visualization/plotting_utils.py b/ndcube/visualization/plotting_utils.py index 8e60c5722..067c08b75 100644 --- a/ndcube/visualization/plotting_utils.py +++ b/ndcube/visualization/plotting_utils.py @@ -1,12 +1,13 @@ import astropy.units as u -__all__ = ['prep_plot_kwargs', 'set_wcsaxes_format_units'] +__all__ = ["prep_plot_kwargs", "set_wcsaxes_format_units"] def _expand_ellipsis(ndim, plist): if Ellipsis in plist: if plist.count(Ellipsis) > 1: - raise IndexError("Only single ellipsis ('...') is permitted.") + msg = "Only single ellipsis ('...') is permitted." + raise IndexError(msg) # Replace the Ellipsis with the correct number of slice(None)s e_ind = plist.index(Ellipsis) @@ -22,7 +23,8 @@ def _expand_ellipsis(ndim, plist): def _expand_ellipsis_axis_coordinates(plist, wapt): if Ellipsis in plist: if plist.count(Ellipsis) > 1: - raise IndexError("Only single ellipsis ('...') is permitted.") + msg = "Only single ellipsis ('...') is permitted." + raise IndexError(msg) # Replace the Ellipsis with the correct number of slice(None)s e_ind = plist.index(Ellipsis) @@ -43,22 +45,23 @@ def prep_plot_kwargs(naxis, wcs, plot_axes, axes_coordinates, axes_units): """ # If plot_axes, axes_coordinates, axes_units are not None and not lists, # convert to lists for consistent indexing behaviour. - if (not isinstance(plot_axes, (tuple, list))) and (plot_axes is not None): + if (not isinstance(plot_axes, tuple | list)) and (plot_axes is not None): plot_axes = [plot_axes] - if (not isinstance(axes_coordinates, (tuple, list))) and (axes_coordinates is not None): + if (not isinstance(axes_coordinates, tuple | list)) and (axes_coordinates is not None): axes_coordinates = [axes_coordinates] - if (not isinstance(axes_units, (tuple, list))) and (axes_units is not None): + if (not isinstance(axes_units, tuple | list)) and (axes_units is not None): axes_units = [axes_units] # Set default value of plot_axes if not set by user. if plot_axes is None: - plot_axes = [..., 'y', 'x'] + plot_axes = [..., "y", "x"] # We flip the plot axes here so they are in the right order for WCSAxes plot_axes = plot_axes[::-1] plot_axes = _expand_ellipsis(naxis, plot_axes) - if 'x' not in plot_axes: - raise ValueError("'x' must be in plot_axes.") + if "x" not in plot_axes: + msg = "'x' must be in plot_axes." + raise ValueError(msg) if axes_coordinates is not None: axes_coordinates = _expand_ellipsis_axis_coordinates(axes_coordinates, wcs.world_axis_physical_types) @@ -68,21 +71,25 @@ def prep_plot_kwargs(naxis, wcs, plot_axes, axes_coordinates, axes_units): if isinstance(axis_coordinate, str): # coordinates can be accessed by either name or type if axis_coordinate not in set(wcs.world_axis_physical_types).union(set(wcs.world_axis_names)): - raise ValueError(f"{axis_coordinate} is not one of this cubes world axis physical types.") + msg = f"{axis_coordinate} is not one of this cubes world axis physical types." + raise ValueError(msg) if not isinstance(axis_coordinate, ax_coord_types): - raise TypeError(f"axes_coordinates must be one of {ax_coord_types} or list of those, not {type(axis_coordinate)}.") + msg = f"axes_coordinates must be one of {ax_coord_types} or list of those, not {type(axis_coordinate)}." + raise TypeError(msg) if axes_units is not None: axes_units = _expand_ellipsis(wcs.world_n_dim, axes_units) if len(axes_units) != wcs.world_n_dim: - raise ValueError(f"The length of the axes_units argument must be {wcs.world_n_dim}.") + msg = f"The length of the axes_units argument must be {wcs.world_n_dim}." + raise ValueError(msg) # Convert all non-None elements to astropy units - axes_units = list(map(lambda x: u.Unit(x) if x is not None else None, axes_units))[::-1] + axes_units = [u.Unit(x) if x is not None else None for x in axes_units][::-1] for i, axis_unit in enumerate(axes_units): wau = wcs.world_axis_units[i] if axis_unit is not None and not axis_unit.is_equivalent(wau): + msg = f"Specified axis unit '{axis_unit}' is not convertible to world axis unit '{wau}'" raise u.UnitsError( - f"Specified axis unit '{axis_unit}' is not convertible to world axis unit '{wau}'") + msg) return plot_axes, axes_coordinates, axes_units diff --git a/ndcube/visualization/tests/test_plotting.py b/ndcube/visualization/tests/test_plotting.py index 39672331f..c24b3ad52 100644 --- a/ndcube/visualization/tests/test_plotting.py +++ b/ndcube/visualization/tests/test_plotting.py @@ -22,16 +22,16 @@ def test_plot_1D_cube(ndcube_1d_l): @figure_test @pytest.mark.parametrize(("ndcube_4d", "cslice", "kwargs"), - ( + [ ("ln_lt_l_t", np.s_[0, 0, 0, :], {}), ("ln_lt_l_t", np.s_[0, 0, :, 0], {}), ("ln_lt_l_t", np.s_[0, :, 0, 0], {}), ("ln_lt_l_t", np.s_[:, 0, 0, 0], {}), ("uncertainty", np.s_[0, 0, 0, :], {}), - ("unit_uncertainty", np.s_[0, 0, 0, :], {'data_unit': u.mJ}), + ("unit_uncertainty", np.s_[0, 0, 0, :], {"data_unit": u.mJ}), - ("mask", np.s_[0, 0, 0, :], {'marker': 'o'}),), + ("mask", np.s_[0, 0, 0, :], {"marker": "o"})], indirect=["ndcube_4d"]) def test_plot_1D_cube_from_slice(ndcube_4d, cslice, kwargs): # TODO: The output for the spatial plots is inconsistent between the lat @@ -74,18 +74,18 @@ def test_plot_2D_cube_custom_axis(ndcube_2d_ln_lt): def test_plot_2D_cube_custom_axis_plot_axes(ndcube_2d_ln_lt): fig = plt.figure() ax = fig.add_subplot(111, projection=ndcube_2d_ln_lt.wcs) - ndcube_2d_ln_lt.plot(axes=ax, plot_axes=('x', 'y')) + ndcube_2d_ln_lt.plot(axes=ax, plot_axes=("x", "y")) return fig @figure_test @pytest.mark.parametrize(("ndcube_4d", "cslice", "kwargs"), - ( + [ ("ln_lt_l_t", np.s_[0, 0, :, :], {}), ("ln_lt_l_t", np.s_[0, :, :, 0], {}), ("ln_lt_l_t", np.s_[:, :, 0, 0], {}), - ("unit_uncertainty", np.s_[0, 0, :, :], {'data_unit': u.mJ}), - ("mask", np.s_[0, :, 0, :], {}),), + ("unit_uncertainty", np.s_[0, 0, :, :], {"data_unit": u.mJ}), + ("mask", np.s_[0, :, 0, :], {})], indirect=["ndcube_4d"]) def test_plot_2D_cube_from_slice(ndcube_4d, cslice, kwargs): fig = plt.figure() @@ -100,7 +100,7 @@ def test_plot_2D_cube_from_slice(ndcube_4d, cslice, kwargs): @figure_test def test_animate_2D_cube(ndcube_2d_ln_lt): cube = ndcube_2d_ln_lt - ax = cube.plot(plot_axes=[None, 'x']) + ax = cube.plot(plot_axes=[None, "x"]) assert isinstance(ax, mpl_animators.ArrayAnimatorWCS) return ax.fig @@ -108,22 +108,19 @@ def test_animate_2D_cube(ndcube_2d_ln_lt): @figure_test @pytest.mark.parametrize(("ndcube_4d", "cslice", "kwargs"), - ( + [ ("ln_lt_l_t", np.s_[:, :, 0, :], {}), - ("ln_lt_l_t", np.s_[:, :, 0, :], {'plot_axes': [..., 'x']}), + ("ln_lt_l_t", np.s_[:, :, 0, :], {"plot_axes": [..., "x"]}), ("ln_lt_l_t", None, {}), - ("ln_lt_l_t", None, {"plot_axes": [0, 0, 'x', 'y'], "axes_units": [None, None, u.pm, None]}), - ("ln_lt_l_t", None, {"plot_axes": [0, 'x', 0, 'y']}), + ("ln_lt_l_t", None, {"plot_axes": [0, 0, "x", "y"], "axes_units": [None, None, u.pm, None]}), + ("ln_lt_l_t", None, {"plot_axes": [0, "x", 0, "y"]}), ("ln_lt_l_t", np.s_[0, :, :, :], {}), ("ln_lt_l_t", np.s_[:, :, :, :], {}), - ("unit_uncertainty", np.s_[0, :, :, :], {'data_unit': u.mJ}), - ("mask", np.s_[:, :, :, :], {}),), + ("unit_uncertainty", np.s_[0, :, :, :], {"data_unit": u.mJ}), + ("mask", np.s_[:, :, :, :], {})], indirect=["ndcube_4d"]) def test_animate_cube_from_slice(ndcube_4d, cslice, kwargs): - if cslice: - sub = ndcube_4d[cslice] - else: - sub = ndcube_4d + sub = ndcube_4d[cslice] if cslice else ndcube_4d ax = sub.plot(**kwargs) assert isinstance(ax, mpl_animators.ArrayAnimatorWCS) diff --git a/ndcube/visualization/tests/test_plotting_utils.py b/ndcube/visualization/tests/test_plotting_utils.py index e82c9a3d2..4aec6a445 100644 --- a/ndcube/visualization/tests/test_plotting_utils.py +++ b/ndcube/visualization/tests/test_plotting_utils.py @@ -5,16 +5,16 @@ import ndcube.visualization.plotting_utils as utils -@pytest.mark.parametrize("ndim, plist, output", ( - (2, ['x', 'y'], ['x', 'y']), - (2, [..., 'x', 'y'], ['x', 'y']), - (2, ['x', 'y', ...], ['x', 'y']), - (3, ['x', ...], ['x', None, None]), - (4, ['x', ..., 'y'], ['x', None, None, 'y']), - (5, [..., 'x'], [None, None, None, None, 'x']), - (5, [..., 'x', None], [None, None, None, 'x', None]), - (5, [None, ..., 'x', None, 'y'], [None, None, 'x', None, 'y']), -)) +@pytest.mark.parametrize(("ndim", "plist", "output"), [ + (2, ["x", "y"], ["x", "y"]), + (2, [..., "x", "y"], ["x", "y"]), + (2, ["x", "y", ...], ["x", "y"]), + (3, ["x", ...], ["x", None, None]), + (4, ["x", ..., "y"], ["x", None, None, "y"]), + (5, [..., "x"], [None, None, None, None, "x"]), + (5, [..., "x", None], [None, None, None, "x", None]), + (5, [None, ..., "x", None, "y"], [None, None, "x", None, "y"]), +]) def test_expand_ellipsis(ndim, plist, output): result = utils._expand_ellipsis(ndim, plist) assert result == output @@ -22,7 +22,7 @@ def test_expand_ellipsis(ndim, plist, output): def test_expand_ellipsis_error(): with pytest.raises(IndexError): - utils._expand_ellipsis(1, (..., 'x', ...)) + utils._expand_ellipsis(1, (..., "x", ...)) def test_prep_plot_kwargs_errors(ndcube_4d_ln_lt_l_t): @@ -31,7 +31,7 @@ def test_prep_plot_kwargs_errors(ndcube_4d_ln_lt_l_t): """ # plot_axes has incorrect length with pytest.raises(ValueError): - utils.prep_plot_kwargs(4, ndcube_4d_ln_lt_l_t.wcs, ['wibble'], None, None) + utils.prep_plot_kwargs(4, ndcube_4d_ln_lt_l_t.wcs, ["wibble"], None, None) # axes_coordinates is not in world_axis_physical_types with pytest.raises(ValueError): @@ -43,7 +43,7 @@ def test_prep_plot_kwargs_errors(ndcube_4d_ln_lt_l_t): # axes_units has incorrect length with pytest.raises(ValueError): - utils.prep_plot_kwargs(4, ndcube_4d_ln_lt_l_t.wcs, None, None, ['m']) + utils.prep_plot_kwargs(4, ndcube_4d_ln_lt_l_t.wcs, None, None, ["m"]) # axes_units has incorrect type with pytest.raises(TypeError): @@ -53,17 +53,17 @@ def test_prep_plot_kwargs_errors(ndcube_4d_ln_lt_l_t): utils.prep_plot_kwargs(4, ndcube_4d_ln_lt_l_t.wcs, None, None, [u.eV, u.m, u.m, u.m]) -@pytest.mark.parametrize("ndcube_2d, args, output", ( +@pytest.mark.parametrize(("ndcube_2d", "args", "output"), [ ("ln_lt", (None, None, None), - (['x', 'y'], None, None)), + (["x", "y"], None, None)), ("ln_lt", - (None, [..., 'custom:pos.helioprojective.lon'], None), - (['x', 'y'], ['custom:pos.helioprojective.lat', 'custom:pos.helioprojective.lon'], None)), + (None, [..., "custom:pos.helioprojective.lon"], None), + (["x", "y"], ["custom:pos.helioprojective.lat", "custom:pos.helioprojective.lon"], None)), ("ln_lt", - (None, None, [u.deg, 'arcsec']), - (['x', 'y'], None, [u.arcsec, u.deg])), -), indirect=['ndcube_2d']) + (None, None, [u.deg, "arcsec"]), + (["x", "y"], None, [u.arcsec, u.deg])), +], indirect=["ndcube_2d"]) def test_prep_plot_kwargs(ndcube_2d, args, output): result = utils.prep_plot_kwargs(2, ndcube_2d.wcs, *args) assert result == output diff --git a/ndcube/wcs/tests/test_tools.py b/ndcube/wcs/tests/test_tools.py index 1aed41c23..1265ff01e 100644 --- a/ndcube/wcs/tests/test_tools.py +++ b/ndcube/wcs/tests/test_tools.py @@ -34,7 +34,7 @@ def test_unwrap_wcs_to_fitswcs(): assert_array_equal(dropped_data_dimensions, np.array([True, True, False, False])) assert isinstance(output_wcs, WCS) assert output_wcs._naxis == [1, 2, 1, 1] - assert list(output_wcs.wcs.ctype) == ['TIME', 'WAVE', 'HPLT-TAN', 'HPLN-TAN'] + assert list(output_wcs.wcs.ctype) == ["TIME", "WAVE", "HPLT-TAN", "HPLN-TAN"] world_values = output_wcs.array_index_to_world_values([0], [0], [0, 1], [0]) assert_array_almost_equal(world_values[0][0], np.array([2700])) assert_array_almost_equal(world_values[1], np.array([1.04e-09, 1.10e-09])) diff --git a/ndcube/wcs/tools.py b/ndcube/wcs/tools.py index 4924bd1e3..4f803db40 100644 --- a/ndcube/wcs/tools.py +++ b/ndcube/wcs/tools.py @@ -48,25 +48,27 @@ def unwrap_wcs_to_fitswcs(wcs): if hasattr(low_level_wrapper, "low_level_wcs"): low_level_wrapper = low_level_wrapper.low_level_wcs if not isinstance(low_level_wrapper, WCS): - raise TypeError(f"Base-level WCS must be type {type(WCS)}. Found: {type(low_level_wrapper)}") + msg = f"Base-level WCS must be type {type(WCS)}. Found: {type(low_level_wrapper)}" + raise TypeError(msg) fitswcs = low_level_wrapper dropped_data_axes = np.zeros(fitswcs.naxis, dtype=bool) # Unwrap each wrapper in reverse order and edit fitswcs. for low_level_wrapper in wrapper_chain[::-1]: if isinstance(low_level_wrapper, SlicedLowLevelWCS): slice_items = np.array([slice(None)] * fitswcs.naxis) - slice_items[dropped_data_axes == False] = low_level_wrapper._slices_array # numpy order + slice_items[dropped_data_axes is False] = low_level_wrapper._slices_array # numpy order fitswcs, dda = _slice_fitswcs(fitswcs, slice_items, numpy_order=True) dropped_data_axes[dda] = True elif isinstance(low_level_wrapper, ResampledLowLevelWCS): factor = np.ones(fitswcs.naxis) offset = np.zeros(fitswcs.naxis) - kept_wcs_axes = dropped_data_axes[::-1] == False # WCS-order + kept_wcs_axes = dropped_data_axes[::-1] is False # WCS-order factor[kept_wcs_axes] = low_level_wrapper._factor offset[kept_wcs_axes] = low_level_wrapper._offset fitswcs = _resample_fitswcs(fitswcs, factor, offset) else: - raise TypeError("Unrecognized/unsupported WCS Wrapper type: {type(low_level_wrapper)}") + msg = "Unrecognized/unsupported WCS Wrapper type: {type(low_level_wrapper)}" + raise TypeError(msg) return fitswcs, dropped_data_axes @@ -114,13 +116,19 @@ def negative_index_error_msg(x): return ( shape = shape[::-1] else: if len(shape) != naxis: - raise ValueError("shape kwarg must be same length as number of pixel axes " - f"in FITS-WCS, i.e. {naxis}") + msg = ( + "shape kwarg must be same length as number of pixel axes " + f"in FITS-WCS, i.e. {naxis}" + ) + raise ValueError(msg) if not all(isinstance(s, Integral) for s in shape): - raise TypeError("All elements of ``shape`` must be integers. " - f"shapes types = {[type(s) for s in shape]}") + msg = ( + "All elements of ``shape`` must be integers. " + f"shapes types = {[type(s) for s in shape]}" + ) + raise TypeError(msg) slice_items = list(slice_items) - for i, (item, len_axis) in enumerate(zip(slice_items, shape)): + for i, (item, len_axis) in enumerate(zip(slice_items, shape, strict=False)): if isinstance(item, Integral): # Mark axis corresponding to int item as dropped from data array. dropped_data_axes[i] = True @@ -142,8 +150,11 @@ def negative_index_error_msg(x): return ( stop = len_axis + item.stop if stop_neg else item.stop slice_items[i] = slice(start, stop, item.step) else: - raise TypeError("All slice_items must be a slice or an int. " - f"type(slice_items[{i}]) = {type(slice_items[i])}") + msg = ( + "All slice_items must be a slice or an int. " + f"type(slice_items[{i}]) = {type(slice_items[i])}" + ) + raise TypeError(msg) # Slice WCS sliced_wcs = fitswcs.slice(slice_items, numpy_order=numpy_order) return sliced_wcs, dropped_data_axes @@ -177,10 +188,12 @@ def _resample_fitswcs(fitswcs, factor, offset=0): # Sanitize inputs. factor = np.asarray(factor) if len(factor) != fitswcs.naxis: - raise ValueError(f"Length of factor must equal number of dimensions {fitswcs.naxis}.") + msg = f"Length of factor must equal number of dimensions {fitswcs.naxis}." + raise ValueError(msg) offset = np.asarray(offset) if len(offset) != fitswcs.naxis: - raise ValueError(f"Length of offset must equal number of dimensions {fitswcs.naxis}.") + msg = f"Length of offset must equal number of dimensions {fitswcs.naxis}." + raise ValueError(msg) # Scale plate scale and shift by offset. fitswcs.wcs.cdelt *= factor fitswcs.wcs.crpix = (fitswcs.wcs.crpix + offset) / factor diff --git a/ndcube/wcs/wrappers/__init__.py b/ndcube/wcs/wrappers/__init__.py index d040709cf..b26ce11d6 100644 --- a/ndcube/wcs/wrappers/__init__.py +++ b/ndcube/wcs/wrappers/__init__.py @@ -1,3 +1,3 @@ from .compound_wcs import * # NOQA -from .reordered_wcs import * # NOQA -from .resampled_wcs import * # NOQA +from .reordered_wcs import * +from .resampled_wcs import * diff --git a/ndcube/wcs/wrappers/compound_wcs.py b/ndcube/wcs/wrappers/compound_wcs.py index 85aa49e22..ab0255ec0 100644 --- a/ndcube/wcs/wrappers/compound_wcs.py +++ b/ndcube/wcs/wrappers/compound_wcs.py @@ -4,7 +4,7 @@ from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper -__all__ = ['CompoundLowLevelWCS'] +__all__ = ["CompoundLowLevelWCS"] def tuplesum(lists): @@ -27,7 +27,7 @@ class Mapping: """ - def __init__(self, mapping): + def __init__(self, mapping) -> None: self.mapping = mapping self.n_inputs = max(mapping) + 1 self.n_outputs = len(mapping) @@ -41,8 +41,8 @@ def inverse(self): for idx in range(self.n_inputs)) return type(self)(mapping) - def __repr__(self): - return f'' + def __repr__(self) -> str: + return f"" class CompoundLowLevelWCS(BaseWCSWrapper): @@ -69,15 +69,16 @@ class CompoundLowLevelWCS(BaseWCSWrapper): ``world_to_pixel`` are the same from all WCSes. """ - def __init__(self, *wcs, mapping=None, pixel_atol=1e-8): + def __init__(self, *wcs, mapping=None, pixel_atol=1e-8) -> None: self._wcs = wcs if not mapping: mapping = tuple(range(self._all_pixel_n_dim)) - if not len(mapping) == self._all_pixel_n_dim: + if len(mapping) != self._all_pixel_n_dim: + msg = "The length of the mapping must equal the total number of pixel dimensions in all input WCSes." raise ValueError( - "The length of the mapping must equal the total number of pixel dimensions in all input WCSes.") + msg) self.mapping = Mapping(mapping) self.atol = pixel_atol @@ -138,10 +139,13 @@ def world_to_pixel_values(self, *world_arrays): for idx_n in idx[1:]: if not np.allclose(pixel_arrays[idx_0], pixel_arrays[idx_n], atol=self.atol, equal_nan=True): - raise ValueError( + msg = ( "The world inputs for shared pixel axes did not result in a pixel " f"coordinate to within {self.atol} relative accuracy." ) + raise ValueError( + msg, + ) return self.mapping.inverse(*pixel_arrays) @property @@ -149,7 +153,7 @@ def world_axis_object_components(self): all_components = [] for iw, w in enumerate(self._wcs): for component in w.world_axis_object_components: - all_components.append((f'{component[0]}_{iw}',) + component[1:]) + all_components.append((f"{component[0]}_{iw}",) + component[1:]) return all_components @property @@ -158,7 +162,7 @@ def world_axis_object_classes(self): all_classes = {} for iw, w in enumerate(self._wcs): for key, value in w.world_axis_object_classes.items(): - all_classes[f'{key}_{iw}'] = value + all_classes[f"{key}_{iw}"] = value return all_classes @property @@ -168,21 +172,25 @@ def pixel_shape(self): out_shape = self.mapping.inverse(*pixel_shape) for i, ix in enumerate(self.mapping.mapping): if out_shape[ix] != pixel_shape[i]: + msg = "The pixel shapes of the supplied WCSes do not match for the dimensions shared by the supplied mapping." raise ValueError( - "The pixel shapes of the supplied WCSes do not match for the dimensions shared by the supplied mapping.") + msg) return out_shape + return None @property def pixel_bounds(self): if any(w.pixel_bounds is not None for w in self._wcs): - pixel_bounds = tuplesum(w.pixel_bounds or [tuple() for _ in range(w.pixel_n_dim)] for w in self._wcs) + pixel_bounds = tuplesum(w.pixel_bounds or [() for _ in range(w.pixel_n_dim)] for w in self._wcs) out_bounds = self.mapping.inverse(*pixel_bounds) for i, ix in enumerate(self.mapping.mapping): if pixel_bounds[i] and (out_bounds[ix] != pixel_bounds[i]): + msg = "The pixel bounds of the supplied WCSes do not match for the dimensions shared by the supplied mapping." raise ValueError( - "The pixel bounds of the supplied WCSes do not match for the dimensions shared by the supplied mapping.") + msg) iint = np.iinfo(int) return tuple(o or (iint.min, iint.max) for o in out_bounds) + return None @property def pixel_axis_names(self): @@ -191,7 +199,7 @@ def pixel_axis_names(self): for i, ix in enumerate(self.mapping.mapping): if out_names[ix] != pixel_names[i]: - out_names[ix] = ' / '.join([out_names[ix], pixel_names[i]]) + out_names[ix] = " / ".join([out_names[ix], pixel_names[i]]) return out_names @@ -216,4 +224,4 @@ def axis_correlation_matrix(self): @property def serialized_classes(self): - return any([w.serialized_classes for w in self._wcs]) + return any(w.serialized_classes for w in self._wcs) diff --git a/ndcube/wcs/wrappers/reordered_wcs.py b/ndcube/wcs/wrappers/reordered_wcs.py index 11639e01b..b236a16e1 100644 --- a/ndcube/wcs/wrappers/reordered_wcs.py +++ b/ndcube/wcs/wrappers/reordered_wcs.py @@ -4,7 +4,7 @@ from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper -__all__ = ['ReorderedLowLevelWCS'] +__all__ = ["ReorderedLowLevelWCS"] class ReorderedLowLevelWCS(BaseWCSWrapper): @@ -24,11 +24,13 @@ class ReorderedLowLevelWCS(BaseWCSWrapper): new WCS. """ - def __init__(self, wcs, pixel_order, world_order): + def __init__(self, wcs, pixel_order, world_order) -> None: if sorted(pixel_order) != list(range(wcs.pixel_n_dim)): - raise ValueError(f'pixel_order should be a permutation of {list(range(wcs.pixel_n_dim))}') + msg = f"pixel_order should be a permutation of {list(range(wcs.pixel_n_dim))}" + raise ValueError(msg) if sorted(world_order) != list(range(wcs.world_n_dim)): - raise ValueError(f'world_order should be a permutation of {list(range(wcs.world_n_dim))}') + msg = f"world_order should be a permutation of {list(range(wcs.world_n_dim))}" + raise ValueError(msg) self._wcs = wcs self._pixel_order = pixel_order self._world_order = world_order @@ -54,14 +56,12 @@ def world_axis_names(self): def pixel_to_world_values(self, *pixel_arrays): pixel_arrays = [pixel_arrays[idx] for idx in self._pixel_order_inv] world_arrays = self._wcs.pixel_to_world_values(*pixel_arrays) - world_arrays = [world_arrays[idx] for idx in self._world_order] - return world_arrays + return [world_arrays[idx] for idx in self._world_order] def world_to_pixel_values(self, *world_arrays): world_arrays = [world_arrays[idx] for idx in self._world_order_inv] pixel_arrays = self._wcs.world_to_pixel_values(*world_arrays) - pixel_arrays = [pixel_arrays[idx] for idx in self._pixel_order] - return pixel_arrays + return [pixel_arrays[idx] for idx in self._pixel_order] @property def world_axis_object_components(self): @@ -71,11 +71,13 @@ def world_axis_object_components(self): def pixel_shape(self): if self._wcs.pixel_shape: return tuple([self._wcs.pixel_shape[idx] for idx in self._pixel_order]) + return None @property def pixel_bounds(self): if self._wcs.pixel_bounds: return tuple([self._wcs.pixel_bounds[idx] for idx in self._pixel_order]) + return None @property def axis_correlation_matrix(self): diff --git a/ndcube/wcs/wrappers/resampled_wcs.py b/ndcube/wcs/wrappers/resampled_wcs.py index 3f3c2836a..a18c8de04 100644 --- a/ndcube/wcs/wrappers/resampled_wcs.py +++ b/ndcube/wcs/wrappers/resampled_wcs.py @@ -2,7 +2,7 @@ from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper -__all__ = ['ResampledLowLevelWCS'] +__all__ = ["ResampledLowLevelWCS"] class ResampledLowLevelWCS(BaseWCSWrapper): @@ -25,18 +25,20 @@ class ResampledLowLevelWCS(BaseWCSWrapper): shifted by the same amount in all dimensions. """ - def __init__(self, wcs, factor, offset=0): + def __init__(self, wcs, factor, offset=0) -> None: self._wcs = wcs if np.isscalar(factor): factor = [factor] * self.pixel_n_dim self._factor = np.array(factor) if len(self._factor) != self.pixel_n_dim: - raise ValueError(f"Length of factor must equal number of dimensions {self.pixel_n_dim}.") + msg = f"Length of factor must equal number of dimensions {self.pixel_n_dim}." + raise ValueError(msg) if np.isscalar(offset): offset = [offset] * self.pixel_n_dim self._offset = np.array(offset) if len(self._offset) != self.pixel_n_dim: - raise ValueError(f"Length of offset must equal number of dimensions {self.pixel_n_dim}.") + msg = f"Length of offset must equal number of dimensions {self.pixel_n_dim}." + raise ValueError(msg) def _top_to_underlying_pixels(self, top_pixels): # Convert user-facing pixel indices to the pixel grid of underlying WCS. @@ -77,7 +79,7 @@ def pixel_shape(self): atol=np.finfo(float).resolution) pixel_shape = underlying_shape / self._factor return tuple(int(np.rint(i)) if is_int else i - for i, is_int in zip(pixel_shape, int_elements)) + for i, is_int in zip(pixel_shape, int_elements, strict=False)) @property def pixel_bounds(self): diff --git a/ndcube/wcs/wrappers/tests/conftest.py b/ndcube/wcs/wrappers/tests/conftest.py index 329d9a1e3..02d882449 100644 --- a/ndcube/wcs/wrappers/tests/conftest.py +++ b/ndcube/wcs/wrappers/tests/conftest.py @@ -5,7 +5,7 @@ class Celestial2DLowLevelWCS(ApyCelestial2DLowLevelWCS): - def __init__(self): + def __init__(self) -> None: self._pixel_bounds = (-1, 5), (1, 7) @property diff --git a/ndcube/wcs/wrappers/tests/test_compound_wcs.py b/ndcube/wcs/wrappers/tests/test_compound_wcs.py index e3a4ce0bf..bf82d0faa 100644 --- a/ndcube/wcs/wrappers/tests/test_compound_wcs.py +++ b/ndcube/wcs/wrappers/tests/test_compound_wcs.py @@ -51,9 +51,9 @@ def celestial_wcs(request): """.strip() -@pytest.mark.parametrize(('spectral_wcs', 'celestial_wcs'), - product(['spectral_1d_ape14_wcs', 'spectral_1d_fitswcs'], - ['celestial_2d_ape14_wcs', 'celestial_2d_fitswcs']), +@pytest.mark.parametrize(("spectral_wcs", "celestial_wcs"), + product(["spectral_1d_ape14_wcs", "spectral_1d_fitswcs"], + ["celestial_2d_ape14_wcs", "celestial_2d_fitswcs"]), indirect=True) def test_celestial_spectral_ape14(spectral_wcs, celestial_wcs): @@ -61,12 +61,12 @@ def test_celestial_spectral_ape14(spectral_wcs, celestial_wcs): assert wcs.pixel_n_dim == 3 assert wcs.world_n_dim == 3 - assert tuple(wcs.world_axis_physical_types) == ('em.freq', 'pos.eq.ra', 'pos.eq.dec') - assert tuple(wcs.world_axis_units) == ('Hz', 'deg', 'deg') - assert tuple(wcs.pixel_axis_names) == ('', '', '') - assert tuple(wcs.world_axis_names) == ('Frequency', - 'Right Ascension', - 'Declination') + assert tuple(wcs.world_axis_physical_types) == ("em.freq", "pos.eq.ra", "pos.eq.dec") + assert tuple(wcs.world_axis_units) == ("Hz", "deg", "deg") + assert tuple(wcs.pixel_axis_names) == ("", "", "") + assert tuple(wcs.world_axis_names) == ("Frequency", + "Right Ascension", + "Declination") assert_equal(wcs.axis_correlation_matrix, np.array([[1, 0, 0], [0, 1, 1], [0, 1, 1]])) @@ -138,7 +138,7 @@ def test_shared_pixel_axis_compound_1d(spectral_1d_fitswcs, time_1d_fitswcs): assert wcs.pixel_n_dim == 1 assert wcs.pixel_shape is None - assert wcs.pixel_axis_names == ('',) + assert wcs.pixel_axis_names == ("",) assert wcs.pixel_bounds is None world = wcs.pixel_to_world_values(0) @@ -163,7 +163,7 @@ def test_shared_pixel_axis_compound_3d(spectral_cube_3d_fitswcs, time_1d_fitswcs assert wcs.pixel_n_dim == 3 np.testing.assert_allclose(wcs.pixel_shape, (10, 20, 30)) - assert wcs.pixel_axis_names == ('', '', '') + assert wcs.pixel_axis_names == ("", "", "") assert wcs.pixel_bounds == ((-1, 5), (1, 7), (1, 2.5)) np.testing.assert_allclose(wcs.axis_correlation_matrix, [[True, True, False], diff --git a/ndcube/wcs/wrappers/tests/test_reordered_wcs.py b/ndcube/wcs/wrappers/tests/test_reordered_wcs.py index 596333e36..03b456f7b 100644 --- a/ndcube/wcs/wrappers/tests/test_reordered_wcs.py +++ b/ndcube/wcs/wrappers/tests/test_reordered_wcs.py @@ -58,12 +58,12 @@ def test_spectral_cube(spectral_cube_3d_fitswcs): assert wcs.pixel_n_dim == 3 assert wcs.world_n_dim == 3 - assert tuple(wcs.world_axis_physical_types) == ('em.freq', 'pos.eq.ra', 'pos.eq.dec') - assert tuple(wcs.world_axis_units) == ('Hz', 'deg', 'deg') - assert tuple(wcs.pixel_axis_names) == ('', '', '') - assert tuple(wcs.world_axis_names) == ('Frequency', - 'Right Ascension', - 'Declination') + assert tuple(wcs.world_axis_physical_types) == ("em.freq", "pos.eq.ra", "pos.eq.dec") + assert tuple(wcs.world_axis_units) == ("Hz", "deg", "deg") + assert tuple(wcs.pixel_axis_names) == ("", "", "") + assert tuple(wcs.world_axis_names) == ("Frequency", + "Right Ascension", + "Declination") assert_equal(wcs.axis_correlation_matrix, np.array([[0, 1, 0], [1, 0, 1], [1, 0, 1]])) @@ -111,15 +111,15 @@ def test_spectral_cube(spectral_cube_3d_fitswcs): assert EXPECTED_SPECTRAL_CUBE_REPR in repr(wcs) -@pytest.mark.parametrize('order', [(1,), (1, 2, 2), (0, 1, 2, 3)]) +@pytest.mark.parametrize("order", [(1,), (1, 2, 2), (0, 1, 2, 3)]) def test_invalid(spectral_cube_3d_fitswcs, order): - with pytest.raises(ValueError, match=re.escape('pixel_order should be a permutation of [0, 1, 2]')): + with pytest.raises(ValueError, match=re.escape("pixel_order should be a permutation of [0, 1, 2]")): ReorderedLowLevelWCS(spectral_cube_3d_fitswcs, pixel_order=order, world_order=[2, 0, 1]) - with pytest.raises(ValueError, match=re.escape('world_order should be a permutation of [0, 1, 2]')): + with pytest.raises(ValueError, match=re.escape("world_order should be a permutation of [0, 1, 2]")): ReorderedLowLevelWCS(spectral_cube_3d_fitswcs, pixel_order=[1, 2, 0], world_order=order) diff --git a/ndcube/wcs/wrappers/tests/test_resampled_wcs.py b/ndcube/wcs/wrappers/tests/test_resampled_wcs.py index 7b7b945e4..3f6bd1b27 100644 --- a/ndcube/wcs/wrappers/tests/test_resampled_wcs.py +++ b/ndcube/wcs/wrappers/tests/test_resampled_wcs.py @@ -62,8 +62,8 @@ def celestial_wcs(request): 1 yes yes """.strip() -@pytest.mark.parametrize('celestial_wcs', - ['celestial_2d_ape14_wcs', 'celestial_2d_fitswcs'], +@pytest.mark.parametrize("celestial_wcs", + ["celestial_2d_ape14_wcs", "celestial_2d_fitswcs"], indirect=True) def test_2d(celestial_wcs): @@ -74,11 +74,11 @@ def test_2d(celestial_wcs): # The following shouldn't change compared to the original WCS assert wcs.pixel_n_dim == 2 assert wcs.world_n_dim == 2 - assert tuple(wcs.world_axis_physical_types) == ('pos.eq.ra', 'pos.eq.dec') - assert tuple(wcs.world_axis_units) == ('deg', 'deg') - assert tuple(wcs.pixel_axis_names) == ('', '') - assert tuple(wcs.world_axis_names) == ('Right Ascension', - 'Declination') + assert tuple(wcs.world_axis_physical_types) == ("pos.eq.ra", "pos.eq.dec") + assert tuple(wcs.world_axis_units) == ("deg", "deg") + assert tuple(wcs.pixel_axis_names) == ("", "") + assert tuple(wcs.world_axis_names) == ("Right Ascension", + "Declination") assert_equal(wcs.axis_correlation_matrix, np.ones((2, 2))) # Shapes and bounds should be floating-point if needed @@ -93,7 +93,7 @@ def test_2d(celestial_wcs): assert_allclose(wcs.world_to_pixel_values(*world_scalar), pixel_scalar) assert_allclose(wcs.world_to_array_index_values(*world_scalar), [1, 2]) - EXPECTED_2D_REPR = EXPECTED_2D_REPR_NUMPY2 if np.__version__ >= '2.0.0' else EXPECTED_2D_REPR_NUMPY1 + EXPECTED_2D_REPR = EXPECTED_2D_REPR_NUMPY2 if np.__version__ >= "2.0.0" else EXPECTED_2D_REPR_NUMPY1 assert str(wcs) == EXPECTED_2D_REPR assert EXPECTED_2D_REPR in repr(wcs) @@ -121,8 +121,8 @@ def test_2d(celestial_wcs): assert_quantity_allclose(celestial.dec, world_array[1] * u.deg) -@pytest.mark.parametrize('celestial_wcs', - ['celestial_2d_ape14_wcs', 'celestial_2d_fitswcs'], +@pytest.mark.parametrize("celestial_wcs", + ["celestial_2d_ape14_wcs", "celestial_2d_fitswcs"], indirect=True) def test_scalar_factor(celestial_wcs): @@ -137,8 +137,8 @@ def test_scalar_factor(celestial_wcs): assert_allclose(wcs.world_to_array_index_values(*world_scalar), [4, 2]) -@pytest.mark.parametrize('celestial_wcs', - ['celestial_2d_ape14_wcs', 'celestial_2d_fitswcs'], +@pytest.mark.parametrize("celestial_wcs", + ["celestial_2d_ape14_wcs", "celestial_2d_fitswcs"], indirect=True) def test_offset(celestial_wcs): celestial_wcs.pixel_bounds = None @@ -156,24 +156,24 @@ def test_offset(celestial_wcs): assert_allclose(wcs.world_to_array_index_values(*world_scalar), [4, 2]) -@pytest.mark.parametrize('celestial_wcs', - ['celestial_2d_ape14_wcs'], +@pytest.mark.parametrize("celestial_wcs", + ["celestial_2d_ape14_wcs"], indirect=True) def test_factor_wrong_length_error(celestial_wcs): with pytest.raises(ValueError): ResampledLowLevelWCS(celestial_wcs, [2] * 3) -@pytest.mark.parametrize('celestial_wcs', - ['celestial_2d_ape14_wcs'], +@pytest.mark.parametrize("celestial_wcs", + ["celestial_2d_ape14_wcs"], indirect=True) def test_scalar_wrong_length_error(celestial_wcs): with pytest.raises(ValueError): ResampledLowLevelWCS(celestial_wcs, 2, offset=[1] * 3) -@pytest.mark.parametrize('celestial_wcs', - ['celestial_2d_ape14_wcs', 'celestial_2d_fitswcs'], +@pytest.mark.parametrize("celestial_wcs", + ["celestial_2d_ape14_wcs", "celestial_2d_fitswcs"], indirect=True) def test_int_fraction_pixel_shape(celestial_wcs): # Some fractional factors are not representable by exact floats, e.g. 1/3.