Skip to content

Commit 96b024d

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
tfp.util.TransformedVariable and tfp.util.DeferredTensor subclass CompositeTensor.
PiperOrigin-RevId: 374959903
1 parent 495e3f2 commit 96b024d

File tree

2 files changed

+166
-17
lines changed

2 files changed

+166
-17
lines changed

tensorflow_probability/python/util/deferred_tensor.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import abc
2021
import functools
2122
import numpy as np
2223
import six
@@ -81,10 +82,10 @@ def _tensorize(d, dtype=None, name=None, as_ref=False):
8182
return d._value(dtype, name, as_ref) # pylint: disable=protected-access
8283

8384

84-
class TensorMetaClass(type):
85+
class TensorMetaClass(abc.ABCMeta):
8586
"""A type of class which will make objects which act like Tensors."""
8687

87-
def __new__(mcs, name, bases, attrs):
88+
def __new__(mcs, name, bases, attrs): # pylint: disable=bad-mcs-classmethod-argument
8889
operators = set(tf.Tensor.OVERLOADABLE_OPERATORS)
8990
operators.difference_update({'__eq__', '__ne__'})
9091
operators.update({'__iter__'})
@@ -109,15 +110,16 @@ def __new__(mcs, name, bases, attrs):
109110
attrs.update(
110111
(attr, getattr(tf.Tensor, attr))
111112
for attr in {'__bool__', '__array_priority__', '__nonzero__'})
112-
cls = super(TensorMetaClass, mcs).__new__(mcs, name, bases, attrs)
113+
cls = super(TensorMetaClass, mcs).__new__(mcs, name, bases, attrs) # pylint: disable=too-many-function-args
113114
tf.register_tensor_conversion_function(cls, conversion_func=_tensorize)
114115
return cls
115116

116117

117118
NONE_SPECIFIED = 'None'
118119

119120

120-
class DeferredTensor(six.with_metaclass(TensorMetaClass, tf.Module)):
121+
class DeferredTensor(six.with_metaclass(
122+
TensorMetaClass, tf.Module, tf.__internal__.CompositeTensor)):
121123
"""Variable tracking object which applies function upon `convert_to_tensor`.
122124
123125
#### Example
@@ -379,6 +381,39 @@ def __array__(self, dtype=None):
379381
'numpy array.')
380382
return np.array(self._value(dtype=dtype))
381383

384+
def _get_input_spec(self):
385+
if isinstance(self.pretransformed_input, tf.__internal__.CompositeTensor):
386+
return self.pretransformed_input._type_spec # pylint: disable=protected-access
387+
if isinstance(self.pretransformed_input, tf.Variable):
388+
return resource_variable_ops.VariableSpec(
389+
self.pretransformed_input.shape,
390+
dtype=self.pretransformed_input.dtype,
391+
trainable=self.pretransformed_input.trainable)
392+
return tf.TensorSpec.from_tensor(self.pretransformed_input)
393+
394+
@property
395+
def _type_spec(self):
396+
input_spec = self._get_input_spec()
397+
transform_or_spec = getattr(self._transform_fn, '_type_spec',
398+
self._transform_fn)
399+
400+
# Extract Variables from also_track.
401+
if self.also_track is None:
402+
also_track_spec = None
403+
else:
404+
also_track_vars = tf.nest.flatten(
405+
tf.nest.map_structure(
406+
lambda x: x.variables if isinstance(x, tf.Module) else x,
407+
self.also_track))
408+
also_track_spec = tf.nest.map_structure(
409+
lambda x: resource_variable_ops.VariableSpec( # pylint: disable=g-long-lambda
410+
x.shape, x.dtype, trainable=x.trainable),
411+
also_track_vars)
412+
413+
return _DeferredTensorSpec(
414+
input_spec, transform_or_spec, dtype=self.dtype, shape=self.shape,
415+
name=self.name, also_track_spec=also_track_spec)
416+
382417

383418
class TransformedVariable(DeferredTensor):
384419
"""Variable tracking object which applies a bijector upon `convert_to_tensor`.
@@ -455,11 +490,7 @@ def __init__(self, initial_value, bijector, dtype=None, name=None, **kwargs):
455490
which is the initial value for the `TransformedVariable`. The underlying
456491
untransformed `tf.Variable` will be initialized with
457492
`bijector.inverse(initial_value)`. Can also be a callable with no
458-
argument that returns the initial value when called. Note: if
459-
`initial_value` is a `TransformedVariable` then the instantiated object
460-
does not create a new `tf.Variable`, but rather points to the underlying
461-
`Variable` and chains the `bijector` arg with the underlying bijector as
462-
`tfb.Chain([bijector, initial_value.bijector])`.
493+
argument that returns the initial value when called.
463494
bijector: A `Bijector`-like instance which defines the transformations
464495
applied to the underlying `tf.Variable`.
465496
dtype: `tf.dtype.DType` instance or otherwise valid `dtype` value to
@@ -479,16 +510,25 @@ def __init__(self, initial_value, bijector, dtype=None, name=None, **kwargs):
479510

