Skip to content

Commit 367d4da

Browse files
brianwa84jburnim
authored andcommitted
[JAX] Fixes for tensorflow_probability for an upcoming change to jnp.array().
An upcoming change to jax.numpy.array() means that, under a transformation like jax.jit(), it will always stage its arrays into the trace. This often breaks if the array is being used for a shape calculation. Make sure we use static shapes in more places to fix test failures. PiperOrigin-RevId: 397849299
1 parent 4746a60 commit 367d4da

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def gen_module(module_name):
196196
code = code.replace('math_ops.range', 'array_ops.range')
197197
code = code.replace('ops.convert_to_tensor_v2_with_dispatch(',
198198
'ops.convert_to_tensor(')
199+
code = code.replace('ops.convert_to_tensor(dim_value)',
200+
'np.array(dim_value, np.int32)')
199201

200202
code = code.replace('self.dtype.real_dtype', 'dtypes.real_dtype(self.dtype)')
201203
code = code.replace('dtype.real_dtype', 'dtypes.real_dtype(dtype)')

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _domain_dimension_tensor(self, shape=None):
486486
# more efficient manner, e.g. without excessive Tensor conversions.
487487
dim_value = tensor_shape.dimension_value(self.domain_dimension)
488488
if dim_value is not None:
489-
return ops.convert_to_tensor(dim_value)
489+
return np.array(dim_value, np.int32)
490490
else:
491491
shape = self.shape_tensor() if shape is None else shape
492492
return shape[-1]
@@ -530,7 +530,7 @@ def _range_dimension_tensor(self, shape=None):
530530
# more efficient manner, e.g. without excessive Tensor conversions.
531531
dim_value = tensor_shape.dimension_value(self.range_dimension)
532532
if dim_value is not None:
533-
return ops.convert_to_tensor(dim_value)
533+
return np.array(dim_value, np.int32)
534534
else:
535535
shape = self.shape_tensor() if shape is None else shape
536536
return shape[-2]

0 commit comments

Comments
 (0)