Skip to content

Commit e563de1

Browse files
authored
fix: backport "finally fix doublecrystalball for good" (#1377)
* actually check for apply_where import * check for ModuleNotFoundError too * missing return statement here * missing normalization and import typo
1 parent 36d0b31 commit e563de1

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

coffea/lookup_tools/doublecrystalball.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
_old_style_where = False
44
try:
5-
import scipy._lib.array_api_extra as xpx
5+
from scipy._lib.array_api_extra import apply_where
66
from scipy.stats._continuous_distns import (
77
_norm_cdf,
88
_norm_pdf_C,
@@ -74,11 +74,11 @@ def lhs(x, betaL, betaH, mL, mH):
7474
def rhs(x, betaL, betaH, mL, mH):
7575
if _old_style_where:
7676
return _lazywhere(x < betaH, (-x, betaH, mH), f=core, f2=tail)
77-
return xpx.apply_where(x < betaH, (-x, betaH, mH), core, tail)
77+
return apply_where(x < betaH, (-x, betaH, mH), core, tail)
7878

7979
if _old_style_where:
80-
N * _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs)
81-
return N * xpx.apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs)
80+
return N * _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs)
81+
return N * apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs)
8282

8383
def _logpdf(self, x, betaL, betaH, mL, mH):
8484
"""
@@ -106,15 +106,13 @@ def lhs(x, betaL, betaH, mL, mH):
106106
def rhs(x, betaL, betaH, mL, mH):
107107
if _old_style_where:
108108
return _lazywhere(x < betaH, (-x, betaH, mH), f=core, f2=tail)
109-
return xpx.apply_where(x < betaH, (-x, betaH, mH), core, tail)
109+
return apply_where(x < betaH, (-x, betaH, mH), core, tail)
110110

111111
if _old_style_where:
112112
return np.log(N) + _lazywhere(
113113
x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs
114114
)
115-
return np.log(N) + xpx.apply_where(
116-
x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs
117-
)
115+
return np.log(N) + apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs)
118116

119117
def _cdf(self, x, betaL, betaH, mL, mH):
120118
"""
@@ -159,11 +157,11 @@ def rhs(x, betaL, betaH, mL, mH):
159157
return _lazywhere(
160158
x < betaH, (x, betaL, betaH, mL, mH), f=core, f2=hightail
161159
)
162-
return xpx.apply_where(x < betaH, (x, betaL, betaH, mL, mH), core, hightail)
160+
return apply_where(x < betaH, (x, betaL, betaH, mL, mH), core, hightail)
163161

164162
if _old_style_where:
165-
return _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs)
166-
return N * xpx.apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs)
163+
return N * _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs)
164+
return N * apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs)
167165

168166
def _ppf(self, p, betaL, betaH, mL, mH):
169167
"""
@@ -215,19 +213,15 @@ def ppf_greater(p, betaL, betaH, mL, mH):
215213
return _lazywhere(
216214
p > pbetaH, (p, betaL, betaH, mL, mH), f=hightail, f2=core
217215
)
218-
return xpx.apply_where(
219-
p > pbetaH, (p, betaL, betaH, mL, mH), hightail, core
220-
)
216+
return apply_where(p > pbetaH, (p, betaL, betaH, mL, mH), hightail, core)
221217

222218
N = 1.0 / (inttail(betaL, mL) + intcore(betaL, betaH) + inttail(betaH, mH))
223219
pbetaL = N * (mL / betaL) * np.exp(-0.5 * betaL * betaL) / (mL - 1)
224220
if _old_style_where:
225221
return _lazywhere(
226222
p < pbetaL, (p, betaL, betaH, mL, mH), f=lowtail, f2=ppf_greater
227223
)
228-
return xpx.apply_where(
229-
p < pbetaL, (p, betaL, betaH, mL, mH), lowtail, ppf_greater
230-
)
224+
return apply_where(p < pbetaL, (p, betaL, betaH, mL, mH), lowtail, ppf_greater)
231225

232226
def _munp(self, n, betaL, betaH, mL, mH):
233227
"""

0 commit comments

Comments
 (0)