Skip to content

Commit d3697a9

Browse files
Rename variable where
1 parent 2aac7fd commit d3697a9

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tensorflow_probability/python/math/special.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def odd_partial_numerator():
147147
iteration_is_even, even_partial_numerator, odd_partial_numerator)
148148

149149

150-
def _betainc_modified_lentz_method(a, b, x, dtype, where):
150+
def _betainc_modified_lentz_method(a, b, x, dtype, use_continued_fraction):
151151
"""Returns the continued fraction for betainc by modified Lentz's method."""
152152
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
153153
one = tf.constant(1., dtype=dtype)
@@ -192,7 +192,8 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
192192

193193
# Assume all input Tensors have the same shape. The extra dimension is
194194
# needed to compute the gradients with respect to a and b.
195-
a, b, x, where = (z[..., tf.newaxis] for z in (a, b, x, where))
195+
a, b, x, use_continued_fraction = [
196+
z[..., tf.newaxis] for z in (a, b, x, use_continued_fraction)]
196197

197198
apb = a + b
198199
ap1 = a + one
@@ -214,7 +215,7 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
214215
cond=lambda stop, *_: tf.reduce_any(~stop),
215216
body=continued_fraction_evaluation,
216217
loop_vars=(
217-
~where,
218+
~use_continued_fraction,
218219
tf.constant(2., dtype=dtype),
219220
initial_values,
220221
initial_gradients),
@@ -227,7 +228,7 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
227228
return f, f_grad_a, f_grad_b
228229

229230

230-
def _betainc_der_continued_fraction(a, b, x, dtype, where):
231+
def _betainc_der_continued_fraction(a, b, x, dtype, use_continued_fraction):
231232
"""Returns the partial derivatives of betainc with respect to a and b."""
232233
# This function evaluates betainc(a, b, x) by its continued fraction
233234
# expansion given here: https://dlmf.nist.gov/8.17.E22
@@ -247,7 +248,7 @@ def _betainc_der_continued_fraction(a, b, x, dtype, where):
247248
x = tf.where(use_symmetry_relation, one - x, x)
248249

249250
cf, cf_grad_a, cf_grad_b = _betainc_modified_lentz_method(
250-
a, b, x, dtype, where)
251+
a, b, x, dtype, use_continued_fraction)
251252

252253
normalization = tf.math.exp(
253254
tf.math.xlogy(a, x) + tf.math.xlog1py(b, -x) -
@@ -268,7 +269,7 @@ def _betainc_der_continued_fraction(a, b, x, dtype, where):
268269
return grad_a, grad_b
269270

270271

271-
def _betainc_der_power_series(a, b, x, dtype, where):
272+
def _betainc_der_power_series(a, b, x, dtype, use_power_series):
272273
"""Returns the partial derivatives of betainc with respect to a and b."""
273274
# This function evaluates betainc(a, b, x) by its series representation:
274275
# x ** a * 2F1(a, 1 - b; a + 1; x) / (a * B(a, b)) ,
@@ -284,9 +285,9 @@ def _betainc_der_power_series(a, b, x, dtype, where):
284285

285286
# Avoid returning NaN or infinity when the input does not satisfy either
286287
# C1 or C2.
287-
safe_a = tf.where(where, a, half)
288-
safe_b = tf.where(where, b, half)
289-
safe_x = tf.where(where, x, half)
288+
safe_a = tf.where(use_power_series, a, half)
289+
safe_b = tf.where(use_power_series, b, half)
290+
safe_x = tf.where(use_power_series, x, half)
290291

291292
# When the condition C1 is false, we apply the symmetry relation given
292293
# here: http://dlmf.nist.gov/8.17.E4
@@ -336,7 +337,7 @@ def power_series_evaluation(should_stop, values, gradients):
336337
cond=lambda stop, *_: tf.reduce_any(~stop),
337338
body=power_series_evaluation,
338339
loop_vars=(
339-
~where,
340+
~use_power_series,
340341
initial_values,
341342
initial_gradients),
342343
maximum_iterations=max_iterations)

0 commit comments

Comments
 (0)