@@ -147,7 +147,7 @@ def odd_partial_numerator():
147
147
iteration_is_even , even_partial_numerator , odd_partial_numerator )
148
148
149
149
150
- def _betainc_modified_lentz_method (a , b , x , dtype , where ):
150
+ def _betainc_modified_lentz_method (a , b , x , dtype , use_continued_fraction ):
151
151
"""Returns the continued fraction for betainc by modified Lentz's method."""
152
152
numpy_dtype = dtype_util .as_numpy_dtype (dtype )
153
153
one = tf .constant (1. , dtype = dtype )
@@ -192,7 +192,8 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
192
192
193
193
# Assume all input Tensors have the same shape. The extra dimension is
194
194
# needed to compute the gradients with respect to a and b.
195
- a , b , x , where = (z [..., tf .newaxis ] for z in (a , b , x , where ))
195
+ a , b , x , use_continued_fraction = [
196
+ z [..., tf .newaxis ] for z in (a , b , x , use_continued_fraction )]
196
197
197
198
apb = a + b
198
199
ap1 = a + one
@@ -214,7 +215,7 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
214
215
cond = lambda stop , * _ : tf .reduce_any (~ stop ),
215
216
body = continued_fraction_evaluation ,
216
217
loop_vars = (
217
- ~ where ,
218
+ ~ use_continued_fraction ,
218
219
tf .constant (2. , dtype = dtype ),
219
220
initial_values ,
220
221
initial_gradients ),
@@ -227,7 +228,7 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
227
228
return f , f_grad_a , f_grad_b
228
229
229
230
230
- def _betainc_der_continued_fraction (a , b , x , dtype , where ):
231
+ def _betainc_der_continued_fraction (a , b , x , dtype , use_continued_fraction ):
231
232
"""Returns the partial derivatives of betainc with respect to a and b."""
232
233
# This function evaluates betainc(a, b, x) by its continued fraction
233
234
# expansion given here: https://dlmf.nist.gov/8.17.E22
@@ -247,7 +248,7 @@ def _betainc_der_continued_fraction(a, b, x, dtype, where):
247
248
x = tf .where (use_symmetry_relation , one - x , x )
248
249
249
250
cf , cf_grad_a , cf_grad_b = _betainc_modified_lentz_method (
250
- a , b , x , dtype , where )
251
+ a , b , x , dtype , use_continued_fraction )
251
252
252
253
normalization = tf .math .exp (
253
254
tf .math .xlogy (a , x ) + tf .math .xlog1py (b , - x ) -
@@ -268,7 +269,7 @@ def _betainc_der_continued_fraction(a, b, x, dtype, where):
268
269
return grad_a , grad_b
269
270
270
271
271
- def _betainc_der_power_series (a , b , x , dtype , where ):
272
+ def _betainc_der_power_series (a , b , x , dtype , use_power_series ):
272
273
"""Returns the partial derivatives of betainc with respect to a and b."""
273
274
# This function evaluates betainc(a, b, x) by its series representation:
274
275
# x ** a * 2F1(a, 1 - b; a + 1; x) / (a * B(a, b)) ,
@@ -284,9 +285,9 @@ def _betainc_der_power_series(a, b, x, dtype, where):
284
285
285
286
# Avoid returning NaN or infinity when the input does not satisfy either
286
287
# C1 or C2.
287
- safe_a = tf .where (where , a , half )
288
- safe_b = tf .where (where , b , half )
289
- safe_x = tf .where (where , x , half )
288
+ safe_a = tf .where (use_power_series , a , half )
289
+ safe_b = tf .where (use_power_series , b , half )
290
+ safe_x = tf .where (use_power_series , x , half )
290
291
291
292
# When the condition C1 is false, we apply the symmetry relation given
292
293
# here: http://dlmf.nist.gov/8.17.E4
@@ -336,7 +337,7 @@ def power_series_evaluation(should_stop, values, gradients):
336
337
cond = lambda stop , * _ : tf .reduce_any (~ stop ),
337
338
body = power_series_evaluation ,
338
339
loop_vars = (
339
- ~ where ,
340
+ ~ use_power_series ,
340
341
initial_values ,
341
342
initial_gradients ),
342
343
maximum_iterations = max_iterations )
0 commit comments