Skip to content

Commit 4bf8811

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Add _convert_variables_to_tensors method to TFP library CompositeTensors.
PiperOrigin-RevId: 454701997
1 parent 4107afd commit 4bf8811

File tree

15 files changed

+280
-60
lines changed

15 files changed

+280
-60
lines changed

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from tensorflow_probability.python.internal import tensor_util
3333
from tensorflow_probability.python.internal import tensorshape_util
3434
from tensorflow_probability.python.internal import test_util
35-
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3635

3736
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
3837

@@ -789,13 +788,6 @@ def testCompositeTensor(self, bijector_name, data):
789788
'bijectors.')
790789
self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')
791790

792-
if not tf.executing_eagerly():
793-
bijector = tf.nest.map_structure(
794-
lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda
795-
if isinstance(x, DeferredTensor) else x),
796-
bijector,
797-
expand_composites=True)
798-
799791
self.assertIsInstance(bijector, tf.__internal__.CompositeTensor)
800792
flat = tf.nest.flatten(bijector, expand_composites=True)
801793
unflat = tf.nest.pack_sequence_as(bijector, flat, expand_composites=True)
@@ -833,6 +825,8 @@ def testCompositeTensor(self, bijector_name, data):
833825
grads = tape.gradient(ys, wrt_vars)
834826
assert_no_none_grad(bijector, 'forward', wrt_vars, grads)
835827

828+
self.assertConvertVariablesToTensorsWorks(bijector)
829+
836830

837831
def ensure_nonzero(x):
838832
return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x)

tensorflow_probability/python/distributions/gaussian_process.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from tensorflow_probability.python.internal import tensorshape_util
3939
from tensorflow_probability.python.math.psd_kernels.internal import util as psd_kernels_util
4040
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
41+
from tensorflow.python.util import variable_utils # pylint: disable=g-direct-tensorflow-import
42+
4143

4244
__all__ = [
4345
'GaussianProcess',
@@ -767,6 +769,14 @@ def _type_spec(self):
767769
omit_kwargs=('parameters', '_check_marginal_cholesky_fn'),
768770
non_identifying_kwargs=('name',))
769771

772+
def _convert_variables_to_tensors(self):
773+
# pylint: disable=protected-access
774+
components = self._type_spec._to_components(self)
775+
tensor_components = variable_utils.convert_variables_to_tensors(
776+
components)
777+
return self._type_spec._from_components(tensor_components)
778+
# pylint: enable=protected-access
779+
770780

771781
@auto_composite_tensor.type_spec_register(
772782
'tfp.distributions.GaussianProcess_ACTTypeSpec')

tensorflow_probability/python/distributions/platform_compatibility_test.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@
222222
'WishartTriL': 1e-5,
223223
})
224224

225+
COMPOSITE_TENSOR_LOGPROB_RTOL = collections.defaultdict(lambda: 1e-6)
226+
COMPOSITE_TENSOR_LOGPROB_RTOL.update({
227+
'WishartTriL': 1e-5,
228+
})
225229

226230
SKIP_KL_CHECK_DIST_VAR_GRADS = [
227231
'GeneralizedExtremeValue', # TD's KL gradients do not rely on bijector
@@ -419,29 +423,35 @@ def _test_sample_and_log_prob(self, dist_name, dist):
419423
ct_lp1 = unflat.log_prob(sample1)
420424
orig_lp1 = dist.log_prob(sample1)
421425
ct_lp1_, orig_lp1_ = self.evaluate((ct_lp1, orig_lp1))
422-
self.assertAllClose(ct_lp1_, orig_lp1_)
426+
self.assertAllClose(ct_lp1_, orig_lp1_,
427+
rtol=COMPOSITE_TENSOR_LOGPROB_RTOL[dist_name])
423428

424429
# ... and after. (Even though they're supposed to be the same anyway.)
425430
ct_lp2 = unflat.log_prob(sample2)
426431
orig_lp2 = dist.log_prob(sample2)
427432
ct_lp2_, orig_lp2_ = self.evaluate((ct_lp2, orig_lp2))
428-
self.assertAllClose(ct_lp2_, orig_lp2_)
433+
self.assertAllClose(ct_lp2_, orig_lp2_,
434+
rtol=COMPOSITE_TENSOR_LOGPROB_RTOL[dist_name])
429435

430-
# TODO(alexeev): Add coverage for meta distributions, in addition to base
431-
# distributions.
432436
@parameterized.named_parameters(
433-
{'testcase_name': dname, 'dist_name': dname}
434-
for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()))
437+
{'testcase_name': dname, 'dist_name': dname} # pylint: disable=g-complex-comprehension
438+
for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys())
439+
+ list(dhps.INSTANTIABLE_META_DISTS))
435440
if dname not in dhps.TF2_UNFRIENDLY_DISTS)
436441
@hp.given(hps.data())
437442
@tfp_hps.tfp_hp_settings()
438443
def testCompositeTensor(self, dist_name, data):
439444
dist = data.draw(
440445
dhps.distributions(
441-
dist_name=dist_name, enable_vars=False, validate_args=False))
446+
dist_name=dist_name, enable_vars=True, validate_args=False,
447+
eligibility_filter=(
448+
lambda d: type(d).__name__ not in dhps.TF2_UNFRIENDLY_DISTS)))
449+
self.evaluate([v.initializer for v in dist.trainable_variables])
442450
with tfp_hps.no_tf_rank_errors():
443451
self._test_sample_and_log_prob(dist_name, dist)
444452

