Skip to content

Minor Fix: Changed 'NamedShape' to 'DShapedArray' in layers.py for newer JAX version compatibility#210

Open
obbwins wants to merge 1 commit intosanchit-gandhi:mainfrom
obbwins:fix-jax-compatibility
Open

Minor Fix: Changed 'NamedShape' to 'DShapedArray' in layers.py for newer JAX version compatibility#210
obbwins wants to merge 1 commit intosanchit-gandhi:mainfrom
obbwins:fix-jax-compatibility

Conversation

@obbwins
Copy link

@obbwins obbwins commented Apr 27, 2025

This PR fixes a minor incompatibility with newer JAX versions by updating the type annotation from jax.core.NamedShape to jax.core.DShapedArray in the compute_fans function.

This is a minimal change to address the immediate compatibility issue. I'm not super experienced with JAX so please let me know if an alternative is better :)

Tested with JAX version: 0.6.0

…edArray for compatibility with newer JAX versions
@samuelbradshaw
Copy link

This fix helped me – thanks. @sanchit-gandhi, could you consider merging this? (And also #177)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants