Skip to content

Conversation

@abhinavd
Copy link

This is a bug coming from a commit on JAX, namely 012c5bd. The eigenvalues and eigenvectors are wrapped in a container object that needs to be unpacked. Otherwise, it causes an error down the line, like the following:

Custom VJP fwd rule jaxeigh_fwd for function adaware_eigh must produce a pair (list or tuple of length two) where the first element represents the primal output (equal to the output of the custom_vjp-decorated function adaware_eigh) and the second element represents residuals (i.e. values stored from the forward pass for use on the backward pass), but instead the fwd rule output's first element had container/pytree structure:
(float32[321, complex64[32, 32])
while the custom_vjp-decorated function adaware_eigh had output container/pytree structure:
EighResult (eigenvalues=float32[32], eigenvectors=complex64[32,32]) .

def adaware_eigh(A: Array) -> Array:
return jnp.linalg.eigh(A)
result = jnp.linalg.eigh(A)
e = result.eigenvalues
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem of the change proposed here is that it is not compatible with older versions of Jax, where the result is just a tuple?
What about e, v = jnp.linalg.eigh(A); return e, v, will this change be compatible with both tuple and namedtuple?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, that would work! Shall I update it?

@codecov
Copy link

codecov bot commented Jan 18, 2025

Codecov Report

Attention: Patch coverage is 57.14286% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
tensorcircuit/backends/jax_ops.py 57.14% 3 Missing ⚠️
Files with missing lines Coverage Δ
tensorcircuit/backends/jax_ops.py 82.72% <57.14%> (-1.09%) ⬇️

Copy link
Member

@refraction-ray refraction-ray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@refraction-ray refraction-ray merged commit 7df72d0 into tensorcircuit:master Jan 18, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants