Skip to content

Commit 9e940f3

Browse files
Performance PR (#123)
* improved polynomial performance, fixing minor issues, standardized the fetching of meta data, and introducing numba * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added missing files, checked something for RAvila * fixed doc string --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2130111 commit 9e940f3

File tree

17 files changed

+350
-620
lines changed

17 files changed

+350
-620
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
'h5py>=3.11',
3939
'drizzlepac>=3.9.1',
4040
'networkx>=3.3',
41+
'numba>=0.62.0',
4142
'numpy>=2.0',
4243
'matplotlib>=3.9',
4344
'pandas>=2.2',

slitlessutils/core/modules/extract/multi/matrix.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ class Matrix:
5252
INT = np.uint64 # DO NOT CHANGE THIS
5353
FLOAT = np.float64 # DO NOT CHANGE THIS
5454

55-
def __init__(self, extorders, mskorders=None, invmethod='lsqr', path='tables',
56-
minunc=1e-10):
55+
def __init__(self, extorders, **kwargs):
56+
# extract the defaults
57+
mskorders = kwargs.get('mskorders', None)
58+
invmethod = kwargs.get('invmethod', 'lsqr')
59+
path = kwargs.get('path', 'su_tables')
60+
minunc = kwargs.get('minunc', 1e-10)
5761

5862
LOGGER.debug("must finish documentation")
5963

@@ -309,21 +313,22 @@ def build_matrix(self, data, sources, group=0):
309313
# now do a default damping target
310314
LOGGER.debug('damping target might be suspect for 2d objects')
311315
LOGGER.info("Building Damping Target")
316+
312317
target = np.zeros(self.nunknowns, dtype=float)
313318
for segid, source in sources.items():
314319

315320
for sedkey, region in source.items():
316321
extid = self.sedkeys.index(sedkey)
322+
if extid in self.ri:
323+
g1 = self.ri[extid]
324+
g2 = self.lamids[g1]
317325

318-
g1 = self.ri[extid]
319-
g2 = self.lamids[g1]
320-
321-
if hasattr(source, 'extpars'):
322-
waves = source.extpars.wavelengths()
323-
else:
324-
waves = self.defpars.wavelengths()
326+
if hasattr(source, 'extpars'):
327+
waves = source.extpars.wavelengths()
328+
else:
329+
waves = self.defpars.wavelengths()
325330

326-
target[g1] = region.sed(waves[g2])
331+
target[g1] = region.sed(waves[g2])
327332

328333
self.set_damping_target(target)
329334

slitlessutils/core/modules/extract/multi/multi.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .....config import SUFFIXES, Config
88
from .....logger import LOGGER
9-
from ....utilities import as_iterable, headers
9+
from ....utilities import as_iterable, get_metadata, headers
1010
from ...group import GroupCollection
1111
from ...module import Module
1212
from .matrix import Matrix
@@ -60,16 +60,22 @@ class Multi(Module):
6060
DESCRIPTION = "Extracting (multi)"
6161

6262
def __init__(self, extorders, logdamp, mskorders=None, algorithm='golden',
63-
**kwargs):
63+
root=None, **kwargs):
6464
Module.__init__(self, self.extract, **kwargs, multiprocess=False)
6565

66+
# file root name
67+
if not isinstance(root, str):
68+
meta = get_metadata()
69+
self.root = meta['Name']
70+
else:
71+
self.root = root
72+
6673
self.extorders = as_iterable(extorders)
6774
self.mskorders = as_iterable(mskorders)
68-
6975
self.optimizer = optimizer(algorithm, logdamp)
7076
self.matrix = Matrix(self.extorders, **kwargs)
7177

