|
22 | 22 | import numpy as np
|
23 | 23 | import scipy.optimize as optimize
|
24 | 24 |
|
| 25 | +import tensorflow.compat.v1 as tf1 |
25 | 26 | import tensorflow.compat.v2 as tf
|
26 | 27 | import tensorflow_probability as tfp
|
27 | 28 |
|
@@ -172,6 +173,17 @@ def test_secant_invalid_max_iterations(self):
|
172 | 173 | tfp.math.find_root_secant(
|
173 | 174 | f, guess, max_iterations=-1, validate_args=True))
|
174 | 175 |
|
| 176 | + def test_secant_non_static_shape(self): |
| 177 | + if tf.executing_eagerly(): |
| 178 | + self.skipTest('Test uses dynamic shapes.') |
| 179 | + |
| 180 | + f = lambda x: (x - 1.) * (x + 1) |
| 181 | + initial_position = tf1.placeholder_with_default([1., 1., 1.], shape=None) |
| 182 | + self.assertAllClose( |
| 183 | + tfp.math.find_root_secant( |
| 184 | + f, initial_position).objective_at_estimated_root, |
| 185 | + [0., 0., 0.]) |
| 186 | + |
175 | 187 |
|
176 | 188 | @test_util.test_all_tf_execution_regimes
|
177 | 189 | class ChandrupatlaRootSearchTest(test_util.TestCase):
|
@@ -266,6 +278,18 @@ def test_chandrupatla_automatically_selects_bounds(self):
|
266 | 278 | position_tolerance=1e-8)
|
267 | 279 | self.assertAllClose(value_at_roots, tf.zeros_like(value_at_roots))
|
268 | 280 |
|
| 281 | + def test_chandrupatla_non_static_shape(self): |
| 282 | + if tf.executing_eagerly(): |
| 283 | + self.skipTest('Test uses dynamic shapes.') |
| 284 | + |
| 285 | + f = lambda x: (x - 1.) * (x + 1) |
| 286 | + low = tf1.placeholder_with_default([-100., -100., -100.], shape=None) |
| 287 | + high = tf1.placeholder_with_default([100., 100., 100.], shape=None) |
| 288 | + self.assertAllClose( |
| 289 | + tfp.math.find_root_chandrupatla( |
| 290 | + f, low=low, high=high).objective_at_estimated_root, |
| 291 | + [0., 0., 0.]) |
| 292 | + |
269 | 293 |
|
270 | 294 | @test_util.test_all_tf_execution_regimes
|
271 | 295 | class BracketRootTest(test_util.TestCase):
|
|
0 commit comments