Skip to content

Commit 636693f

Browse files
davmretensorflower-gardener
authored andcommitted
Support dynamic shape in tfp.math.find_root_secant.
PiperOrigin-RevId: 379291021
1 parent c5ef7e8 commit 636693f

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

tensorflow_probability/python/math/root_search.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,18 +218,17 @@ def f(x):
218218
num_iterations = tf.zeros_like(position, dtype=tf.int32)
219219
max_iterations = tf.convert_to_tensor(max_iterations, dtype=tf.int32)
220220
max_iterations = tf.broadcast_to(
221-
max_iterations, name='max_iterations', shape=position.shape)
221+
max_iterations, name='max_iterations', shape=ps.shape(position))
222222

223223
# Compute the step from `next_position` if present. This covers the case where
224224
# a user has two starting points, which bound the root or has a specific step
225225
# size in mind.
226226
if next_position is None:
227-
epsilon = tf.constant(1e-4, dtype=position.dtype, shape=position.shape)
228-
step = position * epsilon + tf.sign(position) * epsilon
227+
step = (position + tf.sign(position)) * 1e-4
229228
else:
230229
step = next_position - initial_position
231230

232-
finished = tf.constant(False, shape=position.shape)
231+
finished = tf.zeros(ps.shape(position), dtype=tf.bool)
233232

234233
# Negate `stopping_condition` to determine if the search should continue.
235234
# This means, in particular, that tf.reduce_*all* will return only when the

tensorflow_probability/python/math/root_search_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import scipy.optimize as optimize
2424

25+
import tensorflow.compat.v1 as tf1
2526
import tensorflow.compat.v2 as tf
2627
import tensorflow_probability as tfp
2728

@@ -172,6 +173,17 @@ def test_secant_invalid_max_iterations(self):
172173
tfp.math.find_root_secant(
173174
f, guess, max_iterations=-1, validate_args=True))
174175

176+
def test_secant_non_static_shape(self):
177+
if tf.executing_eagerly():
178+
self.skipTest('Test uses dynamic shapes.')
179+
180+
f = lambda x: (x - 1.) * (x + 1)
181+
initial_position = tf1.placeholder_with_default([1., 1., 1.], shape=None)
182+
self.assertAllClose(
183+
tfp.math.find_root_secant(
184+
f, initial_position).objective_at_estimated_root,
185+
[0., 0., 0.])
186+
175187

176188
@test_util.test_all_tf_execution_regimes
177189
class ChandrupatlaRootSearchTest(test_util.TestCase):
@@ -266,6 +278,18 @@ def test_chandrupatla_automatically_selects_bounds(self):
266278
position_tolerance=1e-8)
267279
self.assertAllClose(value_at_roots, tf.zeros_like(value_at_roots))
268280

281+
def test_chandrupatla_non_static_shape(self):
282+
if tf.executing_eagerly():
283+
self.skipTest('Test uses dynamic shapes.')
284+
285+
f = lambda x: (x - 1.) * (x + 1)
286+
low = tf1.placeholder_with_default([-100., -100., -100.], shape=None)
287+
high = tf1.placeholder_with_default([100., 100., 100.], shape=None)
288+
self.assertAllClose(
289+
tfp.math.find_root_chandrupatla(
290+
f, low=low, high=high).objective_at_estimated_root,
291+
[0., 0., 0.])
292+
269293

270294
@test_util.test_all_tf_execution_regimes
271295
class BracketRootTest(test_util.TestCase):

0 commit comments

Comments
 (0)