480511
if callable(initial_value):
481512
initial_value = initial_value()
482-
initial_value = tf.convert_to_tensor(
483-
initial_value, dtype_hint=bijector.dtype, dtype=dtype)
513+
514+
# Extra kwarg that TypeSpec._from_components uses to re-build the object
515+
# without re-initializing the variable.
516+
pretransformed_input = kwargs.pop('pretransformed_input', None)
517+
if pretransformed_input is None:
518+
initial_value = tf.convert_to_tensor(
519+
initial_value, dtype_hint=bijector.dtype, dtype=dtype)
520+
pretransformed_input = tf.Variable(
521+
initial_value=bijector.inverse(initial_value),
522+
name=name,
523+
dtype=dtype,
524+
**kwargs)
525+
shape = initial_value.shape
526+
else:
527+
shape = bijector.forward_event_shape(pretransformed_input.shape)
484528
super(TransformedVariable, self).__init__(
485-
pretransformed_input=tf.Variable(
486-
initial_value=bijector.inverse(initial_value),
487-
name=name,
488-
dtype=dtype,
489-
**kwargs),
529+
pretransformed_input=pretransformed_input,
490530
transform_fn=bijector,
491-
shape=initial_value.shape,
531+
shape=shape,
492532
name=bijector.name)
493533
self._bijector = bijector
494534

@@ -529,6 +569,13 @@ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
529569
name=name,
530570
read_value=read_value)
531571

572+
@property
573+
def _type_spec(self):
574+
input_spec = self._get_input_spec()
575+
transform_or_spec = getattr(self.bijector, '_type_spec', self.bijector)
576+
return _TransformedVariableSpec(
577+
input_spec, transform_or_spec, self.dtype, self.name)
578+
532579

533580
class _DeferredTensorSpecBase(object):
534581
"""Common methods for '_DeferredTensorSpec' and '_TransformedVariableSpec."""

tensorflow_probability/python/util/deferred_tensor_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,24 @@
3535
JAX_MODE = False
3636

3737

38+
class NonCompositeTensorExp(object):
39+
40+
name = 'non_composite_exp'
41+
dtype = tf.float32
42+
43+
def forward(self, x):
44+
return tf.math.exp(x)
45+
46+
def inverse(self, y):
47+
return tf.math.log(y)
48+
49+
def forward_event_shape(self, shape):
50+
return shape
51+
52+
def inverse_event_shape(self, shape):
53+
return shape
54+
55+
3856
@test_util.test_all_tf_execution_regimes
3957
class DeferredTensorTest(test_util.TestCase):
4058

@@ -149,6 +167,55 @@ def test_from_bijector_with_inverted_assignment(self):
149167
self.assertAllEqual([[-1.], [-2.], [-3.]], v_)
150168
self.assertAllEqual([[-1., 0.], [-2., 0.], [-3., 0.]], y_)
151169

