Skip to content

Commit 0867823

Browse files
SiegeLordExjburnim
authored andcommitted
Call convert_to_tensor in tf2jax.cast().
The intention is to get code like `tf2jax.cast(tfp.util.DeferredTensor)` to work correctly. As it is now, it calls `DeferredTensor.__array__`, which is only intended to work with numpy at the moment. It's not clear whether `jnp.array(DT)` should bypass numpy or not. That aside, calling convert_to_tensor at the beginning of functions is pretty standard, so this change is probably a good idea either way. PiperOrigin-RevId: 398321147
1 parent 63bc1fd commit 0867823

File tree

1 file changed

+1
-1
lines changed
  • tensorflow_probability/python/internal/backend/numpy

1 file changed

+1
-1
lines changed

tensorflow_probability/python/internal/backend/numpy/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def batch_jacobian(self, target, source, # pylint: disable=unused-argument
377377

378378

379379
def _cast(x, dtype):
380-
x = convert_to_tensor(x)
380+
x = np.asarray(x)
381381
if (np.issubdtype(x.dtype, np.complexfloating) and
382382
not np.issubdtype(dtype, np.complexfloating)):
383383
x = np.real(x)

0 commit comments

Comments
 (0)