Skip to content

Commit 1d16641

Browse files
srvasudejburnim
authored andcommitted
Fix LinearOperator to ensure that the JAX version doesn't rely on the Numpy backend.
PiperOrigin-RevId: 398070290
1 parent eda25bd commit 1d16641

File tree

1 file changed

+4
-0
lines changed
  • tensorflow_probability/python/internal/backend/jax

1 file changed

+4
-0
lines changed

tensorflow_probability/python/internal/backend/jax/rewrite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def main(argv):
4141
contents = contents.replace(
4242
'import tensorflow_probability.substrates.numpy as tfp',
4343
'import tensorflow_probability.substrates.jax as tfp')
44+
# To fix lazy imports in `LinearOperator`.
45+
contents = contents.replace(
46+
'tensorflow_probability.substrates.numpy',
47+
'tensorflow_probability.substrates.jax')
4448
contents = contents.replace('scipy.linalg', 'jax.scipy.linalg')
4549
contents = contents.replace('scipy.special', 'jax.scipy.special')
4650
if FLAGS.rewrite_numpy_import:

0 commit comments

Comments
 (0)