Skip to content

Commit 424c2dc

Browse files
committed
Better check for invalid 2F1 in DLMF1583
1 parent 36a5446 commit 424c2dc

File tree

2 files changed

+31
-38
lines changed

2 files changed

+31
-38
lines changed

tsdate/approx.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,12 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
111111
112112
:return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j]
113113
"""
114-
assert y_ij >= 0 and mu_ij > 0, "Invalid edge parameters"
115114

116115
a = a_i + a_j + y_ij
117116
b = a_j
118117
c = a_j + y_ij + 1
119118
t = mu_ij + b_i
120119

121-
assert a > 0 and b > 0 and c > 0, "Invalid local posterior"
122-
123120
log_f, sign_f, da_i, db_i, da_j, db_j = hypergeo._hyp2f1(
124121
a_i, b_i, a_j, b_j, y_ij, mu_ij
125122
)

tsdate/hypergeo.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _hyp2f1_recurrence(a, b, c, z):
255255

256256

257257
@numba.njit(
258-
"UniTuple(float64, 6)(float64, float64, float64, float64, float64, float64)"
258+
"UniTuple(float64, 7)(float64, float64, float64, float64, float64, float64)"
259259
)
260260
def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
261261
"""
@@ -287,21 +287,26 @@ def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
287287
)
288288

289289
# 2F1(a, -y; c; z) via backwards recurrence
290-
val, sign, da, _, dc, dz, _ = _hyp2f1_recurrence(a, y, c, z)
290+
val, sign, da, _, dc, dz, d2z = _hyp2f1_recurrence(a, y, c, z)
291291

292292
# map gradient to parameters
293293
da_i = dc - _digamma(a_i + a_j) + _digamma(a_i)
294294
da_j = da + dc - np.log(s) + _digamma(a_j + y + 1) - _digamma(a_i + a_j)
295295
db_i = dz / (b_j - mu) + a_j / (mu + b_i)
296296
db_j = dz * (1 - z) / (b_j - mu) - a_j / s / (mu + b_i)
297297

298+
# needed to verify result
299+
d2b_j = (1 - z) / (b_j - mu) ** 2 * (d2z * (1 - z) - 2 * dz * (1 + a_j)) + (
300+
1 + a_j
301+
) * a_j / (b_j - mu) ** 2
302+
298303
val += scale
299304

300-
return val, sign, da_i, db_i, da_j, db_j
305+
return val, sign, da_i, db_i, da_j, db_j, d2b_j
301306

302307

303308
@numba.njit(
304-
"UniTuple(float64, 6)(float64, float64, float64, float64, float64, float64)"
309+
"UniTuple(float64, 7)(float64, float64, float64, float64, float64, float64)"
305310
)
306311
def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
307312
"""
@@ -320,18 +325,24 @@ def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
320325
)
321326

322327
# 2F1(a, y+1; c; z) via series expansion
323-
val, sign, da, _, dc, dz, _ = _hyp2f1_taylor_series(a, y + 1, c, z)
328+
val, sign, da, _, dc, dz, d2z = _hyp2f1_taylor_series(a, y + 1, c, z)
324329

325330
# map gradient to parameters
326331
da_i = da + np.log(z) + dc + _digamma(a_i) - _digamma(a_i + y + 1)
327332
da_j = da + np.log(z) + _digamma(a_j + y + 1) - _digamma(a_j)
328333
db_i = (1 - z) * (dz + a / z) / (b_i + b_j)
329334
db_j = -z * (dz + a / z) / (b_i + b_j)
330335

336+
# needed to verify result
337+
d2b_j = (
338+
z / (b_i + b_j) ** 2 * (d2z * z + 2 * dz * (1 + a))
339+
+ a * (1 + a) / (b_i + b_j) ** 2
340+
)
341+
331342
sign *= (-1) ** (y + 1)
332343
val += scale
333344

334-
return val, sign, da_i, db_i, da_j, db_j
345+
return val, sign, da_i, db_i, da_j, db_j, d2b_j
335346

336347

337348
@numba.njit(
@@ -345,42 +356,15 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
345356
assert 0 <= mu <= b_j
346357
assert y >= 0 and y % 1.0 == 0.0
347358

348-
f_1, s_1, da_i_1, db_i_1, da_j_1, db_j_1 = _hyp2f1_dlmf1583_first(
359+
f_1, s_1, da_i_1, db_i_1, da_j_1, db_j_1, d2b_j_1 = _hyp2f1_dlmf1583_first(
349360
a_i, b_i, a_j, b_j, y, mu
350361
)
351362

352-
f_2, s_2, da_i_2, db_i_2, da_j_2, db_j_2 = _hyp2f1_dlmf1583_second(
363+
f_2, s_2, da_i_2, db_i_2, da_j_2, db_j_2, d2b_j_2 = _hyp2f1_dlmf1583_second(
353364
a_i, b_i, a_j, b_j, y, mu
354365
)
355366

356367
f_0 = max(f_1, f_2)
357-
358-
# 2sum
359-
aa = f_1
360-
bb = -1 * f_0
361-
s = aa + bb
362-
ap = s - bb
363-
bp = s - ap
364-
da = aa - ap
365-
db = bb - bp
366-
t = da + db
367-
print("2sum", s, t)
368-
369-
aa = f_2
370-
bb = -1 * f_0
371-
s = aa + bb
372-
ap = s - bb
373-
bp = s - ap
374-
da = aa - ap
375-
db = bb - bp
376-
t = da + db
377-
print("2sum", s, t)
378-
# /2sum
379-
380-
# if np.abs(f_1 - f_2) < _HYP2F1_TOL:
381-
# # TODO: detect a priori if this will occur
382-
# raise Invalid2F1("Singular hypergeometric function")
383-
384368
f_1 = np.exp(f_1 - f_0) * s_1
385369
f_2 = np.exp(f_2 - f_0) * s_2
386370
f = f_1 + f_2
@@ -389,10 +373,22 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
389373
db_i = (db_i_1 * f_1 + db_i_2 * f_2) / f
390374
da_j = (da_j_1 * f_1 + da_j_2 * f_2) / f
391375
db_j = (db_j_1 * f_1 + db_j_2 * f_2) / f
376+
d2b_j = (d2b_j_1 * f_1 + d2b_j_2 * f_2) / f
392377

393378
sign = np.sign(f)
394379
val = np.log(np.abs(f)) + f_0
395380

381+
# use first/second derivatives to check that result is non-singular
382+
dz = -db_j * (mu + b_i)
383+
d2z = d2b_j * (mu + b_i) ** 2
384+
if (
385+
not _is_valid_2f1(
386+
dz, d2z, a_j, a_i + a_j + y, a_j + y + 1, (mu - b_j) / (mu + b_i)
387+
)
388+
or sign <= 0
389+
):
390+
raise Invalid2F1("Hypergeometric series did not converge")
391+
396392
return val, sign, da_i, db_i, da_j, db_j
397393

398394

0 commit comments

Comments
 (0)