Skip to content

Commit 7a1cdfb

Browse files
Use tf.math.square
1 parent 48088de commit 7a1cdfb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tensorflow_probability/python/math/special.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
179179
delta = new_c * new_d
180180
new_f = f * delta
181181

182-
new_c_grad = (numerator_grad * c - numerator * c_grad) / (c * c)
182+
new_c_grad = (numerator_grad * c - numerator * c_grad) / tf.math.square(c)
183183
new_d_grad = -new_d * new_d * (numerator_grad * d + numerator * d_grad)
184184
new_f_grad = f_grad * delta + (f * new_c_grad * new_d) + (
185185
f * new_d_grad * new_c)
@@ -204,7 +204,7 @@ def continued_fraction_evaluation(should_stop, iteration, values, gradients):
204204
initial_f = initial_d
205205
initial_values = (initial_c, initial_d, initial_f)
206206

207-
initial_d_grad = tf.concat([one - b, ap1], axis=-1) * x / tf.square(
207+
initial_d_grad = tf.concat([one - b, ap1], axis=-1) * x / tf.math.square(
208208
x * apb - ap1)
209209
initial_c_grad = tf.zeros_like(initial_d_grad)
210210
initial_f_grad = initial_d_grad
@@ -314,7 +314,7 @@ def power_series_evaluation(should_stop, values, gradients):
314314
new_product = product * factor
315315
term = new_product / apn
316316
new_dpdb = factor * dpdb - product * x_div_n
317-
new_da = da - new_product / (apn * apn)
317+
new_da = da - new_product / tf.math.square(apn)
318318
new_db = db + new_dpdb / apn
319319

320320
values = n + one, new_product, series_sum + term
@@ -328,7 +328,7 @@ def power_series_evaluation(should_stop, values, gradients):
328328
initial_values = (n, product, series_sum)
329329

330330
dpdb = tf.zeros_like(safe_b)
331-
da = -tf.math.reciprocal(safe_a * safe_a)
331+
da = -tf.math.reciprocal(tf.math.square(safe_a))
332332
db = dpdb
333333
initial_gradients = (dpdb, da, db)
334334

0 commit comments

Comments
 (0)