170+
@test_util.disable_test_for_backend(
171+
disable_numpy=True, disable_jax=True,
172+
reason='JAX and Numpy do not have `CompositeTensor`.')
173+
@parameterized.named_parameters(
174+
('transform_fn_is_bijector', tfb.Exp),
175+
('transform_fn_is_bijector_like', NonCompositeTensorExp),
176+
('transform_fn_is_callable', lambda: tf.math.exp))
177+
def test_composite_tensor(self, make_transform_fn):
178+
initial_value = [0.2, 3.]
179+
pretransformed_input = tf.Variable(initial_value, dtype=tf.float32)
180+
x = tfp.util.DeferredTensor(pretransformed_input, make_transform_fn())
181+
182+
@tf.function
183+
def f(x_):
184+
self.assertLen(x_.trainable_variables, 1)
185+
return x_
186+
187+
y = f(x)
188+
self.evaluate([v.initializer for v in x.trainable_variables])
189+
self.assertAllClose(self.evaluate(tf.convert_to_tensor(y)),
190+
self.evaluate(tf.math.exp(initial_value)))
191+
self.assertLen(x.trainable_variables, 1)
192+
self.assertLen(y.trainable_variables,
193+
1 if tf.config.functions_run_eagerly() else 0)
194+
195+
@test_util.disable_test_for_backend(
196+
disable_numpy=True, disable_jax=True,
197+
reason=('vectorized_map not implemented in Numpy; '
198+
'`DeferredTensor` is not a valid JAX type.'))
199+
def test_vectorized_map(self):
200+
pretransformed_input = tf.Variable(tf.ones([5, 3]))
201+
x = tfp.util.DeferredTensor(pretransformed_input, tfb.Scale([5]))
202+
y = tf.vectorized_map(lambda v: v + 2., x)
203+
self.evaluate([v.initializer for v in x.trainable_variables])
204+
self.assertAllClose(self.evaluate(y), 5. * pretransformed_input + 2.)
205+
206+
@test_util.disable_test_for_backend(
207+
disable_numpy=True, disable_jax=True,
208+
reason='JAX and Numpy have no notion of `CompositeTensor`.')
209+
def test_also_track_through_flatten_unflatten(self):
210+
pretransformed_input = tf.Variable(3.)
211+
also_track = tfd.Normal(tf.Variable(0.), scale=1.)
212+
x = tfp.util.DeferredTensor(pretransformed_input,
213+
tfb.Shift(tf.Variable(2.)),
214+
also_track=also_track)
215+
flat = tf.nest.flatten(x, expand_composites=True)
216+
unflat = tf.nest.pack_sequence_as(x, flat, expand_composites=True)
217+
self.assertLen(unflat.trainable_variables, 3)
218+
152219

153220
@test_util.test_all_tf_execution_regimes
154221
class TransformedVariableTest(test_util.TestCase):
@@ -274,6 +341,41 @@ def test_nested_transformed_variable(self):
274341
self.assertIsNot(x.pretransformed_input, y.pretransformed_input)
275342
# Different vars have no deps so we needn't test cross-talk.
276343

344+
@test_util.disable_test_for_backend(
345+
disable_numpy=True, disable_jax=True,
346+
reason='JAX and Numpy do not have `CompositeTensor`.')
347+
@parameterized.named_parameters(
348+
('composite_bijector', tfb.Softplus),
349+
('non_composite_bijector', NonCompositeTensorExp))
350+
def test_composite_tensor(self, make_bijector):
351+
x = tfp.util.TransformedVariable(5., make_bijector())
352+
add_val = 10.
353+
354+
@tf.function
355+
def f(x_):
356+
x_.assign_add(add_val)
357+
self.assertLen(x_.trainable_variables, 1)
358+
return x_
359+
360+
y = f(x)
361+
self.evaluate([v.initializer for v in x.trainable_variables])
362+
self.assertAllClose(self.evaluate(tf.convert_to_tensor(y)), 15.)
363+
self.assertAllClose(self.evaluate(tf.convert_to_tensor(x)), 15.)
364+
self.assertLen(x.trainable_variables, 1)
365+
self.assertLen(y.trainable_variables,
366+
1 if tf.config.functions_run_eagerly() else 0)
367+
368+
@test_util.disable_test_for_backend(
369+
disable_numpy=True, disable_jax=True,
370+
reason=('vectorized_map not implemented in Numpy; '
371+
'`DeferredTensor` is not a valid JAX type.'))
372+
def test_vectorized_map(self):
373+
initial_value = tf.ones([5, 3])
374+
x = tfp.util.TransformedVariable(initial_value, tfb.Sigmoid())
375+
y = tf.vectorized_map(lambda v: v + 2., x)
376+
self.evaluate([v.initializer for v in x.trainable_variables])
377+
self.assertAllClose(self.evaluate(y), initial_value + 2.)
378+
277379

278380
@test_util.test_all_tf_execution_regimes
279381
class DeferredTensorBehavesLikeTensorTest(test_util.TestCase):

0 commit comments

Comments
 (0)