@@ -265,7 +265,11 @@ def _invert_permutation(perm): # TODO(b/130217510): Remove this function.
265265 return tf .cast (tf .argsort (perm , axis = - 1 ), perm .dtype )
266266
267267
268- def pivoted_cholesky (matrix , max_rank , diag_rtol = 1e-3 , name = None ):
268+ def pivoted_cholesky (matrix ,
269+ max_rank ,
270+ diag_rtol = 1e-3 ,
271+ return_pivoting_order = False ,
272+ name = None ):
269273 """Computes the (partial) pivoted cholesky decomposition of `matrix`.
270274
271275 The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
@@ -290,10 +294,13 @@ def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
290294 diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
291295 errors of all diagonal elements of `lr @ lr.T` are each lower than
292296 `element * diag_rtol`, iteration is permitted to terminate early.
297+ return_pivoting_order: If `True`, return an `int` `Tensor` indicating the pivoting
298+ order used to produce `lr` in addition to `lr` (defaults to `False`).
293299 name: Optional name for the op.
294300
295301 Returns:
296302 lr: Low rank pivoted Cholesky approximation of `matrix`.
303+ perm: (Optional) pivoting order used to produce `lr`.
297304
298305 #### References
299306
@@ -405,7 +412,11 @@ def body(m, pchol, perm, matrix_diag):
405412 pchol = tf .linalg .matrix_transpose (pchol )
406413 tensorshape_util .set_shape (
407414 pchol , tensorshape_util .concatenate (matrix_diag .shape , [None ]))
408- return pchol
415+
416+ if return_pivoting_order :
417+ return pchol , perm
418+ else :
419+ return pchol
409420
410421
411422def low_rank_cholesky (matrix , max_rank , trace_atol = 0 , trace_rtol = 0 , name = None ):
0 commit comments