Skip to content

Commit 57fa903

Browse files
midfieldtensorflower-gardener
authored andcommitted
Use foldl in no_pivot_ldl instead of while_loop.
PiperOrigin-RevId: 385835428
1 parent 56aece2 commit 57fa903

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

tensorflow_probability/python/experimental/linalg/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,6 @@ py_test(
8383
"//tensorflow_probability/python/experimental/linalg:no_pivot_ldl",
8484
"//tensorflow_probability/python/internal:tensorshape_util",
8585
"//tensorflow_probability/python/internal:test_util",
86+
# "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport
8687
],
8788
)

tensorflow_probability/python/experimental/linalg/no_pivot_ldl.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ def no_pivot_ldl(matrix, name='no_pivot_ldl'):
5959
6060
Performs the LDL factorization, using the outer product algorithm from [1]. No
6161
pivoting (or block pivoting) is done, so this should be less stable than
62-
e.g. Bunch-Kaufman sytrf. This is implemented as a tf.while_loop, so should
63-
have gradients and be accelerator-friendly, but is not particularly
64-
performant.
62+
e.g. Bunch-Kaufman sytrf. This is implemented as a tf.foldl, so should have
63+
gradients and be accelerator-friendly, but is not particularly performant.
64+
65+
If compiling with XLA, make sure any surrounding GradientTape is also
66+
XLA-compiled (b/193584244).
6567
6668
#### References
6769
[1]: Gene H. Golub, Charles F. Van Loan. Matrix Computations, 4th ed., 2013.
@@ -83,7 +85,7 @@ def no_pivot_ldl(matrix, name='no_pivot_ldl'):
8385
# TODO(b/182276317) Deal with dynamic ranks better.
8486
slix = _Slice2Idx(triangular_factor)
8587

86-
def body(i, triangular_factor):
88+
def fn(triangular_factor, i):
8789
column_head = triangular_factor[..., i, i, tf.newaxis]
8890
column_tail = triangular_factor[..., i+1:, i]
8991
rescaled_tail = column_tail / column_head
@@ -97,12 +99,12 @@ def body(i, triangular_factor):
9799
tf.linalg.band_part(
98100
tf.einsum('...i,...j->...ij', column_tail, rescaled_tail),
99101
num_lower=-1, num_upper=0))
100-
return i+1, triangular_factor
102+
return triangular_factor
101103

102-
_, triangular_factor = tf.while_loop(
103-
cond=lambda i, _: i < tf.shape(triangular_factor)[-1],
104-
body=body,
105-
loop_vars=(0, triangular_factor))
104+
triangular_factor = tf.foldl(
105+
fn=fn,
106+
elems=tf.range(tf.shape(triangular_factor)[-1]),
107+
initializer=triangular_factor)
106108

107109
diag = tf.linalg.diag_part(triangular_factor)
108110
triangular_factor = tf.linalg.set_diag(

tensorflow_probability/python/experimental/linalg/no_pivot_ldl_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,34 @@ def testSimpleIndefinite(self):
8080
eigv, _ = self.evaluate(tf.linalg.eigh(reconstruct))
8181
self.assertAllTrue(eigv > 0.)
8282

83+
def testXlaCompileBug(self):
84+
inp = tf.Variable([[2., 1.], [1., 2.]])
85+
self.evaluate(inp.initializer)
86+
alt_chol = simple_robustified_cholesky
87+
alt_chol_nojit = tf.function(alt_chol, autograph=False, jit_compile=False)
88+
alt_chol_jit = tf.function(alt_chol, autograph=False, jit_compile=True)
89+
answer = np.array([[1.4142135, 0.], [0.70710677, 1.2247449]])
90+
91+
self.assertAllClose(self.evaluate(alt_chol(inp)), answer)
92+
self.assertAllClose(self.evaluate(alt_chol_nojit(inp)), answer)
93+
self.assertAllClose(self.evaluate(alt_chol_jit(inp)), answer)
94+
95+
with tf.GradientTape():
96+
chol_with_grad = alt_chol(inp)
97+
chol_nojit_with_grad = alt_chol_nojit(inp)
98+
# Not supported by TF-XLA (WAI), see b/193584244
99+
# chol_jit_with_grad = alt_chol_jit(inp)
100+
self.assertAllClose(self.evaluate(chol_with_grad), answer)
101+
self.assertAllClose(self.evaluate(chol_nojit_with_grad), answer)
102+
103+
# But wrapping the tape in tf.function should work.
104+
@tf.function(autograph=False, jit_compile=True)
105+
def jit_with_grad(mat):
106+
with tf.GradientTape():
107+
return alt_chol_jit(mat)
108+
109+
self.assertAllClose(self.evaluate(jit_with_grad(inp)), answer)
110+
83111

84112
if __name__ == '__main__':
85113
tf.test.main()

0 commit comments

Comments
 (0)