@@ -388,13 +388,14 @@ def _betainc_partials(a, b, x):
388
388
389
389
# The partial derivatives of betainc with respect to a and b are computed
390
390
# using forward mode.
391
- region_ps = ((x < a / (a + b )) & (b * x <= 1. ) & (x <= 0.95 ) |
391
+ use_power_series = ((x < a / (a + b )) & (b * x <= 1. ) & (x <= 0.95 ) |
392
392
((x >= a / (a + b )) & (a * (1. - x ) <= 1. ) & (x >= 0.05 )))
393
- ps_grad_a , ps_grad_b = _betainc_der_power_series (a , b , x , dtype , region_ps )
393
+ ps_grad_a , ps_grad_b = _betainc_der_power_series (
394
+ a , b , x , dtype , use_power_series )
394
395
cf_grad_a , cf_grad_b = _betainc_der_continued_fraction (
395
- a , b , x , dtype , ~ region_ps )
396
- grad_a = tf .where (region_ps , ps_grad_a , cf_grad_a )
397
- grad_b = tf .where (region_ps , ps_grad_b , cf_grad_b )
396
+ a , b , x , dtype , ~ use_power_series )
397
+ grad_a = tf .where (use_power_series , ps_grad_a , cf_grad_a )
398
+ grad_b = tf .where (use_power_series , ps_grad_b , cf_grad_b )
398
399
399
400
# According to the code accompanying [1], grad_a = grad_b = 0 when x is
400
401
# equal to 0 or 1. Under the same condition, grad_x = 0 by its expression.
0 commit comments