Skip to content

Commit 4edabe9

Browse files
authored
Merge pull request #661 from VChristiaens/master
New (I)PCA features and a minor bug fix
2 parents 79c2ab2 + a101ab0 commit 4edabe9

File tree

14 files changed

+527
-425
lines changed

14 files changed

+527
-425
lines changed

vip_hci/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.6.4"
1+
__version__ = "1.6.5"
22

33
from . import preproc
44
from . import config

vip_hci/config/utils_conf.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,25 @@ def pool_map(nproc, fkt, *args, **kwargs):
472472
if not _generator:
473473
res = list(res)
474474
else:
475+
# deactivate multithreading if not yet deactivated
476+
try:
477+
ncpus_mt1 = os.environ["OMP_NUM_THREADS"]
478+
ncpus_mt2 = os.environ["NUMEXPR_NUM_THREADS"]
479+
ncpus_mt3 = os.environ["MKL_NUM_THREADS"]
480+
if ncpus_mt1 != "1" or ncpus_mt2 != "1" or ncpus_mt3 != "1":
481+
os.environ["OMP_NUM_THREADS"] = "1"
482+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
483+
os.environ["MKL_NUM_THREADS"] = "1"
484+
wrongly_set = True
485+
else:
486+
wrongly_set = False
487+
vars_are_set = True
488+
except KeyError: # if the variables are not set, set them manually
489+
os.environ["OMP_NUM_THREADS"] = "1"
490+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
491+
os.environ["MKL_NUM_THREADS"] = "1"
492+
vars_are_set = False
493+
475494
# Check available start methods and pick accordingly (machine-dependent)
476495
avail_methods = multiprocessing.get_all_start_methods()
477496
# if 'forkserver' in avail_methods: # fast and safe, if available
@@ -503,11 +522,6 @@ def pool_map(nproc, fkt, *args, **kwargs):
503522

504523
from multiprocessing import Pool
505524

506-
# deactivate multithreading
507-
os.environ["MKL_NUM_THREADS"] = "1"
508-
os.environ["NUMEXPR_NUM_THREADS"] = "1"
509-
os.environ["OMP_NUM_THREADS"] = "1"
510-
511525
if verbose and msg is not None:
512526
print("{} with {} processes".format(msg, nproc))
513527
pool = Pool(processes=nproc)
@@ -518,11 +532,15 @@ def pool_map(nproc, fkt, *args, **kwargs):
518532
pool.close()
519533
pool.join()
520534

521-
# reactivate multithreading
522-
ncpus = multiprocessing.cpu_count()
523-
os.environ["MKL_NUM_THREADS"] = str(ncpus)
524-
os.environ["NUMEXPR_NUM_THREADS"] = str(ncpus)
525-
os.environ["OMP_NUM_THREADS"] = str(ncpus)
535+
# return back to default behaviour regarding multithreading
536+
if not vars_are_set:
537+
del os.environ["OMP_NUM_THREADS"]
538+
del os.environ["NUMEXPR_NUM_THREADS"]
539+
del os.environ["MKL_NUM_THREADS"]
540+
elif wrongly_set:
541+
os.environ["OMP_NUM_THREADS"] = ncpus_mt1
542+
os.environ["NUMEXPR_NUM_THREADS"] = ncpus_mt2
543+
os.environ["MKL_NUM_THREADS"] = ncpus_mt3
526544

527545
return res
528546

vip_hci/fm/fakecomp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,8 @@ def cube_planet_free(planet_parameter, cube, angs, psfn, imlib='vip-fft',
827827
parameter must have a shape (n_pl,3) or (3,) -- the latter case assumes
828828
a single planet in the data. For a 4d cube r, theta and flux
829829
must all be 1d arrays with length equal to cube.shape[0]; i.e.
830-
planet_parameter should have shape: (n_pl,3,n_ch).
830+
planet_parameter should have shape: (n_pl,3,n_ch) or (3,n_ch) -- the
831+
latter case assumes a single planet in the data.
831832
cube: numpy ndarray
832833
The cube of fits images expressed as a numpy.array.
833834
angs: numpy ndarray

vip_hci/fm/negfc_fmerit.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def get_mu_and_sigma(cube, angs, ncomp, annulus_width, aperture_radius, fwhm,
621621
cube_ref=None, wedge=None, svd_mode="lapack", scaling=None,
622622
algo=pca_annulus, delta_rot=1, imlib="vip-fft",
623623
interpolation="lanczos4", collapse="median", weights=None,
624-
algo_options={}, bin_spec=False):
624+
algo_options={}, bin_spec=False, verbose=False):
625625
"""Extract the mean and standard deviation of pixel intensities in an\
626626
annulus of the PCA-ADI image obtained with 'algo', in the part of a defined\
627627
wedge that is not overlapping with PA_pl+-delta_PA.
@@ -648,7 +648,7 @@ def get_mu_and_sigma(cube, angs, ncomp, annulus_width, aperture_radius, fwhm,
648648
The angular position of the center of the circular aperture. This
649649
parameter is NOT the angular position of the candidate associated to the
650650
Markov chain, but should be the fixed initial guess.
651-
f_guess: float, optional
651+
f_guess: float or 1d numpy array, optional
652652
The flux estimate for the companion.
653653
psfn: 2D or 3D numpy ndarray, optional
654654
Normalized psf used to remove the companion if f_guess is provided.
@@ -721,15 +721,23 @@ def get_mu_and_sigma(cube, angs, ncomp, annulus_width, aperture_radius, fwhm,
721721
722722
"""
723723
if f_guess is not None and psfn is not None:
724-
planet_parameter = (r_guess, theta_guess, f_guess)
724+
if np.isscalar(f_guess):
725+
planet_parameter = (r_guess, theta_guess, f_guess)
726+
elif len(f_guess) == 1:
727+
planet_parameter = (r_guess, theta_guess, f_guess[0])
728+
else:
729+
r_all = [r_guess]*len(f_guess)
730+
theta_all = [r_guess]*len(f_guess)
731+
planet_parameter = np.array([r_all, theta_all, f_guess])
725732
array = cube_planet_free(planet_parameter, cube, angs, psfn,
726733
imlib=imlib, interpolation=interpolation,
727734
transmission=None)
728735
else:
729-
msg = "WARNING: f_guess not provided. The companion will not be "
730-
msg += "removed from the cube before estimating mu and sigma. "
731-
msg += "A wedge will be used"
732-
print(msg)
736+
if verbose:
737+
msg = "WARNING: f_guess not provided. The companion will not be "
738+
msg += "removed from the cube before estimating mu and sigma. "
739+
msg += "A wedge will be used"
740+
print(msg)
733741
array = cube.copy()
734742

735743
centy_fr, centx_fr = frame_center(array[0])

vip_hci/fm/negfc_mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
833833

834834
mu_sig = get_mu_and_sigma(cube, angs, ncomp, annulus_width, aperture_radius,
835835
fwhm, initial_state[0], initial_state[1],
836-
initial_state[2], psfn, cube_ref=cube_ref,
836+
initial_state[2:], psfn, cube_ref=cube_ref,
837837
wedge=wedge, svd_mode=svd_mode, scaling=scaling,
838838
algo=algo, delta_rot=delta_rot, imlib=imlib_rot,
839839
interpolation=interpolation, collapse=collapse,

0 commit comments

Comments
 (0)