File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,21 @@ def sqrtmh(self: Any, a: Tensor) -> Tensor:
6060 e = self .sqrt (e )
6161 return v @ self .diagflat (e ) @ self .adjoint (v )
6262
63+ def sqrtmhpos (self : Any , a : Tensor ) -> Tensor :
64+ """
65+ Return the sqrtm of a known-to-be PSD Hermitian matrix ``a``.
66+
67+ :param a: tensor in matrix form
68+ :type a: Tensor
69+ :return: sqrtm of ``a`` after setting the negative eigenvalues (if they exist) to zero
70+ :rtype: Tensor
71+ """
72+ # maybe friendly for AD and also cosidering that several backend has no support for native sqrtm
73+ e , v = self .eigh (a )
74+ e = self .relu (e )
75+ e = self .sqrt (e )
76+ return v @ self .diagflat (e ) @ self .adjoint (v )
77+
6378 def eigvalsh (self : Any , a : Tensor ) -> Tensor :
6479 """
6580 Get the eigenvalues of matrix ``a``.
You can’t perform that action at this time.
0 commit comments