Skip to content

Commit 9a34093

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Add an option to convert_to_nested_tensor to not convert objects with reference semantics.
PiperOrigin-RevId: 476461510
1 parent 854abf1 commit 9a34093

File tree

5 files changed

+77
-14
lines changed

5 files changed

+77
-14
lines changed

tensorflow_probability/python/internal/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ multi_substrate_py_library(
420420
srcs = ["nest_util.py"],
421421
deps = [
422422
":prefer_static",
423+
":tensor_util",
423424
# numpy dep,
424425
# tensorflow dep,
425426
],

tensorflow_probability/python/internal/backend/numpy/numpy_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,6 +1813,11 @@ def test_convert_variable_to_tensor(self):
18131813
self.assertEqual(np.float64, x.dtype)
18141814
self.assertAllEqual([0., 1., 2.], x)
18151815

1816+
def test_convert_variable_to_tensor_with_dtype_hint(self):
1817+
v = nptf.Variable(np.int32(0))
1818+
x = nptf.convert_to_tensor(v, dtype_hint=tf.float32)
1819+
self.assertEqual(np.int32, x.dtype)
1820+
18161821
def test_get_static_value(self):
18171822
x = nptf.get_static_value(nptf.zeros((3, 2), dtype=nptf.float32))
18181823
self.assertEqual(onp.ndarray, type(x))

tensorflow_probability/python/internal/backend/numpy/ops.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None): # pylint
141141
"""Emulates tf.convert_to_tensor."""
142142
dtype = utils.numpy_dtype(dtype)
143143
dtype_hint = utils.numpy_dtype(dtype_hint)
144-
if is_tensor(value) and not isinstance(value, Variable):
144+
if isinstance(value, Variable):
145+
value = value.__wrapped__
146+
if is_tensor(value):
145147
# In NumPy mode, we are lenient on the dtype compatibility check because
146148
# some codepaths rely on flexible conversion from int/float64 to 32.
147149
if dtype is not None and value.dtype != dtype:
@@ -619,13 +621,18 @@ def __init__(
619621
v = v.astype(utils.numpy_dtype(dtype))
620622
super(NumpyVariable, self).__init__(v)
621623
self._self_name = name
624+
self._self_trainable = trainable
622625
self.initializer = None
623626
# pylint: enable=unused-argument
624627

625628
@property
626629
def name(self):
627630
return self._self_name if self._self_name is not None else str(id(self))
628631

632+
@property
633+
def trainable(self):
634+
return self._self_trainable
635+
629636
def __array__(self, dtype=None):
630637
if dtype is not None:
631638
dtype = utils.numpy_dtype(dtype)
@@ -658,10 +665,6 @@ def assign_sub(self, value, **_):
658665
jax.core.pytype_aval_mappings[onp.ndarray])
659666

660667

661-
def _convert_variable_to_tensor(value, dtype=None):
662-
return convert_to_tensor(value.__wrapped__, dtype=dtype)
663-
register_tensor_conversion_function(NumpyVariable, _convert_variable_to_tensor)
664-
665668
Variable = NumpyVariable
666669

667670

tensorflow_probability/python/internal/nest_util.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tensorflow.compat.v2 as tf
2121

2222
from tensorflow_probability.python.internal import prefer_static as ps
23+
from tensorflow_probability.python.internal import tensor_util
2324

2425
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
2526

@@ -410,7 +411,7 @@ def call_fn(fn, args):
410411