72-
def extract(self, data, sources, groups=None, root=None):
78+
def extract(self, data, sources, groups=None):
7379
"""
7480
Method to do the mult-ended spectral extraction
7581
@@ -85,10 +91,6 @@ def extract(self, data, sources, groups=None, root=None):
8591
The collection of groups. If set as None, then no grouping
8692
will be performed. Default is None
8793
88-
root : str, optional
89-
The root name of the output products. If None, then
90-
'slitlessutils' is used. Default is None
91-
9294
Notes
9395
-----
9496
This is likely not to be directly called.
@@ -116,9 +118,9 @@ def extract(self, data, sources, groups=None, root=None):
116118
hdul1.append(phdu)
117119
hdul3.append(phdu)
118120

119-
# file root name
120-
if not isinstance(root, str):
121-
root = __package__
121+
# get the package meta data
122+
meta = get_metadata()
123+
package = meta.get('Name', '')
122124

123125
# put loops over groups here
124126
if groups:
@@ -128,16 +130,16 @@ def extract(self, data, sources, groups=None, root=None):
128130
groups = GroupCollection()
129131

130132
# open a PDF to write Grouping images
131-
pdffile = f'{root}_{SUFFIXES["L-curve"]}.pdf'
133+
pdffile = f'{self.root}_{SUFFIXES["L-curve"]}.pdf'
132134
LOGGER.info(f'Writing grouped L-curve figure: {pdffile}')
133135
with PdfPages(pdffile) as pdf:
134136
# add some info to the PDF
135137
d = pdf.infodict()
136138
d['Title'] = 'L-Curve Results'
137139
d['Author'] = getpass.getuser()
138-
d['Subject'] = f'L-Curve results for grouped data from {__package__}'
139-
d['Keywords'] = f'{__package__} WFSS L-curve groups'
140-
d['Producer'] = __package__
140+
d['Subject'] = f'L-Curve results for grouped data from {package}'
141+
d['Keywords'] = f'{package} WFSS L-curve groups'
142+
d['Producer'] = package
141143

142144
for grpid, srcdict in enumerate(groups(sources)):
143145

@@ -241,13 +243,13 @@ def extract(self, data, sources, groups=None, root=None):
241243
#
242244
#
243245
#
244-
# self.matrix.lcurve.plot(f'{root}_{SUFFIXES["L-curve"]}.pdf')
246+
# self.matrix.lcurve.plot(f'{self.root}_{SUFFIXES["L-curve"]}.pdf')
245247

246248
# write out the files
247249
if len(hdul1) > 1:
248250
hdul1[0].header['FILETYPE'] = '1d spectra'
249251

250-
x1dfile = f'{root}_{SUFFIXES["1d spectra"]}.fits'
252+
x1dfile = f'{self.root}_{SUFFIXES["1d spectra"]}.fits'
251253
LOGGER.info(f'Writing 1d extractions: {x1dfile}')
252254
hdul1.writeto(x1dfile, overwrite=True)
253255
else:
@@ -256,12 +258,12 @@ def extract(self, data, sources, groups=None, root=None):
256258
if len(hdul3) > 1:
257259
hdul3[0].header['FILETYPE'] = '3d spectra'
258260

259-
x3dfile = f'{root}_{SUFFIXES["3d spectra"]}.fits'
261+
x3dfile = f'{self.root}_{SUFFIXES["3d spectra"]}.fits'
260262
LOGGER.info(f'Writing 3d extractions: {x3dfile}')
261263
hdul1.writeto(x3dfile, overwrite=True)
262264
else:
263265
LOGGER.warning('No 3d spectra written')
264266

265-
lcvfile = f'{root}_{SUFFIXES["L-curve"]}.fits'
267+
lcvfile = f'{self.root}_{SUFFIXES["L-curve"]}.fits'
266268
LOGGER.info(f'Writing L-curve tabular data {lcvfile}')
267269
hdulL.writeto(lcvfile, overwrite=True)

slitlessutils/core/modules/extract/single/single.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .....config import SUFFIXES, Config
88
from .....logger import LOGGER
99
from ....tables import PDTFile
10-
from ....utilities import headers
10+
from ....utilities import get_metadata, headers
1111
from ...module import Module
1212
from .boxcar import boxcar
1313
from .contamination import Contamination
@@ -130,7 +130,8 @@ def __init__(self, extorders, mskorders='all', savecont=False, root=None,
130130

131131
# output file names
132132
if root is None:
133-
root = __package__
133+
root = get_metadata()['Name']
134+
134135
self.filename = os.path.join(self.outpath, f"{root}_{SUFFIXES[self.FILETYPE]}.fits")
135136

136137
if self.savecont:

slitlessutils/core/modules/simulate/simulate.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datetime import datetime
2-
from importlib.metadata import metadata
32

43
import numpy as np
54
from astropy.io import fits
@@ -77,17 +76,13 @@ def simulate(self, data, sources, **kwargs):
7776
The filename of the image created
7877
7978
"""
80-
8179
if sources.sedfile is None:
8280
LOGGER.critical("No SEDFile found.")
8381
return
8482

