Skip to content

Commit 6f39638

Browse files
Improve names in _betainc_der_power_series
1 parent ba137b3 commit 6f39638

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

tensorflow_probability/python/math/special.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,32 +351,32 @@ def _betainc_der_power_series(a, b, x, dtype, use_power_series):
351351
# 2F1(a, 1 - b; a + 1; x) / a
352352
def power_series_evaluation(should_stop, values, gradients):
353353
n, product, series_sum = values
354-
dpdb, da, db = gradients
354+
product_grad_b, da, db = gradients
355355

356356
x_div_n = safe_x / n
357357
factor = (n - safe_b) * x_div_n
358358
apn = safe_a + n
359359

360360
new_product = product * factor
361361
term = new_product / apn
362-
new_dpdb = factor * dpdb - product * x_div_n
362+
new_product_grad_b = factor * product_grad_b - product * x_div_n
363363
new_da = da - new_product / tf.math.square(apn)
364-
new_db = db + new_dpdb / apn
364+
new_db = db + new_product_grad_b / apn
365365

366366
values = n + one, new_product, series_sum + term
367-
gradients = new_dpdb, new_da, new_db
367+
gradients = new_product_grad_b, new_da, new_db
368368

369369
return should_stop | (tf.math.abs(term) <= tolerance), values, gradients
370370

371-
n = one
372-
product = tf.ones_like(safe_a)
373-
series_sum = one / safe_a
374-
initial_values = (n, product, series_sum)
371+
initial_n = one
372+
initial_product = tf.ones_like(safe_a)
373+
initial_series_sum = one / safe_a
374+
initial_values = (initial_n, initial_product, initial_series_sum)
375375

376-
dpdb = tf.zeros_like(safe_b)
377-
da = -tf.math.reciprocal(tf.math.square(safe_a))
378-
db = dpdb
379-
initial_gradients = (dpdb, da, db)
376+
initial_product_grad_b = tf.zeros_like(safe_b)
377+
initial_da = -tf.math.reciprocal(tf.math.square(safe_a))
378+
initial_db = initial_product_grad_b
379+
initial_gradients = (initial_product_grad_b, initial_da, initial_db)
380380

381381
(_, values, gradients) = tf.while_loop(
382382
cond=lambda stop, *_: tf.reduce_any(~stop),

0 commit comments

Comments
 (0)