Skip to content

Commit d447c6c

Browse files
JP-4192: Prepare output for resample (#10145)
1 parent a99dea6 commit d447c6c

File tree

3 files changed

+162
-75
lines changed

3 files changed

+162
-75
lines changed

jwst/resample/resample_spec_step.py

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -53,70 +53,79 @@ def process(self, input_data):
5353
SlitModel or MultiSlitModel
5454
The resampled output, one slit per source.
5555
"""
56-
with datamodels.open(input_data) as input_new:
57-
# Check if input_new is a MultiSlitModel
58-
model_is_msm = isinstance(input_new, MultiSlitModel)
59-
60-
# If input is a 3D rateints MultiSlitModel (unsupported) skip the step
61-
if model_is_msm and len((input_new[0]).shape) == 3:
62-
log.warning("Resample spec step will be skipped")
63-
result = input_new.copy()
64-
result.meta.cal_step.resample = "SKIPPED"
65-
66-
return result
67-
68-
# Convert ImageModel to SlitModel (needed for MIRI LRS)
69-
if isinstance(input_new, ImageModel):
70-
input_new = datamodels.SlitModel(input_new)
71-
72-
if isinstance(input_new, ModelContainer):
73-
input_models = input_new
74-
75-
try:
76-
output = input_models.meta.asn_table.products[0].name
77-
except AttributeError:
78-
# NIRSpec MOS data goes through this path, as the container
79-
# is only ModelContainer-like, and doesn't have an asn_table
80-
# attribute attached. Output name handling gets done in
81-
# _process_multislit() via the update method
82-
# TODO: the container-like object should retain asn_table
83-
output = None
84-
else:
85-
input_models = ModelContainer([input_new])
86-
output = input_new.meta.filename
87-
self.blendheaders = False
56+
output_model = self.prepare_output(input_data)
8857

89-
# Setup drizzle-related parameters
90-
kwargs = self.get_drizpars()
91-
kwargs["output"] = output
92-
self.drizpars = kwargs
58+
# Check if input model is a MultiSlitModel
59+
model_is_msm = isinstance(output_model, MultiSlitModel)
9360

94-
# Call resampling
95-
if isinstance(input_models[0], MultiSlitModel):
96-
result = self._process_multislit(input_models)
61+
# If input is a 3D rateints MultiSlitModel (unsupported) skip the step
62+
if model_is_msm and len((output_model[0]).shape) == 3:
63+
log.warning("Resample spec step will be skipped")
64+
output_model.meta.cal_step.resample = "SKIPPED"
65+
return output_model
9766

98-
elif len(input_models[0].data.shape) != 2:
99-
# resample can only handle 2D images, not 3D cubes, etc
100-
raise TypeError(f"Input {input_models[0]} is not a 2D image.")
67+
# Convert ImageModel to SlitModel (may be needed for older MIRI LRS data)
68+
if isinstance(output_model, ImageModel):
69+
slit_model = datamodels.SlitModel(output_model)
70+
if output_model is not input_data:
71+
output_model.close()
72+
output_model = slit_model
10173

102-
else:
103-
# result is a SlitModel
104-
result = self._process_slit(input_models)
105-
106-
# Update ASNTABLE in output
107-
result.meta.cal_step.resample = "COMPLETE"
108-
result.meta.asn.table_name = input_models[0].meta.asn.table_name
109-
result.meta.asn.pool_name = input_models[0].meta.asn.pool_name
110-
111-
# populate the result wavelength attribute for MultiSlitModel
112-
if isinstance(result, MultiSlitModel):
113-
for slit_idx, _slit in enumerate(result.slits):
114-
wl_array = get_wavelengths(result.slits[slit_idx])
115-
result.slits[slit_idx].wavelength = wl_array
116-
else:
117-
# populate the result wavelength attribute for SlitModel
118-
wl_array = get_wavelengths(result)
119-
result.wavelength = wl_array
74+
if isinstance(output_model, ModelContainer):
75+
input_models = output_model
76+
77+
try:
78+
output = input_models.meta.asn_table.products[0].name
79+
except AttributeError:
80+
# NIRSpec MOS data goes through this path, as the container
81+
# is only ModelContainer-like, and doesn't have an asn_table
82+
# attribute attached. Output name handling gets done in
83+
# _process_multislit() via the update method
84+
# TODO: the container-like object should retain asn_table
85+
output = None
86+
else:
87+
input_models = ModelContainer([output_model])
88+
output = output_model.meta.filename
89+
self.blendheaders = False
90+
91+
# Setup drizzle-related parameters
92+
kwargs = self.get_drizpars()
93+
kwargs["output"] = output
94+
self.drizpars = kwargs
95+
96+
# Call resampling
97+
if isinstance(input_models[0], MultiSlitModel):
98+
result = self._process_multislit(input_models)
99+
100+
elif len(input_models[0].data.shape) != 2:
101+
# resample can only handle 2D images, not 3D cubes, etc
102+
raise TypeError(f"Input {input_models[0]} is not a 2D image.")
103+
104+
else:
105+
# result is a SlitModel
106+
result = self._process_slit(input_models)
107+
108+
# Update ASNTABLE in output
109+
result.meta.cal_step.resample = "COMPLETE"
110+
result.meta.asn.table_name = input_models[0].meta.asn.table_name
111+
result.meta.asn.pool_name = input_models[0].meta.asn.pool_name
112+
113+
# populate the result wavelength attribute for MultiSlitModel
114+
if isinstance(result, MultiSlitModel):
115+
for slit_idx, _slit in enumerate(result.slits):
116+
wl_array = get_wavelengths(result.slits[slit_idx])
117+
result.slits[slit_idx].wavelength = wl_array
118+
else:
119+
# populate the result wavelength attribute for SlitModel
120+
wl_array = get_wavelengths(result)
121+
result.wavelength = wl_array
122+
123+
# Output is a new datamodel.
124+
# Clean up the input model(s) if they were opened here.
125+
if output_model is not input_data:
126+
del output_model
127+
if input_models is not input_data:
128+
del input_models
120129

121130
return result
122131

jwst/resample/resample_step.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import logging
22

33
from stdatamodels import filetype
4-
from stdatamodels.jwst import datamodels as dm
54

6-
from jwst.datamodels import ImageModel, ModelLibrary # type: ignore[attr-defined]
5+
from jwst.datamodels import ImageModel, ModelContainer, ModelLibrary # type: ignore[attr-defined]
76
from jwst.lib.pipe_utils import match_nans_and_flags
87
from jwst.resample import resample
98
from jwst.resample.resample_utils import load_custom_wcs
@@ -70,19 +69,25 @@ def process(self, input_data):
7069
WCS parameters such as ``output_shape`` (now computed from by
7170
``output_wcs.bounding_box``) and ``crpix``.
7271
"""
73-
if isinstance(input_data, str):
74-
ext = filetype.check(input_data)
75-
if ext in ("fits", "asdf"):
76-
input_data = dm.open(input_data)
77-
if isinstance(input_data, ModelLibrary):
78-
input_models = input_data
79-
elif isinstance(input_data, (str, dict, list)):
80-
input_models = ModelLibrary(input_data, on_disk=not self.in_memory)
81-
elif isinstance(input_data, ImageModel):
82-
input_models = ModelLibrary([input_data], on_disk=not self.in_memory)
83-
output = input_data.meta.filename
72+
# Make a copy if needed for an input model.
73+
# Don't open filenames if they're not already models --
74+
# leave it to the ModelLibrary call below to open them.
75+
input_model = self.prepare_output(input_data, open_models=False)
76+
77+
if isinstance(input_model, ModelLibrary):
78+
# Input is already a library: leave it alone.
79+
input_models = input_model
80+
elif isinstance(input_model, ImageModel) or (
81+
isinstance(input_model, str) and filetype.check(input_model) in ["fits", "asdf"]
82+
):
83+
# Input is a single file: pass it to ModelLibrary in a list
84+
input_models = ModelLibrary([input_model], on_disk=not self.in_memory)
8485
self.blendheaders = False
86+
elif isinstance(input_model, (str, dict, list, ModelContainer)):
87+
# Input is an association or list of models/files
88+
input_models = ModelLibrary(input_model, on_disk=not self.in_memory)
8589
else:
90+
# Input is not recognized
8691
raise TypeError(f"Input {input_data} is not a 2D image.")
8792

8893
try:
@@ -141,6 +146,13 @@ def process(self, input_data):
141146
)
142147
result = resamp.resample_many_to_one()
143148

