Skip to content

Commit 7df72d0

Browse files
author
abhinavd
committed
Making compatible with older JAX versions
1 parent edbf1d9 commit 7df72d0

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tensorcircuit/backends/jax_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,12 @@ def _QrGradSquareAndDeepMatrices(q: Array, r: Array, dq: Array, dr: Array) -> Ar
151151

152152
@jax.custom_vjp
153153
def adaware_eigh(A: Array) -> Array:
154-
result = jnp.linalg.eigh(A)
155-
e = result.eigenvalues
156-
v = result.eigenvectors
154+
e, v = jnp.linalg.eigh(A)
157155
return e, v
158156

159157

160158
def jaxeigh_fwd(A: Array) -> Array:
161-
result = jnp.linalg.eigh(A)
162-
e = result.eigenvalues
163-
v = result.eigenvectors
159+
e, v = jnp.linalg.eigh(A)
164160
return (e, v), (A, e, v)
165161

166162

0 commit comments

Comments
 (0)