diff --git a/changes/2051.multiband_catalog.rst b/changes/2051.multiband_catalog.rst new file mode 100644 index 000000000..426f12f57 --- /dev/null +++ b/changes/2051.multiband_catalog.rst @@ -0,0 +1 @@ +Added PSF-matched photometry for multiband catalogs. diff --git a/romancal/multiband_catalog/catalog_generator.py b/romancal/multiband_catalog/catalog_generator.py new file mode 100644 index 000000000..d7541526b --- /dev/null +++ b/romancal/multiband_catalog/catalog_generator.py @@ -0,0 +1,287 @@ +""" +Helper function for generating per-filter catalogs in multiband catalog +creation. +""" + +import logging + +import numpy as np +from astropy.table import join +from roman_datamodels import datamodels + +from romancal.multiband_catalog.utils import add_filter_to_colnames +from romancal.source_catalog.psf_matching import ( + compute_psf_correction_factors, + create_psf_matched_image, + get_filter_wavelength, +) +from romancal.source_catalog.source_catalog import RomanSourceCatalog +from romancal.source_catalog.utils import get_ee_spline + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def create_filter_catalog( + model, + filter_name, + ref_filter, + ref_wavelength, + segment_img, + star_kernel_fwhm, + detection_catobj, + ref_model, + ref_filter_catalog, + ref_psf_model, + fit_psf, + get_reference_file_func, +): + """ + Create catalog(s) for a single filter, including PSF matching if needed. + + This function handles three cases: + 1. Reference filter: Only original measurements (no PSF-matched catalog) + 2. Bluer filter: Normal PSF matching (convolve to reference filter) + 3. Redder filter: Synthetic PSF matching via correction factors + + Parameters + ---------- + model : ImageModel or MosaicModel + The input image model for this filter. + + filter_name : str + Name of the filter (e.g., 'F158'). + + ref_filter : str + Name of the reference filter for PSF matching. + + ref_wavelength : int + Wavelength of the reference filter in microns. + + segment_img : ndarray + Segmentation image from detection. + + star_kernel_fwhm : float + FWHM of star kernel for source catalog. + + detection_catobj : RomanSourceCatalog + Detection catalog object. + + ref_model : ImageModel or MosaicModel + The reference filter image model (for computing correction factors). + + ref_filter_catalog : Table or None + The reference filter's catalog (needed for redder filters). + + ref_psf_model : EpsfRefModel + PSF reference model for the reference filter. + + fit_psf : bool + Whether to fit PSFs in the catalog. + + get_reference_file_func : callable + Function to get reference files (self.get_reference_file from step). + + Returns + ------- + result : dict + Dictionary containing: + - 'catalog': Combined catalog table for this filter + - 'ee_fractions': Dictionary of ee_fractions for this filter + + Raises + ------ + ValueError + If trying to process a redder filter before the reference filter. + """ + # Create mask + mask = ~np.isfinite(model.data) | ~np.isfinite(model.err) | (model.err <= 0) + + # Load PSF reference model (needed for PSF matching and PSF fitting) + log.info(f"Creating catalog for {filter_name} image") + ref_file = get_reference_file_func(model, "epsf") + log.info("Using ePSF reference file: %s", ref_file) + psf_model = datamodels.open(ref_file) + + apcorr_ref = get_reference_file_func(model, "apcorr") + ee_spline = get_ee_spline(model, apcorr_ref) + + # Create catalog with original (non-PSF-matched) data + log.info(f"Creating catalog for original {filter_name} image") + catobj_original = RomanSourceCatalog( + model, + segment_img, + None, + star_kernel_fwhm, + fit_psf=fit_psf, + psf_model=psf_model if fit_psf else None, + mask=mask, + detection_cat=detection_catobj, + cat_type="dr_band", + ee_spline=ee_spline, + ) + + # Store reference filter catalog for later use with redder filters + updated_ref_filter_catalog = ref_filter_catalog + if filter_name == ref_filter: + updated_ref_filter_catalog = catobj_original.catalog + + # Add the filter name to the column names + cat_original = add_filter_to_colnames(catobj_original.catalog, filter_name) + + # Store ee_fractions for this filter + ee_fractions = {} + ee_fractions[filter_name.lower()] = cat_original.meta["ee_fractions"] + + # Clear filter catalog metadata + cat_original.meta = None + + # Determine if this filter is bluer or redder than reference + filter_wavelength = get_filter_wavelength(filter_name) + + # Create PSF-matched catalog based on filter position + if filter_name == ref_filter: + # Reference filter case - only include original measurements + log.info( + f"Reference filter {filter_name}: including only original measurements" + ) + cat = cat_original + + elif filter_wavelength < ref_wavelength: + # Bluer filter - normal PSF matching (convolve to the reference + # image's PSF) + log.info(f"Creating PSF-matched image for {filter_name}") + psf_matched_model = create_psf_matched_image( + model, + psf_model, + ref_psf_model, + ) + + log.info(f"Creating catalog for PSF-matched {filter_name} image") + catobj_matched = RomanSourceCatalog( + psf_matched_model, + segment_img, + None, + star_kernel_fwhm, + fit_psf=False, # No PSF fitting on matched images + psf_model=None, + mask=mask, + detection_cat=detection_catobj, + cat_type="psf_matched", + ee_spline=ee_spline, + ) + + # Add filter name with "m" suffix for PSF-matched columns + filter_name_matched = f"{filter_name}m" + cat_matched = add_filter_to_colnames( + catobj_matched.catalog, filter_name_matched + ) + cat_matched.meta = None + + # Merge the original and PSF-matched catalogs + cat = join(cat_original, cat_matched, keys="label", join_type="outer") + + else: + # Redder filter - synthetic PSF matching using correction factors + log.info( + f"Creating synthetic PSF-matched catalog for " + f"{filter_name} (redder than reference {ref_filter})" + ) + + if ref_filter_catalog is None: + raise ValueError( + f"Reference filter {ref_filter} catalog not yet created. " + "Cannot compute correction factors for redder filter " + f"{filter_name}." + ) + + # Compute correction factors by matching reference image to + # the redder filter + correction_factors = compute_psf_correction_factors( + ref_model=ref_model, + ref_psf_model=ref_psf_model, + ref_catalog=ref_filter_catalog, + target_model=model, + target_psf_model=psf_model, + segment_img=segment_img, + star_kernel_fwhm=star_kernel_fwhm, + detection_cat=detection_catobj, + mask=mask, + ee_spline=ee_spline, + ) + + # Create synthetic PSF-matched catalog by applying + # correction factors to the original catalog + cat_synthetic = cat_original.copy() + + # Remove background aperture flux columns + cols_to_remove = [col for col in cat_synthetic.colnames if "aper_bkg_" in col] + cat_synthetic.remove_columns(cols_to_remove) + + # Apply correction factors to flux columns + for flux_col, correction in correction_factors.items(): + # flux_col names contain the reference filter name, so + # we need to replace it with the current filter name + flux_col_syn = flux_col.replace( + f"_{ref_filter.lower()}_", f"_{filter_name.lower()}_" + ) + + # Apply correction to flux column + if flux_col_syn in cat_synthetic.colnames: + cat_synthetic[flux_col_syn] *= correction + + # Also apply to error columns + err_col_syn = flux_col_syn.replace("_flux", "_flux_err") + if err_col_syn in cat_synthetic.colnames: + cat_synthetic[err_col_syn] *= correction + + # Handle magnitude columns (subtract 2.5 * log10(C)) + mag_col_syn = flux_col_syn.replace("_flux", "_abmag") + if mag_col_syn in cat_synthetic.colnames: + with np.errstate(divide="ignore", invalid="ignore"): + mag_correction = -2.5 * np.log10(correction) + mag_correction = np.where( + ~np.isfinite(mag_correction), + 0.0, + mag_correction, + ) + cat_synthetic[mag_col_syn] += mag_correction + + # Rename columns to add 'm' suffix and keep only flux columns + filter_name_matched = f"{filter_name}m" + cols_to_keep = ["label"] # Always keep label for joining + for colname in list(cat_synthetic.colnames): + # Replace filter name with filter name + 'm' + # e.g., kron_f158_flux -> kron_f158m_flux + # We not keep columns that *end* with the filter name, e.g., + # sharpness_f158. + if f"_{filter_name.lower()}_" in colname: + new_colname = colname.replace( + f"_{filter_name.lower()}_", + f"_{filter_name_matched.lower()}_", + ) + cat_synthetic.rename_column(colname, new_colname) + cols_to_keep.append(new_colname) + + # Keep only the PSF-matched columns (and label) + cat_synthetic = cat_synthetic[cols_to_keep] + + # Merge the two catalogs + cat = join( + cat_original, + cat_synthetic, + keys="label", + join_type="outer", + ) + log.info( + f"Synthetic PSF-matched catalog created for {filter_name} " + f"using correction factors" + ) + + log.info(f"Completed catalog for {filter_name} filter") + + return { + "catalog": cat, + "ee_fractions": ee_fractions, + "ref_filter_catalog": updated_ref_filter_catalog, + } diff --git a/romancal/multiband_catalog/detection_image.py b/romancal/multiband_catalog/detection_image.py index 049f25b6c..89ec11ce3 100644 --- a/romancal/multiband_catalog/detection_image.py +++ b/romancal/multiband_catalog/detection_image.py @@ -56,9 +56,6 @@ def make_det_image(library, kernel_fwhm): ------- detection_data : 2D `numpy.ndarray` The detection image data. - - detection_error : 2D `numpy.ndarray` - The detection image (standard deviation) error. """ if not isinstance(library, ModelLibrary): raise TypeError("library input must be a ModelLibrary object") diff --git a/romancal/multiband_catalog/metadata.py b/romancal/multiband_catalog/metadata.py new file mode 100644 index 000000000..61c016fb3 --- /dev/null +++ b/romancal/multiband_catalog/metadata.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +# Metadata keys to skip when accumulating image metadata +_SKIP_IMAGE_META_KEYS = {"wcs", "individual_image_meta"} + +# Metadata keys to skip when blending metadata +_SKIP_BLEND_KEYS = {"wcsinfo"} + + +def blend_image_metadata( + image_model, + cat_model, + time_means, + exposure_times, +): + """ + Accumulate and blend metadata from an individual filter image into + the catalog. + + This function: + 1. Extracts relevant metadata from the input image model + 2. Appends it to the catalog's image_meta list + 3. Blends metadata values across filters, setting mismatches to None + 4. Handles special cases like coadd_info timing information + 5. Updates file_date to the earliest date + + This function modifies cat_model, time_means, and exposure_times in + place. + + Parameters + ---------- + image_model : ImageModel or MosaicModel + The input image model for a single filter. + + cat_model : MultibandSourceCatalogModel + The multiband catalog model being built. + + time_means : list + List to accumulate mean observation times (modified in place). + + exposure_times : list + List to accumulate exposure times (modified in place). + """ + # Accumulate image metadata + image_meta = { + k: deepcopy(v) + for k, v in image_model["meta"].items() + if k not in _SKIP_IMAGE_META_KEYS + } + cat_model.meta.image_metas.append(image_meta) + + # Blend model with catalog metadata + if image_model.meta.file_date < cat_model.meta.image.file_date: + cat_model.meta.image.file_date = image_model.meta.file_date + + for key, value in image_meta.items(): + if key in _SKIP_BLEND_KEYS: + continue + if not isinstance(value, dict): + # skip blending of single top-level values + continue + if key not in cat_model.meta: + # skip blending if the key is not in the catalog meta + continue + if key == "coadd_info": + cat_model.meta[key]["time_first"] = min( + cat_model.meta[key]["time_first"], value["time_first"] + ) + cat_model.meta[key]["time_last"] = max( + cat_model.meta[key]["time_last"], value["time_last"] + ) + time_means.append(value["time_mean"]) + exposure_times.append(value["exposure_time"]) + else: + # set non-matching metadata values to None + for subkey, subvalue in value.items(): + if cat_model.meta[key].get(subkey, None) != subvalue: + cat_model.meta[key][subkey] = None diff --git a/romancal/multiband_catalog/multiband_catalog.py b/romancal/multiband_catalog/multiband_catalog.py index 5950acf29..711eda404 100644 --- a/romancal/multiband_catalog/multiband_catalog.py +++ b/romancal/multiband_catalog/multiband_catalog.py @@ -16,78 +16,87 @@ from romancal.datamodels import ModelLibrary from romancal.multiband_catalog.background import subtract_background_library +from romancal.multiband_catalog.catalog_generator import create_filter_catalog from romancal.multiband_catalog.detection_image import make_detection_image -from romancal.multiband_catalog.utils import add_filter_to_colnames +from romancal.multiband_catalog.metadata import blend_image_metadata from romancal.source_catalog import injection from romancal.source_catalog.background import RomanBackground from romancal.source_catalog.detection import make_segmentation_image +from romancal.source_catalog.psf_matching import ( + get_filter_wavelength, + get_reddest_filter, +) from romancal.source_catalog.source_catalog import RomanSourceCatalog -from romancal.source_catalog.utils import get_ee_spline log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -_SKIP_IMAGE_META_KEYS = {"wcs", "individual_image_meta"} -_SKIP_BLEND_KEYS = {"wcsinfo"} - - -def multiband_catalog(self, library, example_model, cat_model, ee_spline): +def process_detection_image(self, library, example_model, ee_spline, catalog_model): """ - Create a multiband catalog of sources including photometry and basic - shape measurements. + Create and process the detection image. + + This includes background estimation, source detection via + segmentation, and creation of the detection catalog. Parameters - ----------- - library : `~romancal.datamodels.ModelLibrary` - The library of models. - example_model : `MosaicModel` or `ImageModel` - Example model. - cat_model : `MultibandSourceCatalogModel` - Catalog model. - ee_spline : `astropy.modeling.fitting.SplineSplrepFitter + ---------- + self : object + The multiband catalog step instance. + + library : `romancal.datamodels.ModelLibrary` + The library of models to process. + + example_model : `romancal.datamodels.MosaicImageModel` + An example model from the library for metadata access. + + ee_spline : callable + The encircled energy spline function. + + catalog_model : `romancal.datamodels.MultibandSourceCatalogModel` + The output catalog model (for saving empty results if + needed). Returns ------- - segment_img : `SegmentationImage` or set - The segmentation image. - cat_model : `MultibandSourceCatalogModel` - Updated catalog. - msg : str (optional) - Reason for empty file. + result : dict or tuple + If successful, returns a dictionary with keys: + - 'detection_model': The detection image model + - 'mask': The total mask array + - 'segment_img': The segmentation image + - 'detection_catobj': The detection RomanSourceCatalog object + - 'detection_catalog': The detection catalog table + - 'star_kernel_fwhm': The stellar kernel FWHM + + If detection fails, returns the result of + save_empty_results(). """ - # All input MosaicImages in the ModelLibrary are assumed to have - # the same shape and be pixel aligned. - - log.info("Calculating and subtracting background") - library = subtract_background_library(library, self.bkg_boxsize) - log.info("Creating detection image") + # Define the kernel FWHMs for the detection image # TODO: sensible defaults # TODO: redefine in terms of intrinsic FWHM if self.kernel_fwhms is None: self.kernel_fwhms = [2.0, 5.0] - # TODO: det_img is saved in the MosaicSegmentationMapModel; - # do we also want to save the det_err? - det_img = make_detection_image(library, self.kernel_fwhms) + detection_image = make_detection_image(library, self.kernel_fwhms) # Estimate background rms from detection image to calculate a # threshold for source detection - mask = ~np.isfinite(det_img) + mask = ~np.isfinite(detection_image) - # Return an empty segmentation image and catalog table if all - # pixels are masked in the detection image. + # Return image shape and empty catalog table if all pixels are + # masked in the detection image if np.all(mask): msg = ( "Cannot create source catalog. All " "pixels in the detection image are masked." ) - return det_img.shape, cat_model, msg + return detection_image.shape, catalog_model, msg + log.info("Calculating background RMS for detection image") bkg = RomanBackground( - det_img, + detection_image, box_size=self.bkg_boxsize, coverage_mask=mask, ) @@ -95,7 +104,7 @@ def multiband_catalog(self, library, example_model, cat_model, ee_spline): log.info("Detecting sources") segment_img = make_segmentation_image( - det_img, + detection_image, snr_threshold=self.snr_threshold, npixels=self.npixels, bkg_rms=bkg_rms, @@ -103,21 +112,23 @@ def multiband_catalog(self, library, example_model, cat_model, ee_spline): mask=mask, ) + # Return image shape and empty catalog table if no sources + # were detected if segment_img is None: # no sources found msg = "Cannot create source catalog. No sources were detected." - return det_img.shape, cat_model, msg + return detection_image.shape, catalog_model, msg - segment_img.detection_image = det_img.copy() + segment_img.detection_image = detection_image.copy() # Define the detection image model - det_model = datamodels.MosaicModel.create_minimal() - det_model.data = det_img - det_model.err = np.ones_like(det_img) + detection_model = datamodels.MosaicModel.create_minimal() + detection_model.data = detection_image + detection_model.err = np.ones_like(detection_image) # TODO: this is a temporary solution to get model attributes # currently needed in RomanSourceCatalog - det_model.weight = example_model.weight - det_model.meta = example_model.meta + detection_model.weight = example_model.weight + detection_model.meta = example_model.meta # The stellar FWHM is needed to define the kernel used for # the DAOStarFinder sharpness and roundness properties. @@ -127,10 +138,10 @@ def multiband_catalog(self, library, example_model, cat_model, ee_spline): star_kernel_fwhm = np.min(self.kernel_fwhms) log.info("Creating catalog for detection image") - det_catobj = RomanSourceCatalog( - det_model, + detection_catobj = RomanSourceCatalog( + detection_model, segment_img, - det_img, + detection_image, star_kernel_fwhm, fit_psf=False, # not needed for detection image detection_cat=None, @@ -139,107 +150,369 @@ def multiband_catalog(self, library, example_model, cat_model, ee_spline): ee_spline=ee_spline, ) - # Generate the catalog for the detection image. - # We need to make this catalog before we pass det_catobj - # to the RomanSourceCatalog constructor. - det_cat = det_catobj.catalog - det_cat.meta["ee_fractions"] = {} + # Generate the catalog for the detection image. The catalog + # is lazily evalated, so we need to access it before we pass + # detection_catobj to the RomanSourceCatalog constructor. + detection_catalog = detection_catobj.catalog + + return { + "detection_model": detection_model, + "mask": mask, + "segment_img": segment_img, + "detection_catobj": detection_catobj, + "detection_catalog": detection_catalog, + "star_kernel_fwhm": star_kernel_fwhm, + } + + +def join_filter_catalogs(detection_catalog, filter_catalogs): + """ + Join filter catalogs to detection catalog in wavelength order. + + Parameters + ---------- + detection_catalog : Table + The detection catalog to which filter catalogs will be joined. + + filter_catalogs : dict + Dictionary with (filter_name, wavelength) tuples as keys and + catalog tables as values. + + Returns + ------- + result : Table + The detection catalog with all filter catalogs joined in + wavelength order. + """ + # Sort by wavelength (second element of tuple key) + sorted_filter_keys = sorted(filter_catalogs.keys(), key=lambda x: x[1]) + + for filter_key in sorted_filter_keys: + cat = filter_catalogs[filter_key] + # The outer join prevents an empty table if any + # columns have the same name but different values + # (e.g., repeated filter names) + detection_catalog = join( + detection_catalog, cat, keys="label", join_type="outer" + ) + + return detection_catalog + + +def finalize_ee_fractions(detection_catalog, filter_ee_fractions): + """ + Consolidate and finalize encircled energy fractions. + + Accumulates ee_fractions from all filter processing and sorts + them by wavelength in the detection catalog metadata. Only + includes ee_fractions for original (non-PSF-matched) filter + bands. + + The method modifies detection_catalog.meta["ee_fractions"] in + place. + + Parameters + ---------- + detection_catalog : Table + The detection catalog where ee_fractions metadata will be + stored. + + filter_ee_fractions : list of dict + List of ee_fractions dictionaries from each filter, + where each dict maps filter names to ee_fractions arrays. + """ + detection_catalog.meta["ee_fractions"] = {} + + # Accumulate all ee_fractions from filters + for ee_fracs in filter_ee_fractions: + for key, value in ee_fracs.items(): + detection_catalog.meta["ee_fractions"][key] = value + + # Sort ee_fractions dictionary by filter wavelength + if detection_catalog.meta.get("ee_fractions"): + sorted_ee_fractions = dict( + sorted( + detection_catalog.meta["ee_fractions"].items(), + key=lambda item: get_filter_wavelength(item[0]), + ) + ) + detection_catalog.meta["ee_fractions"] = sorted_ee_fractions + + +def prepare_reference_filter(self, library): + """ + Determine and load the reference filter for PSF matching. + + Parameters + ---------- + self : object + The multiband catalog step instance. + + library : `romancal.datamodels.ModelLibrary` + The library of models to process. + + Returns + ------- + result : dict + Dictionary with keys: + - 'ref_filter': The reference filter name (uppercase) + - 'ref_wavelength': The reference filter approx wavelength (nm) + - 'ref_model': The reference filter model + - 'ref_psf_model': The reference PSF model + """ + # Determine reference filter for PSF matching + if self.reference_filter is None: + # Default to reddest filter + ref_filter = get_reddest_filter(library) + log.info(f"Using reddest filter as reference: {ref_filter}") + else: + # User specified reference filter + ref_filter = self.reference_filter.upper() + log.info(f"Using user-specified reference filter: {ref_filter}") + + # Load reference PSF model + ref_model = None + with library: + for model in library: + if model.meta.instrument.optical_element == ref_filter: + ref_model = model + library.shelve(model, modify=False) + if ref_model is not None: + break + + if ref_model is None: + msg = ( + f"Reference filter {ref_filter} not found in library. " + "Cannot perform PSF matching." + ) + raise ValueError(msg) + + ref_psf_file = self.get_reference_file(ref_model, "epsf") + ref_psf_model = datamodels.open(ref_psf_file) + log.info(f"Using reference PSF: {ref_psf_file}") + + # Get reference filter approximate wavelength (nm) based on filter name + ref_wavelength = get_filter_wavelength(ref_filter) + + return { + "ref_filter": ref_filter, + "ref_wavelength": ref_wavelength, + "ref_model": ref_model, + "ref_psf_model": ref_psf_model, + } + + +def prepare_processing_order(library, ref_filter): + """ + Prepare the processing order for filter models. + + Ensures the reference filter is processed first so its catalog is + available when processing redder filters. + + Parameters + ---------- + library : `romancal.datamodels.ModelLibrary` + The library of models to process. + + ref_filter : str + The reference filter name. + + Returns + ------- + result : list of tuple + List of (model_index, filter_name) tuples sorted so reference + filter is first. + """ + # Create list of model indices in processing order with reference + # filter first, then all others + model_indices = [] + with library: + for i, model in enumerate(library): + filter_name = model.meta.instrument.optical_element + model_indices.append((i, filter_name)) + library.shelve(model, modify=False) + + # Sort so reference filter is first + def sort_key(index_filter_tuple): + _, filt = index_filter_tuple + return 0 if filt == ref_filter else 1 + + model_indices.sort(key=sort_key) + + return model_indices + + +def multiband_catalog(self, library, example_model, catalog_model, ee_spline): + """ + Create a multiband catalog of sources including photometry and basic + shape measurements. + + Parameters + ----------- + library : `~romancal.datamodels.ModelLibrary` + The library of models. + example_model : `MosaicModel` or `ImageModel` + Example model. + catalog_model : `MultibandSourceCatalogModel` + Catalog model. + ee_spline : `astropy.modeling.fitting.SplineSplrepFitter + + Returns + ------- + segment_img : `SegmentationImage` or set + The segmentation image. + catalog_model : `MultibandSourceCatalogModel` + Updated catalog. + msg : str (optional) + Reason for empty file. + + Notes + ----- + All input MosaicImages in the ModelLibrary are assumed to have the + same shape and be pixel aligned. + """ + log.info("Calculating and subtracting background") + library = subtract_background_library(library, self.bkg_boxsize) + + # Create detection image, segmentation, and catalog + detection_result = process_detection_image( + self, library, example_model, ee_spline, catalog_model + ) + + # Check if detection failed (step returns save_empty_results) + if isinstance(detection_result, tuple): + log.warning("Detection image processing failed") + return detection_result + # Extract detection results + segment_img = detection_result["segment_img"] + detection_catobj = detection_result["detection_catobj"] + detection_catalog = detection_result["detection_catalog"] + star_kernel_fwhm = detection_result["star_kernel_fwhm"] + + # Setup reference filter for PSF matching + ref_info = prepare_reference_filter(self, library) + ref_filter = ref_info["ref_filter"] + ref_wavelength = ref_info["ref_wavelength"] + ref_model = ref_info["ref_model"] + ref_psf_model = ref_info["ref_psf_model"] + + # Record the PSF match reference filter in metadata + detection_catalog.meta["psf_match_reference_filter"] = ref_filter.lower() + + # Store reference filter's catalog for computing correction factors + ref_filter_catalog = None # defined in processing loop + + # Prepare to accumulate filter catalogs and metadata time_means = [] exposure_times = [] + filter_catalogs = {} + filter_ee_fractions = [] + + # Prepare processing order (reference filter first) + model_indices = prepare_processing_order(library, ref_filter) # Create catalogs for each input image with library: - for model in library: - mask = ~np.isfinite(model.data) | ~np.isfinite(model.err) | (model.err <= 0) - - if self.fit_psf: - filter_name = model.meta.instrument.optical_element - log.info(f"Creating catalog for {filter_name} image") - ref_file = self.get_reference_file(model, "epsf") - log.info("Using ePSF reference file: %s", ref_file) - psf_ref_model = datamodels.open(ref_file) - else: - psf_ref_model = None - - apcorr_ref = self.get_reference_file(model, "apcorr") - ee_spline = get_ee_spline(model, apcorr_ref) - - catobj = RomanSourceCatalog( - model, - segment_img, - None, - star_kernel_fwhm, + for model_index, _ in model_indices: + model = library.borrow(model_index) + filter_name = model.meta.instrument.optical_element + + # Create catalog for this filter + result = create_filter_catalog( + model=model, + filter_name=filter_name, + ref_filter=ref_filter, + ref_wavelength=ref_wavelength, + segment_img=segment_img, + star_kernel_fwhm=star_kernel_fwhm, + detection_catobj=detection_catobj, + ref_model=ref_model, + ref_filter_catalog=ref_filter_catalog, + ref_psf_model=ref_psf_model, fit_psf=self.fit_psf, - detection_cat=det_catobj, - mask=mask, - psf_ref_model=psf_ref_model, - cat_type="dr_band", - ee_spline=ee_spline, + get_reference_file_func=self.get_reference_file, ) - # Add the filter name to the column names - filter_name = model.meta.instrument.optical_element - cat = add_filter_to_colnames(catobj.catalog, filter_name) - ee_fractions = cat.meta["ee_fractions"] - - # TODO: what metadata do we want to keep, if any, - # from the filter catalogs? - cat.meta = None - - # Add the filter catalog to the multiband catalog. - # The outer join prevents an empty table if any - # columns have the same name but different values - # (e.g., repeated filter names) - det_cat = join(det_cat, cat, keys="label", join_type="outer") - det_cat.meta["ee_fractions"][filter_name.lower()] = ee_fractions - - # accumulate image metadata - image_meta = { - k: copy.deepcopy(v) - for k, v in model["meta"].items() - if k not in _SKIP_IMAGE_META_KEYS - } - cat_model.meta.image_metas.append(image_meta) - - # blend model with catalog metadata - if model.meta.file_date < cat_model.meta.image.file_date: - cat_model.meta.image.file_date = model.meta.file_date - - for key, value in image_meta.items(): - if key in _SKIP_BLEND_KEYS: - continue - if not isinstance(value, dict): - # skip blending of single top-level values - continue - if key not in cat_model.meta: - # skip blending if the key is not in the catalog meta - continue - if key == "coadd_info": - cat_model.meta[key]["time_first"] = min( - cat_model.meta[key]["time_first"], value["time_first"] - ) - cat_model.meta[key]["time_last"] = max( - cat_model.meta[key]["time_last"], value["time_last"] - ) - time_means.append(value["time_mean"]) - exposure_times.append(value["exposure_time"]) - else: - # set non-matching metadata values to None - for subkey, subvalue in value.items(): - if cat_model.meta[key].get(subkey, None) != subvalue: - cat_model.meta[key][subkey] = None + # Store the reference filter catalog. This will be None + # until we process the reference filter. + ref_filter_catalog = result["ref_filter_catalog"] + + # Store ee_fractions for consolidation + filter_ee_fractions.append(result["ee_fractions"]) + + # Store filter catalog for later joining in wavelength order + filter_wavelength = get_filter_wavelength(filter_name) + cat = result["catalog"] + filter_catalogs[(filter_name, filter_wavelength)] = cat + + # Accumulate and blend image metadata + blend_image_metadata(model, catalog_model, time_means, exposure_times) library.shelve(model, modify=False) - # finish blending - cat_model.meta.coadd_info.time_mean = Time(time_means).mean() - cat_model.meta.coadd_info.exposure_time = np.mean(exposure_times) + # Join all filter catalogs to detection catalog + detection_catalog = join_filter_catalogs(detection_catalog, filter_catalogs) + + # Finish blending + catalog_model.meta.coadd_info.time_mean = Time(time_means).mean() + catalog_model.meta.coadd_info.exposure_time = np.mean(exposure_times) + + # Consolidate and sort ee_fractions + finalize_ee_fractions(detection_catalog, filter_ee_fractions) # Put the resulting multiband catalog in the model - cat_model.source_catalog = det_cat + catalog_model.source_catalog = detection_catalog + + return segment_img, catalog_model, None + + +def initialize_catalog_model(library, example_model): + """ + Initialize the multiband source catalog model. - return segment_img, cat_model, None + Creates the catalog model and sets up initial metadata from + the example model and association information. + + Parameters + ---------- + library : `romancal.datamodels.ModelLibrary` + The library of models to process. + + example_model : `romancal.datamodels.MosaicImageModel` + An example model from the library for metadata access. + + Returns + ------- + result : `romancal.datamodels.MultibandSourceCatalogModel` + The initialized catalog model with set metadata. + """ + # Initialize the source catalog model, copying the metadata + # from the example model. Some of this may be overwritten + # during metadata blending. + catalog_model = datamodels.MultibandSourceCatalogModel.create_minimal( + {"meta": example_model.meta} + ) + catalog_model.meta["image"] = { + # try to record association name else fall back to example model + # filename + "filename": library.asn.get("table_name", example_model.meta.filename), + # file_date may be overwritten during metadata blending + "file_date": example_model.meta.file_date, + } + catalog_model.meta["image_metas"] = [] + + # copy over data_release_id, ideally this will come from the association + if "data_release_id" in example_model.meta: + catalog_model.meta.data_release_id = example_model.meta.data_release_id + + # Define the output filename for the source catalog model + try: + catalog_model.meta.filename = library.asn["products"][0]["name"] + except (AttributeError, KeyError): + catalog_model.meta.filename = "multiband_catalog" + + return catalog_model def make_source_injected_library(library): diff --git a/romancal/multiband_catalog/multiband_catalog_step.py b/romancal/multiband_catalog/multiband_catalog_step.py index 83fe58638..9acb5314e 100644 --- a/romancal/multiband_catalog/multiband_catalog_step.py +++ b/romancal/multiband_catalog/multiband_catalog_step.py @@ -8,10 +8,9 @@ import logging from typing import TYPE_CHECKING -from roman_datamodels import datamodels - from romancal.datamodels import ModelLibrary from romancal.multiband_catalog.multiband_catalog import ( + initialize_catalog_model, make_source_injected_library, match_recovered_sources, multiband_catalog, @@ -53,6 +52,7 @@ class MultibandCatalogStep(RomanStep): deblend = boolean(default=False) # deblend sources? suffix = string(default='cat') # Default suffix for output files fit_psf = boolean(default=True) # fit source PSFs for accurate astrometry? + reference_filter = string(default=None) # reference filter for PSF matching inject_sources = boolean(default=False) # Inject sources into images save_debug_info = boolean(default=False) # Include image data and other data for testing @@ -70,45 +70,27 @@ def process(self, library): example_model = library.borrow(0) library.shelve(example_model, modify=False) - # Initialize the source catalog model, copying the metadata - # from the example model. Some of this may be overwritten - # during metadata blending. - cat_model = datamodels.MultibandSourceCatalogModel.create_minimal( - {"meta": example_model.meta} - ) - cat_model.meta["image"] = { - # try to record association name else fall back to example model filename - "filename": library.asn.get("table_name", example_model.meta.filename), - "file_date": example_model.meta.file_date, - # this may be overwritten during metadata blending - } - cat_model.meta["image_metas"] = [] - # copy over data_release_id, ideally this will come from the association - if "data_release_id" in example_model.meta: - cat_model.meta.data_release_id = example_model.meta.data_release_id + # Initialize the source catalog model + cat_model = initialize_catalog_model(library, example_model) log.info("Creating ee_fractions model for first image") apcorr_ref = self.get_reference_file(example_model, "apcorr") ee_spline = get_ee_spline(example_model, apcorr_ref) - # Define the output filename for the source catalog model - try: - cat_model.meta.filename = library.asn["products"][0]["name"] - except (AttributeError, KeyError): - cat_model.meta.filename = "multiband_catalog" - # Set up source injection library and injection catalog if self.inject_sources: si_library, si_cat = make_source_injected_library(library) - # Create catalog of library images - segment_img, cat_model, msg = multiband_catalog( + # Create the multiband catalog + *results, msg = multiband_catalog( self, library, example_model, cat_model, ee_spline ) - # The results are empty - if msg is not None: - return save_empty_results(self, segment_img, cat_model, msg=msg) + # Save empty results if there was an error + if msg is None: + segment_img, cat_model = results + else: + return save_empty_results(self, *results, msg=msg) # Source Injection if self.inject_sources: @@ -116,15 +98,13 @@ def process(self, library): si_example_model = si_library.borrow(0) si_library.shelve(si_example_model, modify=False) - si_ee_spline = get_ee_spline(si_example_model, apcorr_ref) - # Create catalog of source injected images si_segment_img, si_cat_model, _ = multiband_catalog( self, si_library, si_example_model, copy.deepcopy(cat_model), - si_ee_spline, + ee_spline, ) # Match sources diff --git a/romancal/multiband_catalog/tests/test_multiband_catalog.py b/romancal/multiband_catalog/tests/test_multiband_catalog.py index d89bb4952..80d11c50a 100644 --- a/romancal/multiband_catalog/tests/test_multiband_catalog.py +++ b/romancal/multiband_catalog/tests/test_multiband_catalog.py @@ -95,17 +95,41 @@ def shared_tests( library_model.shelve(input_model, modify=False) assert len(cat.meta["aperture_radii"]["circle_pix"]) > 0 + + # Check for original (non-PSF-matched) columns assert sum( 1 for name in cat.colnames if match(r"^aper\d+_f158_flux$", name) ) == len(cat.meta["aperture_radii"]["circle_pix"]) assert sum( 1 for name in cat.colnames if match(r"^aper\d+_f184_flux$", name) ) == len(cat.meta["aperture_radii"]["circle_pix"]) + + # Check for PSF-matched columns + # Target filter (F184) should NOT have PSF-matched columns + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f158m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + assert sum(1 for name in cat.colnames if match(r"^aper\d+_f184m_flux$", name)) == 0 assert "ee_fractions" in cat.meta assert isinstance(cat.meta["ee_fractions"], dict) + # f158, f184 (no PSF-matched bands) assert len(cat.meta["ee_fractions"]) == 2 + + # check that sharpness, roundness1, is_extended, and + # fluxfrac_radius_50 columns are not present for matched bands + assert all( + f"{param}_f158m" not in cat.colnames + for param in ["sharpness", "roundness1", "is_extended", "fluxfrac_radius_50"] + ) + + # Check metadata for PSF match reference filter + assert "psf_match_reference_filter" in cat.meta + assert cat.meta["psf_match_reference_filter"] == "f184" assert "f158" in cat.meta["ee_fractions"] assert "f184" in cat.meta["ee_fractions"] + # PSF-matched bands should not have ee_fractions + assert "f158m" not in cat.meta["ee_fractions"] + assert "f184m" not in cat.meta["ee_fractions"] for value in cat.meta["ee_fractions"].values(): assert len(value) == len(cat.meta["aperture_radii"]["circle_pix"]) @@ -379,7 +403,7 @@ def test_multiband_source_injection_catalog( result, cat, library_model2, - test_multiband_catalog, + save_results, function_jail, shape=(5000, 5000), ) @@ -453,3 +477,194 @@ def test_match_recovered_sources(): # Test columns included or excluded as expected assert "one" in rec_table.colnames assert "empty" not in rec_table.colnames + + +@pytest.fixture +def library_model_three_filters(mosaic_model): + """ + Library with F062, F158, F184. + """ + model1 = mosaic_model.copy() + model1.meta.instrument.optical_element = "F062" + + model2 = mosaic_model.copy() + model2.meta.instrument.optical_element = "F158" + + model3 = mosaic_model.copy() + model3.meta.instrument.optical_element = "F184" + + # input models not in wavelength order to test sorting + return ModelLibrary([model2, model3, model1]) + + +@pytest.mark.parametrize("fit_psf", (True, False)) +def test_multiband_catalog_reddest_reference( + library_model_three_filters, fit_psf, function_jail +): + """ + Test PSF matching when reference filter is the reddest. + + This tests the case where we have [F062, F158, F184] and F184 is + selected as the reference (longest wavelength). F062 and F158 get + PSF-matched to F184. Because F184 is the reference, no PSF-matched + columns are created for F184. + """ + step = MultibandCatalogStep() + + result = step.call( + library_model_three_filters, + bkg_boxsize=50, + snr_threshold=3, + npixels=10, + fit_psf=fit_psf, + deblend=True, + save_results=False, + ) + + cat = result.source_catalog + assert isinstance(cat, Table) + assert len(cat) == 7 + + # Check metadata for PSF match reference filter + assert "psf_match_reference_filter" in cat.meta + assert cat.meta["psf_match_reference_filter"] == "f184" + + # F062 and F158 should have PSF-matched columns (matched to F184) + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f062m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f158m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + + # F184 (reference) should NOT have PSF-matched columns + assert sum(1 for name in cat.colnames if match(r"^aper\d+_f184m_flux$", name)) == 0 + + # Check ee_fractions - should have 3 keys: f062, f158, f184 (no + # PSF-matched bands) + assert "ee_fractions" in cat.meta + assert len(cat.meta["ee_fractions"]) == 3 + assert "f062" in cat.meta["ee_fractions"] + assert "f158" in cat.meta["ee_fractions"] + assert "f184" in cat.meta["ee_fractions"] + # PSF-matched bands should not have ee_fractions + assert "f062m" not in cat.meta["ee_fractions"] + assert "f158m" not in cat.meta["ee_fractions"] + assert "f184m" not in cat.meta["ee_fractions"] + + +def test_multiband_catalog_middle_reference(library_model_three_filters, function_jail): + """ + Test PSF matching when reference filter is in the middle. + + This tests the case where we have [F062, F158, F184] and F158 is + selected as the reference (middle wavelength). F062 gets normal + PSF-matching (convolved to F158), while F184 gets synthetic + PSF-matching (correction factors). F158 has no PSF-matched columns + as it is the reference. + """ + step = MultibandCatalogStep() + + result = step.call( + library_model_three_filters, + bkg_boxsize=50, + snr_threshold=3, + npixels=10, + fit_psf=False, + deblend=True, + reference_filter="F158", + save_results=False, + ) + + cat = result.source_catalog + assert isinstance(cat, Table) + assert len(cat) == 7 + + # Check metadata for PSF match reference filter + assert "psf_match_reference_filter" in cat.meta + assert cat.meta["psf_match_reference_filter"] == "f158" + + # F062 (bluer) should have PSF-matched columns (normal convolution + # to F158) + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f062m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + + # F158 (reference) should NOT have PSF-matched columns + assert sum(1 for name in cat.colnames if match(r"^aper\d+_f158m_flux$", name)) == 0 + + # F184 (redder) should have PSF-matched columns (synthetic via + # correction factors) + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f184m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + + # Check ee_fractions - should have 3 keys: f062, f158, f184 (no + # PSF-matched bands) + assert "ee_fractions" in cat.meta + assert len(cat.meta["ee_fractions"]) == 3 + assert "f062" in cat.meta["ee_fractions"] + assert "f158" in cat.meta["ee_fractions"] + assert "f184" in cat.meta["ee_fractions"] + # PSF-matched bands should not have ee_fractions + assert "f062m" not in cat.meta["ee_fractions"] + assert "f158m" not in cat.meta["ee_fractions"] + assert "f184m" not in cat.meta["ee_fractions"] + + +def test_multiband_catalog_bluest_reference(library_model_three_filters, function_jail): + """ + Test PSF matching when reference filter is the bluest. + + This tests the case where we have [F062, F158, F184] and F062 is + selected as the reference (shortest wavelength). Both F158 and + F184 get synthetic PSF-matching (correction factors). F062 has no + PSF-matched columns as it is the reference. + """ + step = MultibandCatalogStep() + + result = step.call( + library_model_three_filters, + bkg_boxsize=50, + snr_threshold=3, + npixels=10, + fit_psf=False, + deblend=True, + reference_filter="F062", + save_results=False, + ) + + cat = result.source_catalog + assert isinstance(cat, Table) + assert len(cat) == 7 + + # Check metadata for PSF match reference filter + assert "psf_match_reference_filter" in cat.meta + assert cat.meta["psf_match_reference_filter"] == "f062" + + # F062 (reference) should NOT have PSF-matched columns + assert sum(1 for name in cat.colnames if match(r"^aper\d+_f062m_flux$", name)) == 0 + + # F158 (redder) should have PSF-matched columns (synthetic via + # correction factors) + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f158m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + + # F184 (reddest) should have PSF-matched columns (synthetic via + # correction factors) + assert sum( + 1 for name in cat.colnames if match(r"^aper\d+_f184m_flux$", name) + ) == len(cat.meta["aperture_radii"]["circle_pix"]) + + # Check ee_fractions - should have 3 keys: f062, f158, f184 (no + # PSF-matched bands) + assert "ee_fractions" in cat.meta + assert len(cat.meta["ee_fractions"]) == 3 + assert "f062" in cat.meta["ee_fractions"] + assert "f158" in cat.meta["ee_fractions"] + assert "f184" in cat.meta["ee_fractions"] + # PSF-matched bands should not have ee_fractions + assert "f062m" not in cat.meta["ee_fractions"] + assert "f158m" not in cat.meta["ee_fractions"] + assert "f184m" not in cat.meta["ee_fractions"] diff --git a/romancal/source_catalog/psf_matching.py b/romancal/source_catalog/psf_matching.py new file mode 100644 index 000000000..cfd31cf21 --- /dev/null +++ b/romancal/source_catalog/psf_matching.py @@ -0,0 +1,337 @@ +""" +Module for PSF matching of images. +""" + +import logging + +import numpy as np +from astropy.convolution import convolve_fft +from roman_datamodels.datamodels import ImageModel, MosaicModel + +from romancal.multiband_catalog.utils import add_filter_to_colnames +from romancal.source_catalog.psf import ( + create_convolution_kernel, + create_l3_psf_model, +) +from romancal.source_catalog.utils import copy_model_arrays + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def create_psf_matched_image( + model, + psf_model, + target_psf_model, + min_fft_power_ratio=1e-5, +): + """ + Create a PSF-matched version of the input model. + + This function convolves the input image with a matching kernel to + transform its PSF to match a target PSF. + + Parameters + ---------- + model : `ImageModel` or `MosaicModel` + The input image to PSF-match. The image is assumed to be + background subtracted. If the input model data and error arrays + have units, they will be preserved in the output. + + psf_model : EpsfRefModel + The PSF model for the input ``model`` image. + + target_psf_model : EpsfRefModel + The PSF model for the reference/target PSF (typically the + broadest PSF in a multiband set). + + min_fft_power_ratio : float, optional + Regularization parameter for the matching kernel calculation. + Controls the scale of regularization in terms of the peak power + of the input PSF's FFT. Larger values correspond to stronger + regularization. Default is 1e-5. + + Returns + ------- + matched_model : `ImageModel` or `MosaicModel` + A copy of the input model with PSF-matched data and error + arrays. + """ + if not isinstance(model, (ImageModel, MosaicModel)): + raise ValueError("model must be an ImageModel or MosaicModel") + + # Get filter information + input_filter = model.meta.instrument.optical_element + target_filter = target_psf_model.meta.instrument.optical_element + + log.info(f"Creating PSF-matched image: {input_filter} -> {target_filter}") + + # Check if input PSF is the same or broader than target PSF by + # comparing filter wavelengths (longer wavelength -> broader PSF) + input_wavelength = get_filter_wavelength(input_filter) + target_wavelength = get_filter_wavelength(target_filter) + if input_wavelength >= target_wavelength: + log.warning( + f"Input filter ({input_filter}, {input_wavelength} μm) has the " + f"same or longer wavelength than target filter ({target_filter}, " + f"{target_wavelength} μm). No PSF matching performed." + ) + + # Return the input model + return model + + # Create the matching kernel + log.info("Computing PSF matching kernel") + + # Create L3 PSF models to access PSF data arrays + input_l3_psf_model = create_l3_psf_model(psf_model) + target_l3_psf_model = create_l3_psf_model(target_psf_model) + + # Check oversampling factors -- create_convolution_kernel requires + # a single oversampling factor. + input_oversampling = input_l3_psf_model.oversampling + if input_oversampling[0] != input_oversampling[1]: + msg = ( + "Input PSF model has different oversampling factors in x and y directions." + ) + raise ValueError(msg) + target_oversampling = target_l3_psf_model.oversampling + if np.all(input_oversampling != target_oversampling): + msg = "Input and target PSF models have different oversampling factors." + raise ValueError(msg) + oversampling = input_oversampling[0] + + matching_kernel = create_convolution_kernel( + input_l3_psf_model.data, + target_l3_psf_model.data, + downsample=oversampling, + ) + + # Convolve the data + # np.asarray() is needed to convert to an ndarray from + # asdf.tags.core.ndarray.NDArrayType, which can cause issues with + # convolve_fft + log.info("Convolving image data with matching kernel") + matched_data = convolve_fft( + np.asanyarray(model.data), # use np.asanyarray to preserve units + matching_kernel, + preserve_nan=True, + normalize_kernel=True, + ) + + # Propagate errors - convolve variance with kernel^2 + log.info("Propagating errors") + if model.err is not None: + matched_variance = convolve_fft( + model.err**2, + matching_kernel**2, + preserve_nan=True, + normalize_kernel=False, + ) + matched_err = np.sqrt(np.abs(matched_variance)) + else: + matched_err = None + + # Create output model with copied data and err arrays + matched_model = copy_model_arrays(model) + matched_model.data = matched_data + if matched_err is not None: + matched_model.err = matched_err + + log.info(f"Created PSF-matched image: {input_filter} -> {target_filter}") + + return matched_model + + +def get_filter_wavelength(filter_name): + """ + Extract approximate wavelength from filter name. + + Parameters + ---------- + filter_name : str + Filter name (e.g., 'F158', 'F184', 'f158', 'f158m') + + Returns + ------- + wavelength : int + Approximate wavelength in microns (e.g., 1.58, 1.84), or 0 if + cannot parse. + """ + try: + # Remove 'm' suffix if present (for PSF-matched column names) + name = filter_name.rstrip("m") + return int(name[1:]) / 100.0 # Convert to microns + except (ValueError, IndexError): + return 0 + + +def get_reddest_filter(library): + """ + Get the reddest filter in the model library. + + The reddest filter is typically the one with the broadest PSF. + + Parameters + ---------- + library : `~romancal.datamodels.ModelLibrary` + The model library containing images from different filters. + + Returns + ------- + reddest_filter : str + The name of the reddest filter. + + Notes + ----- + This function uses a simple heuristic that extracts filter + wavelengths from the filter names. + """ + # Get list of filters in the library + filters = [] + with library: + for model in library: + filter_name = model.meta.instrument.optical_element + if filter_name not in filters: + filters.append(filter_name) + library.shelve(model, modify=False) + + # Simple heuristic to extract wavelength from filter name + # Roman WFI filters: F062, F087, F106, F129, F146, F158, F184, F213 + filter_wavelengths = {filt: get_filter_wavelength(filt) for filt in filters} + + # Return the longest wavelength filter + reference_filter = max(filter_wavelengths, key=filter_wavelengths.get) + + log.info( + f"Selected {reference_filter} as reference filter for PSF matching " + f"(longest wavelength among {filters})" + ) + + return reference_filter + + +def compute_psf_correction_factors( + ref_model, + ref_psf_model, + ref_catalog, + target_model, + target_psf_model, + segment_img, + star_kernel_fwhm, + detection_cat, + mask, + ee_spline, +): + """ + Compute correction factors for PSF-matched photometry on the + reference filter. + + When the PSF match reference filter is not the reddest filter, we + need to create synthetic PSF-matched photometry by: + + 1. PSF-matching the reference image to the target (redder) filter's PSF. + 2. Measuring photometry on the PSF-matched reference image. + 3. Computing correction factors C = flux_original / flux_matched. + These factors are typically larger than 1.0 because the + PSF-matched image has broader PSF and thus lower flux in a fixed + aperture. These factors are multiplied to the original measured + fluxes in the target (redder) filter. + + Parameters + ---------- + ref_model : `ImageModel` or `MosaicModel` + The reference filter image (original, not PSF-matched). + + ref_psf_model : EpsfRefModel + PSF model for the reference filter. + + ref_catalog : `~astropy.table.Table` + Catalog from the original reference image. + + target_model : `ImageModel` or `MosaicModel` + The target (redder) filter image that will be matched. + + target_psf_model : EpsfRefModel + PSF model for the target filter. + + segment_img : `~photutils.segmentation.SegmentationImage` + Segmentation map for source extraction. + + star_kernel_fwhm : float + FWHM for star detection kernel. + + detection_cat : RomanSourceCatalog + Detection catalog. + + mask : array-like + Boolean mask array where True indicates valid pixels. + + ee_spline : callable + The encircled energy spline function. + + Returns + ------- + correction_factors : dict + Dictionary mapping flux column names to arrays of correction + factors. Values are arrays with one correction factor per + source. + """ + from romancal.source_catalog.source_catalog import RomanSourceCatalog + + ref_filter = ref_model.meta.instrument.optical_element + target_filter = target_model.meta.instrument.optical_element + log.info( + f"Computing PSF correction factors for {ref_filter} matched to {target_filter}" + ) + + # PSF-match the reference image to the target filter + psf_matched_ref = create_psf_matched_image( + ref_model, + ref_psf_model, + target_psf_model, + ) + + # Measure photometry on the PSF-matched reference image + catobj_matched = RomanSourceCatalog( + psf_matched_ref, + segment_img, + None, + star_kernel_fwhm, + fit_psf=False, + psf_model=None, + mask=mask, + detection_cat=detection_cat, + cat_type="psf_matched", + ee_spline=ee_spline, + ) + + # add filter name to catalog column names + cat_matched = add_filter_to_colnames(catobj_matched.catalog, ref_filter) + + # Compute correction factors C = flux_original / flux_matched + # for each flux type (segment, kron, aper*) + correction_factors = {} + + # Flux column patterns to compute corrections for + flux_columns = [col for col in ref_catalog.colnames if col.endswith("_flux")] + + for flux_col in flux_columns: + if flux_col not in ref_catalog.colnames or flux_col not in cat_matched.colnames: + continue + + flux_original = ref_catalog[flux_col] + flux_matched = cat_matched[flux_col] + + # Compute correction factor, avoiding division by zero + with np.errstate(divide="ignore", invalid="ignore"): + correction = flux_original / flux_matched + # Set correction to 1.0 where original flux is zero or + # invalid + correction = np.where( + (flux_matched == 0) | ~np.isfinite(correction), 1.0, correction + ) + + correction_factors[flux_col] = correction + + return correction_factors diff --git a/romancal/source_catalog/segment.py b/romancal/source_catalog/segment.py index a8069c954..ec20f9412 100644 --- a/romancal/source_catalog/segment.py +++ b/romancal/source_catalog/segment.py @@ -49,8 +49,11 @@ class SegmentCatalog: detection catalog centroids and shape properties will also be used to perform aperture photometry (i.e., circular and Kron). - flux_unit : str, optional - The unit of the flux density. Default is 'nJy'. + cat_type : str, optional + The type of catalog to create. The default is 'prompt'. Allowed + values are 'prompt', 'dr_det', 'dr_band', 'psf_matched', + 'forced_full', and 'forced_det'. This determines which + properties are calculated. Notes ----- @@ -68,7 +71,7 @@ def __init__( pixel_area, wcs_angle, detection_cat=None, - flux_unit="nJy", + cat_type="prompt", ): self.model = model self.segment_img = segment_img @@ -76,7 +79,7 @@ def __init__( self.pixel_area = pixel_area self.wcs_angle = wcs_angle self.detection_cat = detection_cat - self.flux_unit = flux_unit + self.cat_type = cat_type self.names = [] self.wcs = self.model.meta.wcs @@ -88,9 +91,22 @@ def __init__( # calculate the segment properties self.calc_segment_properties() - # lazyproperties are not set until accessed so we need to - # manually append them + # lazy properties are not set until accessed so we need to + # manually append them. + for name in self._lazyproperties: + # For "dr_band" and "psf_matched" catalogs, we do not need + # the orientation_sky lazy property. + skip_orientation = self.cat_type in ("dr_band", "psf_matched") + if skip_orientation and name == "orientation_sky": + continue + + # For "psf_matched" catalogs, we do not need the + # fluxfrac_radius_50 lazy property. + skip_fluxfrac = self.cat_type in ("psf_matched",) + if skip_fluxfrac and name == "fluxfrac_radius_50": + continue + self.names.append(name) # add the placeholder attributes @@ -177,33 +193,47 @@ def calc_segment_properties(self): # Extract the properties from the segment catalog. These # names are the SourceCatalog property names and the order # is not important. - photutils_names = ( - "label", - "xcentroid", - "ycentroid", - "xcentroid_win", - "ycentroid_win", - "sky_centroid", - "sky_centroid_win", - "bbox_xmin", - "bbox_xmax", - "bbox_ymin", - "bbox_ymax", - "area", - "semimajor_sigma", - "semiminor_sigma", - "fwhm", - "orientation", - "ellipticity", - "cxx", - "cxy", - "cyy", - "kron_radius", - "segment_flux", - "segment_fluxerr", - "kron_flux", - "kron_fluxerr", - ) + + # For dr_band and psf_matched catalogs (multiband filter and + # PSF-matched catalogs), only calculate minimal properties + # needed for photometry + if self.cat_type in ("dr_band", "psf_matched"): + photutils_names = [ + "label", + "segment_flux", + "segment_fluxerr", + "kron_flux", + "kron_fluxerr", + ] + else: + # Full catalog includes all core properties + photutils_names = [ + "label", + "xcentroid", + "ycentroid", + "xcentroid_win", + "ycentroid_win", + "sky_centroid", + "sky_centroid_win", + "bbox_xmin", + "bbox_xmax", + "bbox_ymin", + "bbox_ymax", + "area", + "kron_radius", + "segment_flux", + "segment_fluxerr", + "kron_flux", + "kron_fluxerr", + "semimajor_sigma", + "semiminor_sigma", + "fwhm", + "orientation", + "ellipticity", + "cxx", + "cxy", + "cyy", + ] # if needed, map names from photutils to the output catalog names name_map = {} @@ -223,7 +253,10 @@ def calc_segment_properties(self): new_name = name_map.get(name, name) value = getattr(segm_cat, name) - # handle any unit conversions + # handle any unit conversions needed for specific columns + # photutils -> romancal + + # unitless -> pix if new_name in ( "x_centroid", "y_centroid", @@ -248,7 +281,7 @@ def calc_segment_properties(self): if new_name == "ellipticity": value = value.value - # change the photutils dtypes + # change dtypes for romancal catalog if new_name not in ("sky_centroid", "sky_centroid_win"): if np.issubdtype(value.dtype, np.integer): value = value.astype(np.int32) diff --git a/romancal/source_catalog/source_catalog.py b/romancal/source_catalog/source_catalog.py index 264920577..384370c1e 100644 --- a/romancal/source_catalog/source_catalog.py +++ b/romancal/source_catalog/source_catalog.py @@ -32,7 +32,8 @@ class RomanSourceCatalog: ---------- model : `ImageModel` or `MosaicModel` The input data model. The image data is assumed to be background - subtracted. + subtracted. For PSF-matched photometry in multiband catalogs, + input the PSF-matched model. segment_image : `~photutils.segmentation.SegmentationImage` A 2D segmentation image, with the same shape as the input data, @@ -70,14 +71,19 @@ class RomanSourceCatalog: flux_unit : str, optional The unit of the flux density. Default is 'nJy'. - cat_type : {'prompt', 'dr_det', 'dr_band', 'forced_full', 'forced_det'}, optional - The type of catalog to create. The default is 'prompt'. This - determines the columns in the output catalog. The 'dr_det' and - 'dr_band' catalogs are band-specific catalogs for the - multiband source detection. + cat_type : str, optional + The type of catalog to create. The default is 'prompt'. Allowed + values are 'prompt', 'dr_det', 'dr_band', 'psf_matched', + 'forced_full', and 'forced_det'. This determines the columns + in the output catalog. The 'dr_det' and 'dr_band' catalogs are + band-specific catalogs for the multiband source detection. The + 'psf_matched' catalog is similar to 'dr_band' but excludes + sharpness, roundness1, is_extended, and fluxfrac_radius_50 + properties for performance optimization. - ee_spline : `~astropy.modeling.models.Spline1D` or `None` - The PSF aperture correction model, built from the reference file. + ee_spline : `~astropy.modeling.models.Spline1D` or `None`, optional + The PSF aperture correction model, built from the reference + file. Notes ----- @@ -95,8 +101,8 @@ def __init__( kernel_fwhm, *, fit_psf=True, + psf_model=None, mask=None, - psf_ref_model=None, detection_cat=None, flux_unit="nJy", cat_type="prompt", @@ -110,10 +116,11 @@ def __init__( self.convolved_data = convolved_data self.kernel_fwhm = kernel_fwhm self.fit_psf = fit_psf + self.psf_model = psf_model self.mask = mask - self.psf_ref_model = psf_ref_model self.detection_cat = detection_cat - self.flux_unit = flux_unit + self.flux_unit_str = flux_unit + self.flux_unit = u.Unit(self.flux_unit_str) self.cat_type = cat_type self.ee_spline = ee_spline @@ -127,11 +134,9 @@ def __init__( self.l2_to_sb = self.model.meta.photometry.conversion_megajanskys else: self.l2_to_sb = 1.0 - self.sb_to_flux = (1.0 * (u.MJy / u.sr) * self._pixel_area).to( - u.Unit(self.flux_unit) - ) + self.sb_to_flux = (1.0 * (u.MJy / u.sr) * self._pixel_area).to(self.flux_unit) - if self.fit_psf and self.psf_ref_model is None: + if self.fit_psf and self.psf_model is None: log.error( "PSF fitting is requested but no PSF reference model is provided. Skipping PSF photometry." ) @@ -145,10 +150,10 @@ def convert_l2_to_sb(self): Convert level-2 data from units of DN/s to MJy/sr (surface brightness). """ - # the conversion in done in-place to avoid making copies of the data; + # the conversion is done in-place to avoid making copies of the data; # use dictionary syntax to set the value to avoid on-the-fly validation - self.model["data"] *= self.l2_to_sb - self.model["err"] *= self.l2_to_sb + for attr in ("data", "err"): + self.model[attr] *= self.l2_to_sb if self.convolved_data is not None: self.convolved_data *= self.l2_to_sb @@ -159,12 +164,11 @@ def convert_sb_to_flux_density(self): The flux density unit is defined by self.flux_unit. """ - # the conversion in done in-place to avoid making copies of the data; + # the conversion is done in-place to avoid making copies of the data; # use dictionary syntax to set the value to avoid on-the-fly validation - self.model["data"] *= self.sb_to_flux.value - self.model["data"] <<= self.sb_to_flux.unit - self.model["err"] *= self.sb_to_flux.value - self.model["err"] <<= self.sb_to_flux.unit + for attr in ("data", "err"): + self.model[attr] *= self.sb_to_flux.value + self.model[attr] <<= self.sb_to_flux.unit if self.convolved_data is not None: self.convolved_data *= self.sb_to_flux.value self.convolved_data <<= self.sb_to_flux.unit @@ -271,7 +275,7 @@ def calc_segment_properties(self): self._pixel_area, self._wcs_angle, detection_cat=self.detection_cat, - flux_unit=self.flux_unit, + cat_type=self.cat_type, ) self.meta.update(segment_cat.meta) @@ -305,7 +309,7 @@ def calc_psf_photometry(self): The results are set as dynamic attributes on the class instance. """ - psf_cat = PSFCatalog(self.model, self.psf_ref_model, self._xypos, self.mask) + psf_cat = PSFCatalog(self.model, self.psf_model, self._xypos, self.mask) for name in psf_cat.names: setattr(self, name, getattr(psf_cat, name)) @@ -613,7 +617,7 @@ def flux_colnames(self): ] psf_colnames = ["psf_flux", "psf_flux_err"] - if self.cat_type in ("prompt", "forced_full", "dr_band"): + if self.cat_type in ("prompt", "forced_full", "dr_band", "psf_matched"): flux_colnames = self.aper_colnames if self.fit_psf: flux_colnames.extend(psf_colnames) @@ -771,6 +775,12 @@ def column_names(self): elif self.cat_type == "dr_band": colnames = band_colnames.copy() + elif self.cat_type == "psf_matched": + # Similar to dr_band but without the othershape_colnames + # (sharpness, roundness1, is_extended, fluxfrac_radius_50) + colnames = ["label"] + colnames.extend(self.flux_colnames) + return colnames def _prefix_forced(self, catalog): @@ -802,15 +812,138 @@ def _prefix_forced(self, catalog): return catalog + @staticmethod + def _get_compatible_unit(*arrays): + """ + Check if multiple arrays have compatible units and return the + common unit. + + This function verifies that either all arrays are plain NumPy + ndarrays (no units) or that they are all Astropy Quantity + objects with the same units. + + Parameters + ---------- + *arrays : `numpy.ndarray` or `astropy.units.Quantity` + The data arrays to check. + + Returns + ------- + result : astropy.units.Unit or None + The common `astropy.units.Unit` object if all are Quantity + arrays with the same unit. Otherwise, returns `None`. + + Raises + ------ + ValueError + If one input is a Quantity and another is not, or if they + are all Quantities but with different units. + """ + if len(arrays) == 0: + return None + + # Filter out None values + arrays = [arr for arr in arrays if arr is not None] + if len(arrays) == 0: + return None + + # Check if first array is a Quantity + is_quantity = [isinstance(arr, u.Quantity) for arr in arrays] + + # All must be quantities or all must not be quantities + if all(is_quantity): + # Check that all have the same unit + first_unit = arrays[0].unit + for i, arr in enumerate(arrays[1:], start=1): + if arr.unit != first_unit: + raise ValueError( + f"Incompatible units: array 0 has unit '{first_unit}' " + f"but array {i} has unit '{arr.unit}'." + ) + return first_unit + elif not any(is_quantity): + return None + else: + # Mixed types + raise ValueError( + "Incompatible types: some arrays have units while others do not." + ) + + def _validate_and_convert_units(self): + """ + Validate that model data arrays have compatible units and + convert them to flux density units if needed. + + This method checks that model.data and model.err have the same + units, then converts both them and convolved_data (if not None) + to the desired flux density unit (self.flux_unit). + + Note: convolved_data is allowed to have different units from + model.data and model.err because it may come from a separate + source (e.g., a detection image from forced photometry). + + For models without units: + - Level-2 (ImageModel): DN/s -> MJy/sr -> flux density + - Level-3 (MosaicModel): MJy/sr -> flux density + + For models with units: + - Converts to self.flux_unit if the unit is compatible + + Raises + ------ + ValueError + If model.data and model.err have incompatible units. + """ + # Check that model.data and model.err have compatible units + unit = self._get_compatible_unit(self.model.data, self.model.err) + + if unit is None: + # No units present - convert to flux density units + if isinstance(self.model, ImageModel): + # Level-2: DN/s -> MJy/sr + self.convert_l2_to_sb() + # Level-2 or Level-3: MJy/sr -> flux density + self.convert_sb_to_flux_density() + else: + # Units present - check compatibility and convert + if unit.is_equivalent(self.flux_unit): + # Convert to desired flux unit + self.model["data"] = self.model["data"].to(self.flux_unit) + self.model["err"] = self.model["err"].to(self.flux_unit) + else: + raise ValueError( + f"Incompatible units: model data has unit '{unit}' " + f"which is not equivalent to the desired flux unit " + f"'{self.flux_unit}'." + ) + + # Handle convolved_data separately as it may have different units + if self.convolved_data is not None: + conv_unit = self._get_compatible_unit(self.convolved_data) + if conv_unit is None: + # No units - apply same conversion as model data + if isinstance(self.model, ImageModel): + self.convolved_data *= self.l2_to_sb + self.convolved_data *= self.sb_to_flux.value + self.convolved_data <<= self.sb_to_flux.unit + else: + if conv_unit.is_equivalent(self.flux_unit): + # Convert to desired flux unit + self.convolved_data = self.convolved_data.to(self.flux_unit) + else: + raise ValueError( + f"Incompatible units: convolved_data has unit " + f"'{conv_unit}' which is not equivalent to the " + f"desired flux unit '{self.flux_unit}'." + ) + @lazyproperty def catalog(self): """ The final source catalog as an Astropy Table. """ - # convert data to flux units - if isinstance(self.model, ImageModel): - self.convert_l2_to_sb() - self.convert_sb_to_flux_density() + # Validate and convert units for all data arrays + self._validate_and_convert_units() # make measurements - the order of these calculations is important log.info("Calculating segment properties") diff --git a/romancal/source_catalog/source_catalog_step.py b/romancal/source_catalog/source_catalog_step.py index 4277b3f46..c3c58e4d3 100644 --- a/romancal/source_catalog/source_catalog_step.py +++ b/romancal/source_catalog/source_catalog_step.py @@ -19,7 +19,7 @@ from romancal.source_catalog.psf import add_jitter from romancal.source_catalog.save_utils import save_all_results, save_empty_results from romancal.source_catalog.source_catalog import RomanSourceCatalog -from romancal.source_catalog.utils import get_ee_spline +from romancal.source_catalog.utils import copy_model_arrays, get_ee_spline from romancal.stpipe import RomanStep if TYPE_CHECKING: @@ -68,10 +68,10 @@ def process(self, step_input): if self.fit_psf: self.ref_file = self.get_reference_file(input_model, "epsf") log.info("Using ePSF reference file: %s", self.ref_file) - psf_ref_model = datamodels.open(self.ref_file) - psf_ref_model.psf = add_jitter(psf_ref_model, input_model) + psf_model = datamodels.open(self.ref_file) + psf_model.psf = add_jitter(psf_model, input_model) else: - psf_ref_model = None + psf_model = None # Define a boolean mask for pixels to be excluded mask = ( @@ -80,16 +80,11 @@ def process(self, step_input): | (input_model.err <= 0) ) - # Copy the data and error arrays to avoid modifying the input - # model. The metadata and dq and weight arrays are not copied - # because they are not modified in this step. - if isinstance(input_model, ImageModel): - model = ImageModel() - model.meta = input_model.meta - model.data = input_model.data.copy() - model.err = input_model.err.copy() - model.dq = input_model.dq + # Copy the data and error arrays to avoid modifying the input model + model = copy_model_arrays(input_model) + # Create a DQ mask for ImageModel + if isinstance(input_model, ImageModel): # Create a DQ mask for pixels to be excluded; currently all # pixels with any DQ flag are excluded from the source catalog # except for those in ignored_dq_flags. @@ -100,12 +95,6 @@ def process(self, step_input): # TODO: to set the mask to True for *only* dq_flags use: # dq_mask = np.any(model.dq[..., None] & dq_flags, axis=-1) mask |= dq_mask - elif isinstance(input_model, MosaicModel): - model = MosaicModel() - model.meta = input_model.meta - model.data = input_model.data.copy() - model.err = input_model.err.copy() - model.weight = input_model.weight # Initialize the source catalog model, copying the metadata # from the input model @@ -200,8 +189,8 @@ def process(self, step_input): detection_image, self.kernel_fwhm, fit_psf=fit_psf, + psf_model=psf_model, mask=mask, - psf_ref_model=psf_ref_model, cat_type=cat_type, ee_spline=ee_spline, ) @@ -218,8 +207,8 @@ def process(self, step_input): forced_detection_image, self.kernel_fwhm, fit_psf=self.fit_psf, + psf_model=psf_model, mask=mask, - psf_ref_model=psf_ref_model, cat_type="forced_full", ee_spline=ee_spline, ) diff --git a/romancal/source_catalog/tests/test_psf_matching.py b/romancal/source_catalog/tests/test_psf_matching.py new file mode 100644 index 000000000..5f84947c6 --- /dev/null +++ b/romancal/source_catalog/tests/test_psf_matching.py @@ -0,0 +1,461 @@ +""" +Unit tests for PSF matching functionality. +""" + +from unittest.mock import patch + +import numpy as np +import pytest +from astropy.modeling.fitting import TRFLSQFitter +from astropy.modeling.models import Gaussian2D +from roman_datamodels.datamodels import ImageModel + +from romancal.source_catalog.psf_matching import ( + create_psf_matched_image, + get_filter_wavelength, + get_reddest_filter, +) + + +@pytest.fixture +def mock_image_model(): + """ + Create a mock ImageModel for testing. + """ + # Create a simple test image with a Gaussian-like source + ny, nx = 100, 100 + y, x = np.mgrid[:ny, :nx] + cy, cx = ny // 2, nx // 2 + + # Create a simple Gaussian-like source + sigma = 3.0 + gauss = Gaussian2D( + amplitude=1.0, x_mean=cx, y_mean=cy, x_stddev=sigma, y_stddev=sigma + ) + data = gauss(x, y) + rng = np.random.default_rng(42) + data += 0.001 * rng.standard_normal((ny, nx)) # Add small noise + + err = 0.01 * np.ones_like(data) + + # Use create_fake_data to get proper metadata structure + model = ImageModel.create_fake_data(shape=(ny, nx)) + model.data[:] = data + model.err[:] = err + model.meta.instrument.optical_element = "F087" + + return model + + +@pytest.fixture +def mock_psf_ref_model(): + """ + Create a mock PSF reference model factory for testing. + """ + + def _create_mock_psf(filter_name, oversampling=4): + """ + Create a mock PSF reference model. + """ + + class MockPSFRef: + """ + Mock PSF reference model. + """ + + def __init__(self, filter_name, oversampling): + # Set metadata + self.meta = type("Meta", (), {})() + self.meta.instrument = type("Instrument", (), {})() + self.meta.instrument.optical_element = filter_name + self.oversample = oversampling + self.oversampling = oversampling + + # Create a simple PSF stamp for create_l3_psf_model + stamp_size = 41 + y, x = np.mgrid[:stamp_size, :stamp_size] + cy, cx = stamp_size // 2, stamp_size // 2 + + # Create Gaussian PSF - broader for longer wavelengths + wavelength = get_filter_wavelength(filter_name) + # Scale FWHM with wavelength + fwhm = 2.0 + wavelength # Simple scaling + sigma = fwhm / 2.355 + psf = np.exp(-((x - cx) ** 2 + (y - cy) ** 2) / (2 * sigma**2)) + psf = psf / psf.sum() + + self.psf_data = psf + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + return MockPSFRef(filter_name, oversampling) + + return _create_mock_psf + + +@pytest.fixture +def mock_create_l3_psf_patch(): + """ + Create a mock patch for create_l3_psf_model. + + Returns a context manager that can be used to patch create_l3_psf_model + in tests that need to perform PSF matching. + """ + + def mock_create_l3_psf(psf_ref_model): + """Mock create_l3_psf_model.""" + + class MockL3PSF: + def __init__(self): + self.data = psf_ref_model.psf_data + self.oversampling = np.array([4, 4]) + + return MockL3PSF() + + return patch( + "romancal.source_catalog.psf_matching.create_l3_psf_model", + side_effect=mock_create_l3_psf, + ) + + +def test_get_filter_wavelength(): + """ + Test filter wavelength extraction. + """ + assert get_filter_wavelength("F062") == 0.62 + assert get_filter_wavelength("F087") == 0.87 + assert get_filter_wavelength("F106") == 1.06 + assert get_filter_wavelength("F129") == 1.29 + assert get_filter_wavelength("F158") == 1.58 + assert get_filter_wavelength("F184") == 1.84 + assert get_filter_wavelength("F213") == 2.13 + + # Test with 'm' suffix (PSF-matched column names) + assert get_filter_wavelength("F158m") == 1.58 + assert get_filter_wavelength("f184m") == 1.84 + + # Test invalid filter names + assert get_filter_wavelength("invalid") == 0 + assert get_filter_wavelength("") == 0 + + +def test_create_psf_matched_image_basic( + mock_image_model, mock_psf_ref_model, mock_create_l3_psf_patch +): + """ + Test basic PSF matching functionality. + """ + # Create input and target PSF models + input_psf_ref = mock_psf_ref_model("F087") # Narrower PSF + target_psf_ref = mock_psf_ref_model("F184") # Broader PSF + + with mock_create_l3_psf_patch: + # Create matched image + matched_model = create_psf_matched_image( + mock_image_model, input_psf_ref, target_psf_ref + ) + + # Check that output is an ImageModel + assert isinstance(matched_model, ImageModel) + + # Check that data shape is preserved + assert matched_model.data.shape == mock_image_model.data.shape + + # Check that errors are propagated + assert matched_model.err is not None + assert matched_model.err.shape == mock_image_model.err.shape + assert np.all(matched_model.err > 0) + assert np.all(np.isfinite(matched_model.err)) + + +def test_create_psf_matched_image_same_filter(mock_image_model, mock_psf_ref_model): + """ + Test that PSF matching is skipped when filters are the same. + """ + # Create PSF models with same filter + input_psf_ref = mock_psf_ref_model("F087") + target_psf_ref = mock_psf_ref_model("F087") + + matched_model = create_psf_matched_image( + mock_image_model, input_psf_ref, target_psf_ref + ) + + # Check that matching was skipped - returns the input model + assert matched_model is mock_image_model + + +def test_create_psf_matched_image_broader_input(mock_image_model, mock_psf_ref_model): + """ + Test that PSF matching is skipped when input PSF is already broader. + """ + # Create PSF models where input is broader than target + input_psf_ref = mock_psf_ref_model("F184") # Broader + target_psf_ref = mock_psf_ref_model("F087") # Narrower + + mock_image_model.meta.instrument.optical_element = "F184" + + matched_model = create_psf_matched_image( + mock_image_model, input_psf_ref, target_psf_ref + ) + + # Check that matching was skipped - returns the input model + assert matched_model is mock_image_model + + +def test_create_psf_matched_image_params(mock_image_model, mock_psf_ref_model): + """ + Test that all required parameters must be provided. + """ + input_psf_ref = mock_psf_ref_model("F087") + + match = "missing 1 required positional argument" + with pytest.raises(TypeError, match=match): + create_psf_matched_image(mock_image_model, input_psf_ref) + + +def test_create_psf_matched_image_flux_conservation( + mock_image_model, mock_psf_ref_model, mock_create_l3_psf_patch +): + """ + Test that total flux is roughly preserved after PSF matching. + """ + input_psf_ref = mock_psf_ref_model("F087") + target_psf_ref = mock_psf_ref_model("F184") + + input_flux = np.sum(mock_image_model.data) + + with mock_create_l3_psf_patch: + matched_model = create_psf_matched_image( + mock_image_model, input_psf_ref, target_psf_ref + ) + + matched_flux = np.sum(matched_model.data) + + # Flux should be approximately preserved + assert np.isclose(matched_flux, input_flux, rtol=0.05) + + +def test_get_reddest_filter(): + """ + Test automatic reddest filter selection. + """ + + # Create a mock library with multiple filters + class MockLibrary: + def __init__(self, models): + self.models = models + self._index = 0 + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def __iter__(self): + return iter(self.models) + + def shelve(self, model, modify=False): + pass + + # Create models with different filters + model_f087 = ImageModel.create_fake_data(shape=(10, 10)) + model_f087.meta.instrument.optical_element = "F087" + + model_f129 = ImageModel.create_fake_data(shape=(10, 10)) + model_f129.meta.instrument.optical_element = "F129" + + model_f184 = ImageModel.create_fake_data(shape=(10, 10)) + model_f184.meta.instrument.optical_element = "F184" + + library = MockLibrary([model_f087, model_f129, model_f184]) + + reddest_filter = get_reddest_filter(library) + + # Should select F184 (longest wavelength) + assert reddest_filter == "F184" + + +def test_create_psf_matched_image_invalid_input(mock_psf_ref_model): + """ + Test that invalid model type raises ValueError. + """ + invalid_model = "not a model" + input_psf_ref = mock_psf_ref_model("F087") + target_psf_ref = mock_psf_ref_model("F184") + + match = "model must be an ImageModel or MosaicModel" + with pytest.raises(ValueError, match=match): + create_psf_matched_image(invalid_model, input_psf_ref, target_psf_ref) + + +def test_psf_matching_kernel_validation(): + """ + Validation test for PSF matching with known Gaussian PSFs. + + This test creates a noiseless image with a single centered Gaussian + source (sigma=3) and PSF (sigma=3). The target PSF is a Gaussian + with sigma=5. The matching kernel should be a Gaussian with sigma=4, + and the final PSF-matched source should have sigma=5. + + This validates that: + - The matching kernel has the expected sigma=4 + - The output source has the expected sigma=5 + """ + # Create noiseless image with centered Gaussian source + ny, nx = 201, 201 # Larger to avoid edge effects + y, x = np.mgrid[:ny, :nx] + cy, cx = ny // 2, nx // 2 + + # Source with sigma=3 + source_sigma = 3.0 + gauss_source = Gaussian2D( + amplitude=1000.0, + x_mean=cx, + y_mean=cy, + x_stddev=source_sigma, + y_stddev=source_sigma, + ) + data = gauss_source(x, y) + + # Create model + model = ImageModel.create_fake_data(shape=(ny, nx)) + model.data[:] = data + model.err = 0.01 * np.ones_like(data) + model.meta.instrument.optical_element = "F087" + + # Create PSF models with known Gaussian PSFs + # Note: PSFs will be oversampled by a factor of 4 + oversample = 4 + + class MockPSFRef: + """Mock PSF reference model with Gaussian PSF.""" + + def __init__(self, filter_name, sigma): + self.meta = type("Meta", (), {})() + self.meta.instrument = type("Instrument", (), {})() + self.meta.instrument.optical_element = filter_name + self.oversample = oversample + self.oversampling = oversample + + # Create Gaussian PSF stamp (oversampled) + # sigma in the oversampled space + sigma_oversampled = sigma * oversample + stamp_size = 41 * oversample # Larger stamp for oversampled PSF + y_psf, x_psf = np.mgrid[:stamp_size, :stamp_size] + cy_psf, cx_psf = stamp_size // 2, stamp_size // 2 + + gauss_psf = Gaussian2D( + amplitude=1.0, + x_mean=cx_psf, + y_mean=cy_psf, + x_stddev=sigma_oversampled, + y_stddev=sigma_oversampled, + ) + psf = gauss_psf(x_psf, y_psf) + psf = psf / psf.sum() # Normalize + + self.psf_data = psf + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + # Input PSF: sigma=3 (same as source) + input_psf_sigma = 3.0 + input_psf_ref = MockPSFRef("F087", input_psf_sigma) + + # Target PSF: sigma=5 + target_psf_sigma = 5.0 + target_psf_ref = MockPSFRef("F184", target_psf_sigma) + + # Expected matching kernel sigma: sqrt(5^2 - 3^2) = sqrt(25 - 9) = 4 + expected_kernel_sigma = np.sqrt(target_psf_sigma**2 - input_psf_sigma**2) + + # Mock create_l3_psf_model to capture the kernel + captured_kernel = {} + + def mock_create_l3_psf(psf_ref_model): + """Mock create_l3_psf_model.""" + + class MockL3PSF: + def __init__(self): + self.data = psf_ref_model.psf_data + self.oversampling = np.array([4, 4]) + + return MockL3PSF() + + # Also patch create_convolution_kernel to capture it + def mock_create_kernel(input_psf, target_psf, downsample=1): + """Mock that captures the kernel.""" + # Import the real function + from romancal.source_catalog.psf import create_convolution_kernel + + kernel = create_convolution_kernel(input_psf, target_psf, downsample=downsample) + captured_kernel["kernel"] = kernel + return kernel + + with ( + patch( + "romancal.source_catalog.psf_matching.create_l3_psf_model", + side_effect=mock_create_l3_psf, + ), + patch( + "romancal.source_catalog.psf_matching.create_convolution_kernel", + side_effect=mock_create_kernel, + ), + ): + matched_model = create_psf_matched_image(model, input_psf_ref, target_psf_ref) + + # Check that we captured the kernel + assert "kernel" in captured_kernel + kernel = captured_kernel["kernel"] + + # Fit a 2D Gaussian to the matching kernel to extract sigma + y_k, x_k = np.mgrid[: kernel.shape[0], : kernel.shape[1]] + cy_k, cx_k = kernel.shape[0] // 2, kernel.shape[1] // 2 + + # Initial guess for Gaussian fit + gauss_init = Gaussian2D( + amplitude=kernel.max(), + x_mean=cx_k, + y_mean=cy_k, + x_stddev=expected_kernel_sigma, + y_stddev=expected_kernel_sigma, + ) + + fitter = TRFLSQFitter() + gauss_fit = fitter(gauss_init, x_k, y_k, kernel) + + # Check kernel sigma (should be ~4) + fitted_kernel_sigma = (gauss_fit.x_stddev.value + gauss_fit.y_stddev.value) / 2 + assert np.isclose(fitted_kernel_sigma, expected_kernel_sigma, rtol=0.01), ( + f"Kernel sigma {fitted_kernel_sigma:.2f} != " + f"expected {expected_kernel_sigma:.2f}" + ) + + # Fit a 2D Gaussian to the output source to extract sigma + gauss_output_init = Gaussian2D( + amplitude=matched_model.data.max(), + x_mean=cx, + y_mean=cy, + x_stddev=target_psf_sigma, + y_stddev=target_psf_sigma, + ) + + gauss_output_fit = fitter(gauss_output_init, x, y, matched_model.data) + + # Check output source sigma (should be ~5) + fitted_output_sigma = ( + gauss_output_fit.x_stddev.value + gauss_output_fit.y_stddev.value + ) / 2 + assert np.isclose(fitted_output_sigma, target_psf_sigma, rtol=0.01), ( + f"Output source sigma {fitted_output_sigma:.2f} != " + f"expected {target_psf_sigma:.2f}" + ) diff --git a/romancal/source_catalog/tests/test_utils.py b/romancal/source_catalog/tests/test_utils.py new file mode 100644 index 000000000..68e12084d --- /dev/null +++ b/romancal/source_catalog/tests/test_utils.py @@ -0,0 +1,148 @@ +""" +Unit tests for the source catalog utils module. +""" + +import numpy as np +import pytest +from roman_datamodels.datamodels import ImageModel, MosaicModel + +from romancal.source_catalog.utils import copy_model_arrays + +rng = np.random.default_rng(12345) + + +class TestCopyModelArrays: + """ + Tests for the copy_model_arrays function. + """ + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_data_copied(self, model_class): + """ + Test that data array is copied independently. + """ + model = model_class.create_fake_data(shape=(50, 50)) + model.data[:] = rng.normal(0, 1, size=(50, 50)) + + copied = copy_model_arrays(model) + + # Verify data is copied (different memory addresses) + assert id(model.data) != id(copied.data) + # Verify data values are equal + assert np.array_equal(model.data, copied.data) + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_err_copied(self, model_class): + """ + Test that err array is copied independently. + """ + model = model_class.create_fake_data(shape=(50, 50)) + model.err = np.ones((50, 50)) * 0.5 + + copied = copy_model_arrays(model) + + # Verify err is copied (different memory addresses) + assert id(model.err) != id(copied.err) + # Verify err values are equal + assert np.array_equal(model.err, copied.err) + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_meta_preserved(self, model_class): + """ + Test that metadata is preserved. + """ + model = model_class.create_fake_data(shape=(50, 50)) + + copied = copy_model_arrays(model) + + # Verify meta has same values (note: datamodels may copy the dict) + assert model.meta.model_type == copied.meta.model_type + assert model.meta.telescope == copied.meta.telescope + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_data_modification_independent(self, model_class): + """ + Test that modifying copied data doesn't affect original. + """ + model = model_class.create_fake_data(shape=(50, 50)) + model.data[:] = rng.normal(0, 1, size=(50, 50)) + original_value = model.data[0, 0] + + copied = copy_model_arrays(model) + copied.data[0, 0] = 999.0 + + # Original should be unchanged + assert model.data[0, 0] == original_value + # Copied should have new value + assert copied.data[0, 0] == 999.0 + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_err_modification_independent(self, model_class): + """ + Test that modifying copied err doesn't affect original. + """ + model = model_class.create_fake_data(shape=(50, 50)) + model.err = np.ones((50, 50)) * 0.5 + original_err_value = model.err[10, 10] + + copied = copy_model_arrays(model) + copied.err[10, 10] = 123.456 + + # Original should be unchanged + assert model.err[10, 10] == original_err_value + # Copied should have new value + assert copied.err[10, 10] == 123.456 + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_returns_correct_type(self, model_class): + """ + Test that function returns the same model type. + """ + model = model_class.create_fake_data(shape=(50, 50)) + copied = copy_model_arrays(model) + assert isinstance(copied, model_class) + + @pytest.mark.parametrize("model_class", [ImageModel, MosaicModel]) + def test_preserves_data_shape(self, model_class): + """ + Test that copied model preserves data shape. + """ + shape = (100, 75) + model = model_class.create_fake_data(shape=shape) + copied = copy_model_arrays(model) + assert copied.data.shape == shape + + def test_image_model_dq_shared(self): + """ + Test that ImageModel dq array is shared (not copied). + """ + model = ImageModel.create_fake_data(shape=(50, 50)) + + copied = copy_model_arrays(model) + + # Verify dq is shared (same memory address) + assert id(model.dq) == id(copied.dq) + + def test_mosaic_model_weight_shared(self): + """ + Test that MosaicModel weight array is shared (not copied). + """ + model = MosaicModel.create_fake_data(shape=(50, 50)) + model.weight = np.ones((50, 50)) + + copied = copy_model_arrays(model) + + # Verify weight is shared (same memory address) + assert id(model.weight) == id(copied.weight) + + def test_invalid_model_type(self): + """ + Test that invalid model type raises TypeError. + """ + + class FakeModel: + pass + + expected_msg = "model must be an ImageModel or MosaicModel" + with pytest.raises(TypeError, match=expected_msg): + copy_model_arrays(FakeModel()) diff --git a/romancal/source_catalog/utils.py b/romancal/source_catalog/utils.py index c1789b5a6..db01f04d0 100644 --- a/romancal/source_catalog/utils.py +++ b/romancal/source_catalog/utils.py @@ -2,6 +2,7 @@ from astropy.modeling.fitting import SplineSplrepFitter from astropy.modeling.models import Spline1D from roman_datamodels import datamodels +from roman_datamodels.datamodels import ImageModel, MosaicModel def copy_mosaic_meta(model, cat_model): @@ -41,7 +42,7 @@ def get_ee_spline(input_model, apcorr_file): Parameters ---------- - input'_model : `~roman_datamodels.datamodels.ImageModel` or `~roman_datamodels.datamodels.MosaicModel` + input_model : `~roman_datamodels.datamodels.ImageModel` or `~roman_datamodels.datamodels.MosaicModel` The input data model. """ @@ -53,3 +54,45 @@ def get_ee_spline(input_model, apcorr_file): # Fit a spline model to the ee_fraction vs radius data so that we # can interpolate the values for arbitrary radii return SplineSplrepFitter()(Spline1D(), ee_radii, ee_fractions) + + +def copy_model_arrays(model): + """ + Create a shallow copy of ImageModel or MosaicModel with data and err + copied. + + This function creates a new model instance that shares the metadata + with the input model but has independent copies of the data and err + arrays. Other arrays (dq, weight) are shared references. + + Parameters + ---------- + model : ImageModel or MosaicModel + The input data model to copy. + + Returns + ------- + copied_model : ImageModel or MosaicModel + A new model with copied data and err arrays. + + Notes + ----- + The metadata and dq/weight arrays are not copied because they are + not modified in source catalog operations. + """ + if isinstance(model, ImageModel): + copied_model = ImageModel() + copied_model.meta = model.meta + copied_model.data = model.data.copy() + copied_model.err = model.err.copy() + copied_model.dq = model.dq + elif isinstance(model, MosaicModel): + copied_model = MosaicModel() + copied_model.meta = model.meta + copied_model.data = model.data.copy() + copied_model.err = model.err.copy() + copied_model.weight = model.weight + else: + raise TypeError("model must be an ImageModel or MosaicModel") + + return copied_model