@@ -454,8 +454,9 @@ def low_rank_cholesky(matrix, max_rank, trace_atol=0, trace_rtol=0, name=None):
454454 dtype_hint = tf .float32 )
455455 if not isinstance (matrix , tf .linalg .LinearOperator ):
456456 matrix = tf .convert_to_tensor (matrix , name = 'matrix' , dtype = dtype )
457+ matrix = tf .linalg .LinearOperatorFullMatrix (matrix )
457458
458- mtrace = tf . linalg . trace (matrix )
459+ mtrace = matrix . trace ()
459460 mrank = tensorshape_util .rank (matrix .shape )
460461 batch_dims = mrank - 2
461462
@@ -485,7 +486,7 @@ def lr_cholesky_body(i, lr, residual_diag):
485486 matrix_row = tf .squeeze (matrix .row (max_j ), axis = - 2 )
486487 else :
487488 matrix_row = tf .gather (
488- matrix , max_j , axis = - 1 , batch_dims = batch_dims )[..., 0 ]
489+ matrix . to_dense () , max_j , axis = - 1 , batch_dims = batch_dims )[..., 0 ]
489490 # residual_matrix[max_j, :] = matrix_row[max_j, :] - (lr * lr^t)[max_j, :]
490491 # And (lr * lr^t)[max_j, :] = lr[max_j, :] * lr^t
491492 lr_row_maxj = tf .gather (lr , max_j , axis = - 2 , batch_dims = batch_dims )
@@ -530,7 +531,7 @@ def lr_cholesky_body(i, lr, residual_diag):
530531
531532 lr = tf .zeros (matrix .shape , dtype = matrix .dtype )[..., :max_rank ]
532533
533- mdiag = tf . linalg . diag_part (matrix )
534+ mdiag = matrix . diag_part ()
534535 i , lr , residual_diag = tf .while_loop (
535536 cond = lr_cholesky_cond ,
536537 body = lr_cholesky_body ,
0 commit comments