Skip to content

Commit 8d631cd

Browse files
Replace apply_symmetry by use_symmetry_relation
1 parent 7a1cdfb commit 8d631cd

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

tensorflow_probability/python/math/special.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ def _betainc_der_continued_fraction(a, b, x, dtype, where):
240240
# for x < (a - 1) / (a + b - 2). For x >= (a - 1) / (a + b - 2),
241241
# we can obtain an equivalent computation by using the symmetry
242242
# relation given here: https://dlmf.nist.gov/8.17#E4
243-
apply_symmetry = (x >= (a - one) / (a + b - two))
243+
use_symmetry_relation = (x >= (a - one) / (a + b - two))
244244
a_orig = a
245-
a = tf.where(apply_symmetry, b, a)
246-
b = tf.where(apply_symmetry, a_orig, b)
247-
x = tf.where(apply_symmetry, one - x, x)
245+
a = tf.where(use_symmetry_relation, b, a)
246+
b = tf.where(use_symmetry_relation, a_orig, b)
247+
x = tf.where(use_symmetry_relation, one - x, x)
248248

249249
cf, cf_grad_a, cf_grad_b = _betainc_modified_lentz_method(
250250
a, b, x, dtype, where)
@@ -262,8 +262,8 @@ def _betainc_der_continued_fraction(a, b, x, dtype, where):
262262
# If we are taking advantage of the symmetry relation, then we have to
263263
# adjust grad_a and grad_b.
264264
grad_a_orig = grad_a
265-
grad_a = tf.where(apply_symmetry, -grad_b, grad_a)
266-
grad_b = tf.where(apply_symmetry, -grad_a_orig, grad_b)
265+
grad_a = tf.where(use_symmetry_relation, -grad_b, grad_a)
266+
grad_b = tf.where(use_symmetry_relation, -grad_a_orig, grad_b)
267267

268268
return grad_a, grad_b
269269

@@ -290,11 +290,11 @@ def _betainc_der_power_series(a, b, x, dtype, where):
290290

291291
# When the condition C1 is false, we apply the symmetry relation given
292292
# here: http://dlmf.nist.gov/8.17.E4
293-
apply_symmetry = (safe_x >= safe_a / (safe_a + safe_b))
293+
use_symmetry_relation = (safe_x >= safe_a / (safe_a + safe_b))
294294
safe_a_orig = safe_a
295-
safe_a = tf.where(apply_symmetry, safe_b, safe_a)
296-
safe_b = tf.where(apply_symmetry, safe_a_orig, safe_b)
297-
safe_x = tf.where(apply_symmetry, one - safe_x, safe_x)
295+
safe_a = tf.where(use_symmetry_relation, safe_b, safe_a)
296+
safe_b = tf.where(use_symmetry_relation, safe_a_orig, safe_b)
297+
safe_x = tf.where(use_symmetry_relation, one - safe_x, safe_x)
298298

299299
# max_iterations was set by experimentation and tolerance was taken from
300300
# Cephes.
@@ -356,8 +356,8 @@ def power_series_evaluation(should_stop, values, gradients):
356356
# If we are taking advantage of the symmetry relation, then we have to
357357
# adjust grad_a and grad_b.
358358
grad_a_orig = grad_a
359-
grad_a = tf.where(apply_symmetry, -grad_b, grad_a)
360-
grad_b = tf.where(apply_symmetry, -grad_a_orig, grad_b)
359+
grad_a = tf.where(use_symmetry_relation, -grad_b, grad_a)
360+
grad_b = tf.where(use_symmetry_relation, -grad_a_orig, grad_b)
361361

362362
return grad_a, grad_b
363363

0 commit comments

Comments
 (0)