Skip to content

Commit 6fc9a9e

Browse files
srvasudetensorflower-gardener
authored andcommitted
Fix Hager-Zhang linesearch to accept intervals with zero derivative for the right endpoint.
- This improves performance of L-BFGS / BFGS substantially on test problems. PiperOrigin-RevId: 492308325
1 parent ee8fbbe commit 6fc9a9e

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

tensorflow_probability/python/optimizer/linesearch/hager_zhang_test.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,30 @@ def fdf(x):
122122
sum(r.func_evals for r in results_mapped))
123123

124124
def test_batch_bracket_failures(self):
125-
# To bracket successfully, we must find the narrow window with positive
126-
# derivative -- roughly [1.39, 2.72].
125+
# To bracket successfully, we must find the narrow window with non-negative
126+
# derivative -- good values are roughly [1.39, 2.72].
127127
def _fdf(x):
128128
z = x - 1
129129
return ValueAndGradient(
130130
x=x,
131131
f=tf.math.exp(-z) - tf.math.exp(-z*z),
132132
df=2*z*tf.math.exp(-z*z) - tf.math.exp(-z))
133133

134-
start = tf.convert_to_tensor([0.01, 0.1, 1.0, 1.5, 2.0, 3.0])
134+
start = tf.convert_to_tensor([0.01, 0.1, 1.0, 1.5, 2.0, -5.0])
135135
results = self.evaluate(hager_zhang(
136136
_fdf, initial_step_size=start))
137137

138138
# Bracketing will do something like: check `5^0 * start`, `5^1 * start`,
139-
# `5^2 * start`, ..., looking for a point where the derivative is positive.
140-
# This search will find a point with positive derivative when `start` is
141-
# `0.1`, `1.5`, or `2.0`, but will fail for the other values.
142-
self.assertAllEqual([False, True, False, True, True, False],
143-
results.converged)
144-
self.assertAllEqual([True, False, True, False, False, True],
145-
results.failed)
139+
# `5^2 * start`, ..., looking for a point where the derivative is
140+
# non-negative. The search will start to fail for negative values where the
141+
# function is highly positive and goes steeply downward (but since it's far
142+
# enough out fails to bracket).
143+
self.assertAllEqual(
144+
[True, True, True, True, True, False],
145+
results.converged)
146+
self.assertAllEqual(
147+
[False, False, False, False, False, True],
148+
results.failed)
146149

147150
val_0 = self.evaluate(_fdf(tf.convert_to_tensor(0.0)))
148151
self.assertAllEqual(

tensorflow_probability/python/optimizer/linesearch/internal/hager_zhang_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def _is_rising(val):
723723
rising: A Boolean Tensor giving whether this point is a suitable right
724724
end-point for an interval subject to secant subdivision.
725725
"""
726-
return tf.math.is_finite(val.f) & (val.df > 0)
726+
return tf.math.is_finite(val.f) & (val.df >= 0.)
727727

728728

729729
def is_finite(val_1, val_2=None):

tensorflow_probability/python/optimizer/linesearch/internal/hager_zhang_lib_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,51 @@ def test_bracket_simple(self):
252252
self.assertLess(result.left.df, 0) # Opposite slopes.
253253
self.assertGreaterEqual(result.right.df, 0)
254254

255+
def test_bracket_accepts_interval_zero_derivative(self):
256+
"""Tests that bracketing accepts f' = 0 for the right endpoint."""
257+
wolfe_threshold = 1e-6
258+
259+
# This example is taken from the unconstrained beale function.
260+
def beale(z):
261+
# Constrain to [-4.5, 4.5]
262+
z = 4.5 * tf.math.sigmoid(z) - 4.5 * tf.math.sigmoid(-z)
263+
x = z[..., 0]
264+
y = z[..., 1]
265+
return ((1.5 - x + x * y)**2 +
266+
(2.25 - x + x * y**2)**2 +
267+
(2.625 - x + x * y**3)**2)
268+
269+
def beale_ls(t):
270+
t = tf.convert_to_tensor(t, dtype=tf.float32)
271+
def _internal_ls(t):
272+
# Choose an initial point and step such that the step goes out towards
273+
# infinity. In that way, we guarantee the gradients are zero at the
274+
# step but aren't a suitable minima as the function increases away
275+
# from the point (3., 0.5).
276+
x = np.array([0.6, 1.35]).astype(np.float32)
277+
# Large step that pushes the function to a flat region of space.
278+
p = np.array([-100., -100.]).astype(np.float32)
279+
return beale(x + t * p)
280+
f, df = value_and_gradient(_internal_ls, t)
281+
return ValueAndGradient(x=t, f=tf.squeeze(f), df=tf.squeeze(df))
282+
283+
val_a = beale_ls(0.0) # Value at zero.
284+
val_b = beale_ls(1.0) # Value at initial step.
285+
f_lim = val_a.f + (wolfe_threshold * tf.abs(val_a.f))
286+
287+
result = self.evaluate(
288+
hzl.bracket(beale_ls, _interval(val_a, val_b), f_lim, max_iterations=5))
289+
290+
# The left endpoint has negative derivative, the right has zero derivative.
291+
# This should be a valid interval a priori.
292+
self.assertFalse(result.failed)
293+
self.assertEqual(result.iteration, 0) # Zero expansion.
294+
self.assertEqual(result.num_evals, 0) # Zero evaluations.
295+
self.assertEqual(result.left.x, 0.)
296+
self.assertEqual(result.right.x, 1.)
297+
self.assertLess(result.left.df, 0) # Opposite slopes.
298+
self.assertGreaterEqual(result.right.df, 0)
299+
255300
def test_bracket_batching(self):
256301
"""Tests that bracketing works in batching mode."""
257302
wolfe_threshold = 1e-6

0 commit comments

Comments
 (0)