Skip to content

Commit c4c3425

Browse files
Replace region_ps by use_power_series
1 parent 8d631cd commit c4c3425

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tensorflow_probability/python/math/special.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,14 @@ def _betainc_partials(a, b, x):
388388

389389
# The partial derivatives of betainc with respect to a and b are computed
390390
# using forward mode.
391-
region_ps = ((x < a / (a + b)) & (b * x <= 1.) & (x <= 0.95) |
391+
use_power_series = ((x < a / (a + b)) & (b * x <= 1.) & (x <= 0.95) |
392392
((x >= a / (a + b)) & (a * (1. - x) <= 1.) & (x >= 0.05)))
393-
ps_grad_a, ps_grad_b = _betainc_der_power_series(a, b, x, dtype, region_ps)
393+
ps_grad_a, ps_grad_b = _betainc_der_power_series(
394+
a, b, x, dtype, use_power_series)
394395
cf_grad_a, cf_grad_b = _betainc_der_continued_fraction(
395-
a, b, x, dtype, ~region_ps)
396-
grad_a = tf.where(region_ps, ps_grad_a, cf_grad_a)
397-
grad_b = tf.where(region_ps, ps_grad_b, cf_grad_b)
396+
a, b, x, dtype, ~use_power_series)
397+
grad_a = tf.where(use_power_series, ps_grad_a, cf_grad_a)
398+
grad_b = tf.where(use_power_series, ps_grad_b, cf_grad_b)
398399

399400
# According to the code accompanying [1], grad_a = grad_b = 0 when x is
400401
# equal to 0 or 1. Under the same condition, grad_x = 0 by its expression.

0 commit comments

Comments
 (0)