453+
self.assertConvertVariablesToTensorsWorks(dist)
454+
445455

446456
@test_util.test_graph_mode_only
447457
class DistributionXLATest(test_util.TestCase):

tensorflow_probability/python/internal/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@ multi_substrate_py_library(
707707
"//tensorflow_probability/python/bijectors:bijector",
708708
"//tensorflow_probability/python/internal:empirical_statistical_testing",
709709
"//tensorflow_probability/python/internal/backend/numpy",
710+
"//tensorflow_probability/python/util:deferred_tensor",
710711
"//tensorflow_probability/python/util:seed_stream",
711712
],
712713
)

tensorflow_probability/python/internal/auto_composite_tensor.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow.python.framework import type_spec
2727
from tensorflow.python.ops import resource_variable_ops
2828
from tensorflow.python.util import tf_inspect
29+
from tensorflow.python.util import variable_utils
2930
# pylint: enable=g-direct-tensorflow-import
3031

3132
__all__ = [
@@ -527,6 +528,9 @@ def auto_composite_tensor(
527528
- object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid
528529
- object.attribute = ['abc', tf.constant(1.)] # invalid
529530
531+
All `__init__` args that may be `ResourceVariable`s must also admit `Tensor`s
532+
(or else `_convert_variables_to_tensors` must be overridden).
533+
530534
If the attribute is a callable, serialization of the `TypeSpec`, and therefore
531535
interoperability with `tf.saved_model`, is not currently supported. As a
532536
workaround, callables that do not contain or close over `Tensor`s may be
@@ -655,6 +659,37 @@ def body(obj):
655659
type_spec_class_name = f'{cls.__name__}_ACTTypeSpec'
656660
type_spec_name = f'{module_name}.{type_spec_class_name}'
657661

662+
def _convert_variables_to_tensors(obj):
663+
"""Recursively converts Variables in the AutoCompositeTensor to Tensors.
664+
665+
This method flattens `obj` into a nested structure of `Tensor`s or
666+
`CompositeTensor`s, converts any `ResourceVariable`s (which are
667+
`CompositeTensor`s) to `Tensor`s, and rebuilds `obj` with `Tensor`s in place
668+
of `ResourceVariable`s.
669+
670+
The usage of `obj._type_spec._from_components` violates the contract of
671+
`CompositeTensor`, since it is called on a different nested structure
672+
(one containing only `Tensor`s) than `obj.type_spec` specifies (one that may
673+
contain `ResourceVariable`s). Since `AutoCompositeTensor`'s
674+
`_from_components` method passes the contents of the nested structure to
675+
`__init__` to rebuild the TFP object, and any TFP object that may be
676+
instantiated with `ResourceVariables` may also be instantiated with
677+
`Tensor`s, this usage is valid.
678+
679+
Args:
680+
obj: An `AutoCompositeTensor` instance.
681+
682+
Returns:
683+
tensor_obj: `obj` with all internal `ResourceVariable`s converted to
684+
`Tensor`s.
685+
"""
686+
# pylint: disable=protected-access
687+
components = obj._type_spec._to_components(obj)
688+
tensor_components = variable_utils.convert_variables_to_tensors(
689+
components)
690+
return obj._type_spec._from_components(tensor_components)
691+
# pylint: enable=protected-access
692+
658693
# If the declared class is already a CompositeTensor subclass, we can avoid
659694
# affecting the actual type of the returned class. Otherwise, we need to
660695
# explicitly mix in the CT type, and hence create and return a newly
@@ -674,7 +709,10 @@ def _type_spec(obj):
674709
return _AlreadyCTTypeSpec.from_instance(
675710
obj, omit_kwargs, non_identifying_kwargs)
676711

677-
cls._type_spec = property(_type_spec) # pylint: disable=protected-access
712+
# pylint: disable=protected-access
713+
cls._type_spec = property(_type_spec)
714+
cls._convert_variables_to_tensors = _convert_variables_to_tensors
715+
# pylint: enable=protected-access
678716
return cls
679717

680718
clsid = (cls.__module__, cls.__name__, omit_kwargs,
@@ -701,6 +739,9 @@ def _type_spec(self):
701739
return _GeneratedCTTypeSpec.from_instance(
702740
self, omit_kwargs, non_identifying_kwargs)
703741

742+
def _convert_variables_to_tensors(self):
743+
return _convert_variables_to_tensors(self)
744+
704745
_AutoCompositeTensor.__name__ = cls.__name__
705746
_registry[clsid] = _AutoCompositeTensor
706747
type_spec_register(type_spec_name)(_GeneratedCTTypeSpec)

tensorflow_probability/python/internal/auto_composite_tensor_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,40 @@ def fn(x):
455455
batch_dist = tf.vectorized_map(fn, tf.convert_to_tensor([1., 2., 3.]))
456456
self.assertAllEqual(batch_dist.batch_shape, [3, 2])
457457

458+
def test_convert_variables_to_tensors(self):
459+
v = tf.Variable(0.)
460+
u = tf.Variable(1.)
461+
var_td = tfd.TransformedDistribution(
462+
tfd.Normal(v, 1.),
463+
bijector=tfb.Shift(u))
464+
tensor_td = var_td._convert_variables_to_tensors()
465+
466+
self.evaluate([x.initializer for x in var_td.trainable_variables])
467+
self.assertIsInstance(var_td, tf.__internal__.CompositeTensor)
468+
self.assertLen(var_td.trainable_variables, 2)
469+
self.assertEmpty(tensor_td.trainable_variables)
470+
self.assertEqual(self.evaluate(var_td.distribution.loc),
471+
self.evaluate(tensor_td.distribution.loc))
472+
self.assertEqual(self.evaluate(var_td.bijector.shift),
473+
self.evaluate(tensor_td.bijector.shift))
474+
475+
def test_automatic_conversion_to_tensor(self):
476+
v = tf.Variable(tf.ones([5]))
477+
d = tfd.Normal(tf.zeros([5]), v)
478+
x = tf.convert_to_tensor([3.])
479+
480+
vectorized_log_prob = tf.vectorized_map(lambda z: z.log_prob(x), d)
481+
log_prob = d.log_prob(x)
482+
self.evaluate(v.initializer)
483+
self.assertAllClose(vectorized_log_prob[:, 0], log_prob)
484+
485+
loc = tf.Variable(0.)
486+
self.evaluate(loc.initializer)
487+
cond_dist = tf.cond(
488+
tf.convert_to_tensor(True),
489+
lambda: tfd.Normal(loc, 1.), lambda: tfd.Normal(0., 1.))
490+
self.assertIsInstance(cond_dist, tfd.Normal)
491+
458492

459493
class _TestTypeSpec(auto_composite_tensor._AutoCompositeTensorTypeSpec):
460494

tensorflow_probability/python/internal/backend/jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ FILENAMES = [
6363
"test_lib",
6464
"tf_inspect",
6565
"type_spec",
66+
"variable_utils",
6667
"variables",
6768
"v1",
6869
"v2",

tensorflow_probability/python/internal/backend/numpy/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ py_library(
6262
":test_lib",
6363
":tf_inspect",
6464
":type_spec",
65+
":variable_utils",
6566
":variables",
6667
],
6768
)
@@ -408,6 +409,11 @@ py_library(
408409
srcs = ["type_spec.py"],
409410
)
410411

412+
py_library(
413+
name = "variable_utils",
414+
srcs = ["variable_utils.py"],
415+
)
416+
411417
py_library(
412418
name = "variables",
413419
srcs = ["variables.py"],

tensorflow_probability/python/internal/backend/numpy/tensor_spec.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919
]
2020

2121

22-
class TensorSpec(object):
22+
class DenseSpec(object):
2323

24-
def __init__(self, shape, dtype):
24+
def __init__(self, shape, dtype, name=None):
2525
self.shape = shape
2626
self.dtype = dtype
27+
self.name = name
2728

2829
def __repr__(self):
29-
return f'TensorSpec(shape={self.shape}, dtype={self.dtype})'
30+
return '{}(shape={}, dtype={}, name={})'.format(
31+
type(self).__name__, self.shape, repr(self.dtype), repr(self.name))
32+
33+
34+
class TensorSpec(DenseSpec):
35+
pass
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2022 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Numpy stub for `variable_utils`."""
16+
17+
__all__ = [
18+
'convert_variables_to_tensors',
19+
]
20+
21+
22+
def convert_variables_to_tensors(x):
23+
return x

0 commit comments

Comments
 (0)