Commit 7089d1a
Call convert_to_tensor in tf2jax.cast().
The intention is to get code like `tf2jax.cast(tfp.util.DeferredTensor)` to work correctly. As it is now, it calls `DeferredTensor.__array__`, which is only intended to work with numpy at the moment. It's not clear whether `jnp.array(DT)` should bypass numpy or not.
That aside, calling convert_to_tensor at the beginning of functions is pretty standard, so this change is probably a good idea either way.
PiperOrigin-RevId: 3983639661 parent 36c6167 commit 7089d1a
1 file changed
+1
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
377 | 377 | | |
378 | 378 | | |
379 | 379 | | |
380 | | - | |
| 380 | + | |
381 | 381 | | |
382 | 382 | | |
383 | 383 | | |
| |||
0 commit comments