Skip to content

Commit 36c6167

Browse files
SiegeLordExjburnim
authored andcommitted
Fix infinite loop when converting tf1.Dimension to a nptf.Tensor.
It was caused by using nptf.dimension_value on a tf.Dimension, which did not correctly strip away the Dimension type, since nptf.dimension was checking for type equality against nptf.Dimension (rather than tf.Dimension). PiperOrigin-RevId: 398351764
1 parent 0867823 commit 36c6167

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

tensorflow_probability/python/internal/prefer_static.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,20 @@
4343

4444
JAX_MODE = False
4545

46+
4647
# Enable converting TF TensorShape and Dimension into np.array. This allows TF
4748
# code to pass TensorShapes into prefer_static functions. We can also re-use the
48-
# nptf methods.
49+
# TensorShape conversion logic from nptf, but Dimension needs to be here to make
50+
# sure we use TF's dimension_value.
51+
def _convert_dimension_to_tensor(value, dtype=None):
52+
dtype = dtype or np.int32
53+
if dtype not in (np.int32, np.int64):
54+
raise nptf.ops.TypeConversionError(value, dtype)
55+
return nptf.convert_to_tensor(tf.compat.dimension_value(value), dtype=dtype)
56+
57+
4958
nptf.register_tensor_conversion_function(
50-
tf1.Dimension, nptf.ops._convert_dimension_to_tensor) # pylint: disable=protected-access
59+
tf1.Dimension, _convert_dimension_to_tensor)
5160
nptf.register_tensor_conversion_function(
5261
tf.TensorShape, nptf.ops._convert_tensorshape_to_tensor) # pylint: disable=protected-access
5362

tensorflow_probability/python/internal/prefer_static_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ def test_rank_from_shape_scalar(self):
267267
self.evaluate(v.initializer)
268268
self.assertEqual(1, self.evaluate(ps.rank_from_shape(v)))
269269

270+
def test_convert_dimension_to_tensor(self):
271+
v = ps.constant(tf1.Dimension(1))
272+
self.assertEqual(1, v)
273+
270274
def test_rank_from_shape(self):
271275
shape = [2, 4, 3]
272276
expected_rank = len(shape)

0 commit comments

Comments
 (0)