8583
# grab the time
8684
t0 = datetime.now()
8785

88-
# grab the inputs
89-
# insconf,insdata = data
90-
9186
# open the table for reading
9287
with PDTFile(data, path=self.path, mode='r') as h5:
9388

@@ -125,7 +120,6 @@ def simulate(self, data, sources, **kwargs):
125120

126121
# process each detector
127122
for detname, detdata in data.items():
128-
129123
# load a detector
130124
# detconf=detdata.config
131125
# detconf=insconf[detname]
@@ -202,6 +196,7 @@ def simulate(self, data, sources, **kwargs):
202196
# the function make_HDUs takes the noiseless sci, creates
203197
# all ancillary data, packages into HDUs, and adds noise
204198
# hdus= detdata.make_HDUs(sci,self.noisepars)
199+
205200
hdus = detdata.make_HDUs(sci, addnoise=self.addnoise)
206201

207202
# put the HDUs into the list, but first update some header info
@@ -217,18 +212,14 @@ def simulate(self, data, sources, **kwargs):
217212
# append the HDU
218213
hdul.append(hdu)
219214

220-
# record the end time
221-
t1 = datetime.now()
222-
dt = t1 - t0
215+
# compute runtime
216+
dt = datetime.now() - t0
223217

224218
# put some times into the header
225-
version = metadata(__package__).get('Version', 'unknown')
226-
phdu.header.set('ORIGIN', value=f'{__package__} v{version}',
227-
after='NAXIS')
228-
phdu.header.set('DATE', value=t1.strftime('%Y-%m-%d'), after='ORIGIN',
229-
comment='date this file was written (yyyy-mm-dd)')
230-
phdu.header.set('RUNTIME', value=dt.total_seconds(), after='DATE',
231-
comment='run time of this file in s')
219+
headers.add_preamble(phdu.header,
220+
runtime=(dt.total_seconds(), 'runtime in s'))
221+
headers.add_software_log(phdu.header)
222+
sources.update_header(phdu.header) # source props
232223

233224
# put the primary header in the HDUL
234225
hdul.insert(0, phdu)

slitlessutils/core/photometry/sed.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,6 @@ def __call__(self, wave, fnu=False, **kwargs):
493493
flam = self.flam[g]
494494
flux = np.interp(wave, lamb, flam, **kwargs, left=flam[0], right=flam[-1])
495495

496-
# flux=np.interp(wave,self.lamb[g],self.flam[g],**kwargs,
497-
# left=self.flam[g[0]],right=self.flam[g[-1]])
498-
499496
if fnu:
500497
flux *= ((wave / c) * (wave / 1e10))
501498

@@ -626,7 +623,7 @@ def from_file(cls, filename, exten=1):
626623
return obj
627624

628625
@staticmethod
629-
def get_from_CDBS(atlas, filename):
626+
def get_from_CDBS(atlas, filename, outpath=''):
630627
"""
631628
Staticmethod to retrieve a spectrum from the Calibration Database
632629
System (CDBS)
@@ -639,6 +636,9 @@ def get_from_CDBS(atlas, filename):
639636
filename : str
640637
The filename in the atlas
641638
639+
outpath : str
640+
The path for the output file. Default is current working dir
641+
642642
"""
643643
# base URL for CDBS
644644
url = 'https://archive.stsci.edu/hlsps/reference-atlases/cdbs/grid/'
@@ -660,6 +660,13 @@ def get_from_CDBS(atlas, filename):
660660
return
661661
shutil.move(tmpfile, filename)
662662