149+
# The output is a new datamodel.
150+
# Clean up the input model(s) if they were opened here.
151+
if input_model is not input_data:
152+
del input_model
153+
if input_models is not input_data:
154+
del input_models
155+
144156
return result
145157

146158
@staticmethod

jwst/resample/tests/test_resample_step.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,15 @@ def test_single_image_file_input(nircam_rate, tmp_cwd):
435435
# Create a temporary file with the input data
436436
im = AssignWcsStep.call(nircam_rate, sip_approx=False)
437437
im.meta.filename = "test_input.fits"
438+
439+
# Add a bad pixel in the error plane, not matched with a NaN in data.
440+
# This will test if the match_nans_and_flags call at the beginning
441+
# of the step modifies the input model.
442+
im.err[0, 0] = np.nan
443+
444+
# Save a copy to disk, keep a copy of the data for testing.
438445
im.save("test_input.fits")
446+
im_copy = im.data.copy()
439447

440448
# Run the step on the file
441449
result_from_memory = ResampleStep.call(im)
@@ -449,12 +457,52 @@ def test_single_image_file_input(nircam_rate, tmp_cwd):
449457
# Check that input model was not modified
450458
assert im is not result_from_memory
451459
assert im.meta.cal_step.resample is None
460+
assert_allclose(im.data, im_copy)
452461

