Skip to content

Commit 3871387

Browse files
oskarfernlundjburnim
authored andcommitted
changed signature of pivoted_cholesky
1 parent 133809b commit 3871387

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

tensorflow_probability/python/math/linalg.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ def _invert_permutation(perm): # TODO(b/130217510): Remove this function.
265265
return tf.cast(tf.argsort(perm, axis=-1), perm.dtype)
266266

267267

268-
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
268+
def pivoted_cholesky(matrix,
269+
max_rank,
270+
diag_rtol=1e-3,
271+
return_pivoting_order=False,
272+
name=None):
269273
"""Computes the (partial) pivoted cholesky decomposition of `matrix`.
270274
271275
The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
@@ -290,10 +294,13 @@ def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
290294
diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
291295
errors of all diagonal elements of `lr @ lr.T` are each lower than
292296
`element * diag_rtol`, iteration is permitted to terminate early.
297+
return_pivoting_order: If `True`, return an `int` `Tensor` indicating the pivoting
298+
order used to produce `lr` in addition to `lr` (defaults to `False`).
293299
name: Optional name for the op.
294300
295301
Returns:
296302
lr: Low rank pivoted Cholesky approximation of `matrix`.
303+
perm: (Optional) pivoting order used to produce `lr`.
297304
298305
#### References
299306
@@ -405,7 +412,11 @@ def body(m, pchol, perm, matrix_diag):
405412
pchol = tf.linalg.matrix_transpose(pchol)
406413
tensorshape_util.set_shape(
407414
pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
408-
return pchol
415+
416+
if return_pivoting_order:
417+
return pchol, perm
418+
else:
419+
return pchol
409420

410421

411422
def low_rank_cholesky(matrix, max_rank, trace_atol=0, trace_rtol=0, name=None):

0 commit comments

Comments
 (0)