663+
if isinstance(outpath, str):
664+
newfile = os.path.join(outpath, filename)
665+
shutil.move(filename, newfile)
666+
return newfile
667+
else:
668+
return filename
669+
663670
@classmethod
664671
def from_CDBS(obj, atlas, filename, cleanup=True):
665672
"""

slitlessutils/core/preprocess/astrometry/list_wcs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@ def list_wcs(arg):
2929
obs = (h0['TELESCOP'], h0['INSTRUME'], h0['DETECTOR'])
3030
if obs == ('HST', 'ACS', 'WFC'):
3131
filtername = h0['FILTER1']
32+
33+
elif obs == ('HST', 'ACS', 'SBC'):
34+
filtername = h0['FILTER1']
35+
3236
elif obs == ('HST', 'WFC3', 'UVIS'):
3337
filtername = h0['FILTER']
38+
3439
elif obs == ('HST', 'WFC3', 'IR'):
3540
filtername = h0['FILTER']
41+
3642
else:
3743
raise ValueError(f"Cannot find observation {obs}")
3844

slitlessutils/core/sources/source.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, segid, img, seg, hdr, reg=None, zeropoint=26., grpid=0,
117117
self.ltv = [hdr.get('LTV1', 0.), hdr.get('LTV2', 0.)]
118118

119119
# record some things
120-
self.segid = segid # hdr['SEGID']
120+
self.segid = segid
121121
self.grpid = grpid
122122
# self.whttype = whttype
123123
self.backsize = max(backsize, 0)
@@ -225,6 +225,8 @@ def __init__(self, segid, img, seg, hdr, reg=None, zeropoint=26., grpid=0,
225225
# region image
226226
ri = indices.reverse(r, ignore=(0,))
227227
for regid, pixid in ri.items():
228+
regid = int(regid)
229+
228230
# pixels and weights for this region
229231
xx = x[pixid]
230232
yy = y[pixid]
@@ -642,6 +644,8 @@ def load_sedlib(self, sedlib, throughput=None):
642644

643645
region.sed = sed
644646

647+
sed.write_file(f'spectra/{self.segid}_{region.regid}.csv')
648+
645649
def load_sed_images(self):
646650
pass
647651

@@ -664,8 +668,6 @@ def load_sed_array(self, waves, flam):
664668
"""
665669

666670
for regkey, region in self.items():
667-
668-
# region.sed.set_sed(waves,flam*np.sum(region.w))
669671
region.sed.append(waves, flam * np.sum(region.w))
670672

671673
def write_seds(self, filetype='sed', path=None, **kwargs):

slitlessutils/core/sources/sourcecollection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ class SourceCollection(dict):
2929
# the default zeropoint
3030
DEFZERO = 26.
3131

32-
def __init__(self, segfile, detfile, maglim=np.inf, minpix=0, preprocessor=None,
33-
zeropoint=None, throughput=None, sedfile=None,
34-
**kwargs):
32+
def __init__(self, segfile, detfile, maglim=np.inf, minpix=0,
33+
preprocessor=None, zeropoint=None, throughput=None,
34+
sedfile=None, **kwargs):
3535
"""
3636
Initializer
3737
@@ -196,6 +196,8 @@ def _load_classic(self, hdus, hdui, exten=0, **kwargs):
196196
# find pixels for each object
197197
ri = indices.reverse(seg, ignore=(0,))
198198
for segid, (y, x) in ri.items():
199+
segid = int(segid)
200+
199201
# get a bounding box
200202
x0 = np.maximum(np.amin(x) - self.PAD, 0)
201203
x1 = np.minimum(np.amax(x) + self.PAD + 1, nx - 1)

slitlessutils/core/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from . import headers # noqa: F401
33
from . import indices # noqa: F401
44
from .as_iterable import as_iterable # noqa: F401
5+
from .get_metadata import get_metadata # noqa: F401
56
from .pool import Pool # noqa: F401

0 commit comments

Comments
 (0)