453462
result_from_file.close()
454463
result_from_memory.close()
455464
im.close()
456465

457466

467+
def test_list_model_input(nircam_rate, tmp_cwd):
468+
"""Ensure step can be run on a list of models without modifying them."""
469+
# Create a temporary file with the input data
470+
im = AssignWcsStep.call(nircam_rate, sip_approx=False)
471+
im.meta.filename = "test_input_1.fits"
472+
473+
# Add a bad pixel in the error plane, not matched with a NaN in data.
474+
# This will test if the match_nans_and_flags call at the beginning
475+
# of the step modifies the input model.
476+
im.err[0, 0] = np.nan
477+
478+
# Keep a copy of the data for testing
479+
im_copy = im.data.copy()
480+
481+
# Make a list of input models to run
482+
im2 = im.copy()
483+
im2.meta.filename = "test_input_2.fits"
484+
im_list = [im, im2]
485+
486+
# Run the step on the file
487+
result = ResampleStep.call(im_list)
488+
489+
# Check that the output is as expected
490+
assert isinstance(result, ImageModel)
491+
assert result.meta.cal_step.resample == "COMPLETE"
492+
493+
# Check that input models were not modified
494+
assert result is not im
495+
assert result is not im2
496+
assert im.meta.cal_step.resample is None
497+
assert im2.meta.cal_step.resample is None
498+
assert_allclose(im.data, im_copy)
499+
assert_allclose(im2.data, im_copy)
500+
501+
result.close()
502+
im.close()
503+
im2.close()
504+
505+
458506
@pytest.mark.parametrize("ratio", [0.5, 0.7, 1.0])
459507
def test_pixel_scale_ratio_imaging(nircam_rate, ratio):
460508
im = AssignWcsStep.call(nircam_rate, sip_approx=False)
@@ -1616,15 +1664,33 @@ def test_nirspec_lamp_pixscale(nirspec_lamp, tmp_path):
16161664
result4.close()
16171665

16181666

1619-
def test_spec_output_is_not_input(nirspec_cal):
1620-
im = ResampleSpecStep.call(nirspec_cal)
1667+
@pytest.mark.parametrize("input_list", [True, False])
1668+
def test_spec_input_not_modified(nirspec_cal, input_list):
1669+
# Add a bad pixel in the error plane of one slit, not matched
1670+
# with a NaN in data.
1671+
# This will test if the match_nans_and_flags call at the beginning
1672+
# of the step modifies the input model.
1673+
nirspec_cal.slits[0].err[15, 100] = np.nan
1674+
data_copy = nirspec_cal.slits[0].data.copy()
1675+
1676+
if input_list:
1677+
input_models = [nirspec_cal, nirspec_cal.copy()]
1678+
else:
1679+
input_models = nirspec_cal
1680+
1681+
im = ResampleSpecStep.call(input_models)
16211682

16221683
# Step is complete
16231684
assert im.meta.cal_step.resample == "COMPLETE"
16241685

16251686
# Input is not modified
16261687
assert im is not nirspec_cal
16271688
assert nirspec_cal.meta.cal_step.resample is None
1689+
if input_list:
1690+
for model in input_models:
1691+
assert_allclose(model.slits[0].data, data_copy)
1692+
else:
1693+
assert_allclose(input_models.slits[0].data, data_copy)
16281694

16291695

16301696
def test_spec_skip_cube():

0 commit comments

Comments
 (0)