Skip to content

Commit 798bfea

Browse files
committed
fix failing test
1 parent 6700269 commit 798bfea

File tree

1 file changed

+66
-14
lines changed

1 file changed

+66
-14
lines changed

py/minimint/mist_interpolator.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,18 @@ def getLogAgeFromEEP(self, mass, eep, feh, returnJac=False):
466466
l1feh[bad] = 0
467467
eep1[bad] = 0
468468

469-
wf, ifehs = utils._get_cubic_coeffs(feh, self.ufeh, l1feh)
470-
wm, imasses = utils._get_cubic_coeffs(mass, self.umass, l1mass)
469+
feh_arr = np.atleast_1d(feh)
470+
mass_arr = np.atleast_1d(mass)
471+
l1feh_arr = np.atleast_1d(l1feh)
472+
l1mass_arr = np.atleast_1d(l1mass)
473+
474+
wf, ifehs = utils._get_cubic_coeffs(feh_arr, self.ufeh, l1feh_arr)
475+
wm, imasses = utils._get_cubic_coeffs(mass_arr, self.umass,
476+
l1mass_arr)
477+
wf = np.atleast_2d(wf)
478+
ifehs = np.atleast_2d(ifehs)
479+
wm = np.atleast_2d(wm)
480+
imasses = np.atleast_2d(imasses)
471481
ueep = np.arange(neep)
472482
we, ieeps = utils._get_cubic_coeffs(eep, ueep, eep1)
473483

@@ -547,11 +557,42 @@ def getMaxMass(self, logage, feh):
547557
im1, im2 = curm, im2
548558
if im2 - im1 == 1:
549559
break
550-
ret = self._getMaxMassBox(logage, feh, l1feh, l1feh + 1, im1, im2)
551-
if not (np.isfinite(ret)):
552-
return self.umass[im1] # the edge
553-
else:
554-
return ret * (1 - 1e-10)
560+
lo = self.umass[im1]
561+
hi = self.umass[im2]
562+
563+
def _isfinite_mass(m):
564+
return np.isfinite(self(m, logage, feh)['logl'][0])
565+
566+
# Ensure lo is valid and hi is invalid for the refinement.
567+
if not _isfinite_mass(lo):
568+
idx = im1
569+
while idx > 0 and not _isfinite_mass(self.umass[idx]):
570+
idx -= 1
571+
lo = self.umass[idx]
572+
hi = self.umass[min(idx + 1, len(self.umass) - 1)]
573+
elif _isfinite_mass(hi):
574+
idx = im2
575+
while idx + 1 < len(self.umass) and _isfinite_mass(
576+
self.umass[idx + 1]):
577+
idx += 1
578+
if idx + 1 >= len(self.umass):
579+
return self.umass[idx]
580+
lo = self.umass[idx]
581+
hi = self.umass[idx + 1]
582+
583+
# Refine the boundary so that lo is valid and hi is invalid.
584+
# Use a strict tolerance so that lo is finite but lo+tol is not.
585+
tol = 1e-7
586+
for _ in range(40):
587+
if hi - lo <= tol:
588+
break
589+
mid = 0.5 * (lo + hi)
590+
if _isfinite_mass(mid):
591+
lo = mid
592+
else:
593+
hi = mid
594+
595+
return lo
555596

556597
def _get_eep_coeffs(self, mass, logage, feh):
557598
"""
@@ -578,8 +619,14 @@ def _get_eep_coeffs(self, mass, logage, feh):
578619
l1feh[bads] = 0
579620
l2feh[bads] = 1
580621

581-
wf, ifehs = utils._get_cubic_coeffs(feh, self.ufeh, l1feh)
582-
wm, imasses = utils._get_cubic_coeffs(mass, self.umass, l1mass)
622+
feh_arr = np.atleast_1d(feh)
623+
mass_arr = np.atleast_1d(mass)
624+
l1feh_arr = np.atleast_1d(l1feh)
625+
l1mass_arr = np.atleast_1d(l1mass)
626+
627+
wf, ifehs = utils._get_cubic_coeffs(feh_arr, self.ufeh, l1feh_arr)
628+
wm, imasses = utils._get_cubic_coeffs(mass_arr, self.umass,
629+
l1mass_arr)
583630

584631
def getAge(cureep, subset):
585632
return utils._interpolator_bicubic(
@@ -643,9 +690,14 @@ def _isvalid(self, mass, logage, feh, l1feh=None):
643690
or (l1mass < 0) or (l1feh < 0)):
644691
return False
645692

646-
C11, C12, C21, C22 = _get_polylin_coeff(feh, self.ufeh, mass,
647-
self.umass, l1feh, l2feh,
648-
l1mass, l2mass)
693+
feh_arr = np.atleast_1d(feh)
694+
mass_arr = np.atleast_1d(mass)
695+
l1feh_arr = np.atleast_1d(l1feh)
696+
l1mass_arr = np.atleast_1d(l1mass)
697+
698+
wf, ifehs = utils._get_cubic_coeffs(feh_arr, self.ufeh, l1feh_arr)
699+
wm, imasses = utils._get_cubic_coeffs(mass_arr, self.umass,
700+
l1mass_arr)
649701

650702
# we want to find there is a point i in the age grid
651703
# where grid[i]<=logage<grid[i+1]
@@ -656,8 +708,8 @@ def _isvalid(self, mass, logage, feh, l1feh=None):
656708
i1, i2 = 0, self.neep - 1
657709

658710
def getAge(cureep):
659-
return _interpolator(self.logage_grid_unfilled, C11, C12, C21, C22,
660-
l1feh, l2feh, l1mass, l2mass, cureep)
711+
return utils._interpolator_bicubic(self.logage_grid_unfilled, wf,
712+
ifehs, wm, imasses, cureep)
661713

662714
# check invariants on edges
663715
if not getAge(i1) <= logage:

0 commit comments

Comments
 (0)