|
2 | 2 |
|
3 | 3 | _old_style_where = False |
4 | 4 | try: |
5 | | - import scipy._lib.array_api_extra as xpx |
| 5 | + from scipy._lib.array_api_extra import apply_where |
6 | 6 | from scipy.stats._continuous_distns import ( |
7 | 7 | _norm_cdf, |
8 | 8 | _norm_pdf_C, |
@@ -74,11 +74,11 @@ def lhs(x, betaL, betaH, mL, mH): |
74 | 74 | def rhs(x, betaL, betaH, mL, mH): |
75 | 75 | if _old_style_where: |
76 | 76 | return _lazywhere(x < betaH, (-x, betaH, mH), f=core, f2=tail) |
77 | | - return xpx.apply_where(x < betaH, (-x, betaH, mH), core, tail) |
| 77 | + return apply_where(x < betaH, (-x, betaH, mH), core, tail) |
78 | 78 |
|
79 | 79 | if _old_style_where: |
80 | | - N * _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs) |
81 | | - return N * xpx.apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs) |
| 80 | + return N * _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs) |
| 81 | + return N * apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs) |
82 | 82 |
|
83 | 83 | def _logpdf(self, x, betaL, betaH, mL, mH): |
84 | 84 | """ |
@@ -106,15 +106,13 @@ def lhs(x, betaL, betaH, mL, mH): |
106 | 106 | def rhs(x, betaL, betaH, mL, mH): |
107 | 107 | if _old_style_where: |
108 | 108 | return _lazywhere(x < betaH, (-x, betaH, mH), f=core, f2=tail) |
109 | | - return xpx.apply_where(x < betaH, (-x, betaH, mH), core, tail) |
| 109 | + return apply_where(x < betaH, (-x, betaH, mH), core, tail) |
110 | 110 |
|
111 | 111 | if _old_style_where: |
112 | 112 | return np.log(N) + _lazywhere( |
113 | 113 | x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs |
114 | 114 | ) |
115 | | - return np.log(N) + xpx.apply_where( |
116 | | - x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs |
117 | | - ) |
| 115 | + return np.log(N) + apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs) |
118 | 116 |
|
119 | 117 | def _cdf(self, x, betaL, betaH, mL, mH): |
120 | 118 | """ |
@@ -159,11 +157,11 @@ def rhs(x, betaL, betaH, mL, mH): |
159 | 157 | return _lazywhere( |
160 | 158 | x < betaH, (x, betaL, betaH, mL, mH), f=core, f2=hightail |
161 | 159 | ) |
162 | | - return xpx.apply_where(x < betaH, (x, betaL, betaH, mL, mH), core, hightail) |
| 160 | + return apply_where(x < betaH, (x, betaL, betaH, mL, mH), core, hightail) |
163 | 161 |
|
164 | 162 | if _old_style_where: |
165 | | - return _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs) |
166 | | - return N * xpx.apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs) |
| 163 | + return N * _lazywhere(x > -betaL, (x, betaL, betaH, mL, mH), f=rhs, f2=lhs) |
| 164 | + return N * apply_where(x > -betaL, (x, betaL, betaH, mL, mH), rhs, lhs) |
167 | 165 |
|
168 | 166 | def _ppf(self, p, betaL, betaH, mL, mH): |
169 | 167 | """ |
@@ -215,19 +213,15 @@ def ppf_greater(p, betaL, betaH, mL, mH): |
215 | 213 | return _lazywhere( |
216 | 214 | p > pbetaH, (p, betaL, betaH, mL, mH), f=hightail, f2=core |
217 | 215 | ) |
218 | | - return xpx.apply_where( |
219 | | - p > pbetaH, (p, betaL, betaH, mL, mH), hightail, core |
220 | | - ) |
| 216 | + return apply_where(p > pbetaH, (p, betaL, betaH, mL, mH), hightail, core) |
221 | 217 |
|
222 | 218 | N = 1.0 / (inttail(betaL, mL) + intcore(betaL, betaH) + inttail(betaH, mH)) |
223 | 219 | pbetaL = N * (mL / betaL) * np.exp(-0.5 * betaL * betaL) / (mL - 1) |
224 | 220 | if _old_style_where: |
225 | 221 | return _lazywhere( |
226 | 222 | p < pbetaL, (p, betaL, betaH, mL, mH), f=lowtail, f2=ppf_greater |
227 | 223 | ) |
228 | | - return xpx.apply_where( |
229 | | - p < pbetaL, (p, betaL, betaH, mL, mH), lowtail, ppf_greater |
230 | | - ) |
| 224 | + return apply_where(p < pbetaL, (p, betaL, betaH, mL, mH), lowtail, ppf_greater) |
231 | 225 |
|
232 | 226 | def _munp(self, n, betaL, betaH, mL, mH): |
233 | 227 | """ |
|
0 commit comments