411412
def convert_to_nested_tensor(value, dtype=None, dtype_hint=None,
412413
allow_packing=False, as_shape_tensor=False,
413-
name=None):
414+
convert_ref=True, name=None):
414415
"""Converts the given `value` to a (structure of) `Tensor`.
415416
416417
This function converts Python objects of various types to a (structure of)
@@ -432,6 +433,8 @@ def convert_to_nested_tensor(value, dtype=None, dtype_hint=None,
432433
as_shape_tensor: Optional boolean when if `True` uses
433434
`prefer_static.convert_to_shape_tensor` instead of `tf.convert_to_tensor`
434435
for JAX compatibility.
436+
convert_ref: Python `bool`, default `True`. If `True`, convert objects with
437+
reference semantics to Tensor.
435438
name: Optional name to use if a new `Tensor` is created. If inputs are
436439
structured, elements are named accoring to '{name}/{path}.{to}.{elem}'.
437440
@@ -467,8 +470,11 @@ def convert_fn(path, value, dtype, dtype_hint, name=None):
467470
# break the Bijector cache on forward/inverse log det jacobian,
468471
# because tf.convert_to_tensor is not a no-op thereon.
469472
return value
470-
else:
473+
elif convert_ref:
471474
return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
475+
else:
476+
return tensor_util.convert_nonref_to_tensor(
477+
value, dtype, dtype_hint, name=name)
472478

473479
### The following branches only affect naming.
474480
# For unstructured calls, just use the provided name.

tensorflow_probability/python/internal/nest_util_test.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,27 @@ def __repr__(self):
6565
self.shape, self.dtype, self.name)
6666

6767

68+
class VariableSpec(TensorSpec):
69+
"""Stub for VariableSpec to simplify tests in np/jax backends."""
70+
71+
def __init__(self, shape, dtype=tf.float32, trainable=True, name=None):
72+
super(VariableSpec, self).__init__(shape, dtype, name=name)
73+
self.trainable = trainable
74+
75+
@classmethod
76+
def from_value(cls, variable):
77+
return cls(
78+
variable.shape, variable.dtype, variable.trainable, variable.name)
79+
80+
def __eq__(self, other):
81+
return (super(VariableSpec, self).__eq__(other)
82+
and self.trainable == other.trainable)
83+
84+
def __repr__(self):
85+
return 'VariableSpec(shape={}, dtype={}, trainable={}, name={})'.format(
86+
self.shape, self.dtype, self.trainable, self.name)
87+
88+
6889
class LeafList(list):
6990
_tfp_nest_expansion_force_leaf = ()
7091

@@ -297,12 +318,13 @@ def fn(arg1, arg2):
297318
'b': TensorSpec([], tf.int64, name='c2t/b')}
298319
},{
299320
'testcase_name': '_tensor_with_hint',
300-
'value': [TensorSpec([], tf.int32)],
321+
'value': [TensorSpec([], tf.int32, name='tensor')],
301322
'dtype_hint': [tf.float32],
302323
'expected': [TensorSpec([], tf.int32, name='tensor')]
303324
},{
304325
'testcase_name': '_tensor_struct',
305-
'value': [TensorSpec([], tf.int32), TensorSpec([], tf.float32)],
326+
'value': [TensorSpec([], tf.int32, name='tensor'),
327+
TensorSpec([], tf.float32, name='tensor')],
306328
'dtype_hint': [tf.float32, tf.float32],
307329
'expected': [TensorSpec([], tf.int32, name='tensor'),
308330
TensorSpec([], tf.float32, name='tensor_1')]
@@ -318,20 +340,46 @@ def fn(arg1, arg2):
318340
'name': None,
319341
'expected': [TensorSpec([], tf.float32, name='Const'),
320342
TensorSpec([], tf.float32, name='Const_1')]
343+
},{
344+
'testcase_name': '_tensor_and_variable_struct',
345+
'value': [TensorSpec([], tf.int32, name='tensor'),
346+
VariableSpec([], tf.float32, trainable=False, name='variable')],
347+
'dtype_hint': [tf.float32, tf.float32],
348+
'convert_ref': False,
349+
'expected': [
350+
TensorSpec([], tf.int32, name='tensor'),
351+
VariableSpec([], tf.float32, trainable=False, name='variable:0')]
352+
},{
353+
'testcase_name': '_tensor_and_variable_struct_convert_ref',
354+
'value': [VariableSpec([], tf.int32, name='variable'),
355+
TensorSpec([], tf.float32, name='tensor')],
356+
'dtype_hint': [tf.float32, tf.float32],
357+
'expected': [TensorSpec([], tf.int32, name='c2t/ReadVariableOp'),
358+
TensorSpec([], tf.float32, name='tensor')]
321359
})
322360
def testConvertToNestedTensor(
323-
self, value, dtype=None, dtype_hint=None, name='c2t', expected=None):
324-
# Convert specs to tensors
361+
self, value, dtype=None, dtype_hint=None, name='c2t', convert_ref=True,
362+
expected=None):
363+
# Convert specs to tensors or variables.
325364
def maybe_spec_to_tensor(x):
365+
if isinstance(x, VariableSpec):
366+
return tf.Variable(
367+
tf.zeros(x.shape, x.dtype), trainable=x.trainable, name=x.name)
326368
if isinstance(x, TensorSpec):
327-
return tf.zeros(x.shape, x.dtype, name='tensor')
369+
return tf.zeros(x.shape, x.dtype, name=x.name)
328370
return x
329371
value = nest.map_structure(maybe_spec_to_tensor, value)
330372

331373
# Grab shape/dtype from convert_to_nested_tensor for comparison.
374+
def spec_from_value(x):
375+
if isinstance(x, tf.Variable):
376+
return VariableSpec.from_value(x)
377+
return TensorSpec.from_tensor(x)
378+
332379
observed = nest.map_structure(
333-
TensorSpec.from_tensor,
334-
nest_util.convert_to_nested_tensor(value, dtype, dtype_hint, name=name))
380+
spec_from_value,
381+
nest_util.convert_to_nested_tensor(
382+
value, dtype, dtype_hint, convert_ref=convert_ref, name=name))
335383
self.assertAllEqualNested(observed, expected)
336384

337385
@parameterized.named_parameters({

0 commit comments

Comments
 (0)