@@ -255,7 +255,7 @@ def _hyp2f1_recurrence(a, b, c, z):
255
255
256
256
257
257
@numba .njit (
258
- "UniTuple(float64, 6 )(float64, float64, float64, float64, float64, float64)"
258
+ "UniTuple(float64, 7 )(float64, float64, float64, float64, float64, float64)"
259
259
)
260
260
def _hyp2f1_dlmf1583_first (a_i , b_i , a_j , b_j , y , mu ):
261
261
"""
@@ -287,21 +287,26 @@ def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
287
287
)
288
288
289
289
# 2F1(a, -y; c; z) via backwards recurrence
290
- val , sign , da , _ , dc , dz , _ = _hyp2f1_recurrence (a , y , c , z )
290
+ val , sign , da , _ , dc , dz , d2z = _hyp2f1_recurrence (a , y , c , z )
291
291
292
292
# map gradient to parameters
293
293
da_i = dc - _digamma (a_i + a_j ) + _digamma (a_i )
294
294
da_j = da + dc - np .log (s ) + _digamma (a_j + y + 1 ) - _digamma (a_i + a_j )
295
295
db_i = dz / (b_j - mu ) + a_j / (mu + b_i )
296
296
db_j = dz * (1 - z ) / (b_j - mu ) - a_j / s / (mu + b_i )
297
297
298
+ # needed to verify result
299
+ d2b_j = (1 - z ) / (b_j - mu ) ** 2 * (d2z * (1 - z ) - 2 * dz * (1 + a_j )) + (
300
+ 1 + a_j
301
+ ) * a_j / (b_j - mu ) ** 2
302
+
298
303
val += scale
299
304
300
- return val , sign , da_i , db_i , da_j , db_j
305
+ return val , sign , da_i , db_i , da_j , db_j , d2b_j
301
306
302
307
303
308
@numba .njit (
304
- "UniTuple(float64, 6 )(float64, float64, float64, float64, float64, float64)"
309
+ "UniTuple(float64, 7 )(float64, float64, float64, float64, float64, float64)"
305
310
)
306
311
def _hyp2f1_dlmf1583_second (a_i , b_i , a_j , b_j , y , mu ):
307
312
"""
@@ -320,18 +325,24 @@ def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
320
325
)
321
326
322
327
# 2F1(a, y+1; c; z) via series expansion
323
- val , sign , da , _ , dc , dz , _ = _hyp2f1_taylor_series (a , y + 1 , c , z )
328
+ val , sign , da , _ , dc , dz , d2z = _hyp2f1_taylor_series (a , y + 1 , c , z )
324
329
325
330
# map gradient to parameters
326
331
da_i = da + np .log (z ) + dc + _digamma (a_i ) - _digamma (a_i + y + 1 )
327
332
da_j = da + np .log (z ) + _digamma (a_j + y + 1 ) - _digamma (a_j )
328
333
db_i = (1 - z ) * (dz + a / z ) / (b_i + b_j )
329
334
db_j = - z * (dz + a / z ) / (b_i + b_j )
330
335
336
+ # needed to verify result
337
+ d2b_j = (
338
+ z / (b_i + b_j ) ** 2 * (d2z * z + 2 * dz * (1 + a ))
339
+ + a * (1 + a ) / (b_i + b_j ) ** 2
340
+ )
341
+
331
342
sign *= (- 1 ) ** (y + 1 )
332
343
val += scale
333
344
334
- return val , sign , da_i , db_i , da_j , db_j
345
+ return val , sign , da_i , db_i , da_j , db_j , d2b_j
335
346
336
347
337
348
@numba .njit (
@@ -345,42 +356,15 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
345
356
assert 0 <= mu <= b_j
346
357
assert y >= 0 and y % 1.0 == 0.0
347
358
348
- f_1 , s_1 , da_i_1 , db_i_1 , da_j_1 , db_j_1 = _hyp2f1_dlmf1583_first (
359
+ f_1 , s_1 , da_i_1 , db_i_1 , da_j_1 , db_j_1 , d2b_j_1 = _hyp2f1_dlmf1583_first (
349
360
a_i , b_i , a_j , b_j , y , mu
350
361
)
351
362
352
- f_2 , s_2 , da_i_2 , db_i_2 , da_j_2 , db_j_2 = _hyp2f1_dlmf1583_second (
363
+ f_2 , s_2 , da_i_2 , db_i_2 , da_j_2 , db_j_2 , d2b_j_2 = _hyp2f1_dlmf1583_second (
353
364
a_i , b_i , a_j , b_j , y , mu
354
365
)
355
366
356
367
f_0 = max (f_1 , f_2 )
357
-
358
- # 2sum
359
- aa = f_1
360
- bb = - 1 * f_0
361
- s = aa + bb
362
- ap = s - bb
363
- bp = s - ap
364
- da = aa - ap
365
- db = bb - bp
366
- t = da + db
367
- print ("2sum" , s , t )
368
-
369
- aa = f_2
370
- bb = - 1 * f_0
371
- s = aa + bb
372
- ap = s - bb
373
- bp = s - ap
374
- da = aa - ap
375
- db = bb - bp
376
- t = da + db
377
- print ("2sum" , s , t )
378
- # /2sum
379
-
380
- # if np.abs(f_1 - f_2) < _HYP2F1_TOL:
381
- # # TODO: detect a priori if this will occur
382
- # raise Invalid2F1("Singular hypergeometric function")
383
-
384
368
f_1 = np .exp (f_1 - f_0 ) * s_1
385
369
f_2 = np .exp (f_2 - f_0 ) * s_2
386
370
f = f_1 + f_2
@@ -389,10 +373,22 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
389
373
db_i = (db_i_1 * f_1 + db_i_2 * f_2 ) / f
390
374
da_j = (da_j_1 * f_1 + da_j_2 * f_2 ) / f
391
375
db_j = (db_j_1 * f_1 + db_j_2 * f_2 ) / f
376
+ d2b_j = (d2b_j_1 * f_1 + d2b_j_2 * f_2 ) / f
392
377
393
378
sign = np .sign (f )
394
379
val = np .log (np .abs (f )) + f_0
395
380
381
+ # use first/second derivatives to check that result is non-singular
382
+ dz = - db_j * (mu + b_i )
383
+ d2z = d2b_j * (mu + b_i ) ** 2
384
+ if (
385
+ not _is_valid_2f1 (
386
+ dz , d2z , a_j , a_i + a_j + y , a_j + y + 1 , (mu - b_j ) / (mu + b_i )
387
+ )
388
+ or sign <= 0
389
+ ):
390
+ raise Invalid2F1 ("Hypergeometric series did not converge" )
391
+
396
392
return val , sign , da_i , db_i , da_j , db_j
397
393
398
394
0 commit comments