Skip to content

Commit e26f9bb

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Temporarily remove AutoCompositeTensor from bijectors for TFP 0.13 release.
PiperOrigin-RevId: 374929567
1 parent 78cf6f8 commit e26f9bb

File tree

9 files changed

+173
-159
lines changed

9 files changed

+173
-159
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import abc
2222
import contextlib
23-
# import functools
23+
import functools
2424

2525
# Dependency imports
2626
import numpy as np
@@ -1599,39 +1599,33 @@ def _composite_tensor_shape_params(self):
15991599
return ()
16001600

16011601

1602-
# Temporarily disable AutoCT for TFP 0.13 release
1603-
# class AutoCompositeTensorBijector(
1604-
# Bijector, auto_composite_tensor.AutoCompositeTensor):
1605-
# r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.
1602+
class AutoCompositeTensorBijector(
1603+
Bijector, auto_composite_tensor.AutoCompositeTensor):
1604+
r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.
16061605
1607-
# `CompositeTensor` objects are able to pass in and out of `tf.function` and
1608-
# `tf.while_loop`, or serve as part of the signature of a TF saved model.
1609-
# `Bijector` subclasses that follow the contract of
1610-
# `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s # pylint: disable=line-too-long
1611-
# by inheriting from `AutoCompositeTensorBijector` and applying a class
1612-
# decorator as shown here:
1606+
`CompositeTensor` objects are able to pass in and out of `tf.function` and
1607+
`tf.while_loop`, or serve as part of the signature of a TF saved model.
1608+
`Bijector` subclasses that follow the contract of
1609+
`tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s
1610+
by inheriting from `AutoCompositeTensorBijector` and applying a class
1611+
decorator as shown here:
16131612
1614-
# ```python
1615-
# @tfp.experimental.auto_composite_tensor(
1616-
# omit_kwargs=('name',), module_name='my_module')
1617-
# class MyBijector(tfb.AutoCompositeTensorBijector):
1618-
1619-
# # The remainder of the subclass implementation is unchanged.
1620-
# ```
1621-
# """
1622-
# pass
1623-
1624-
1625-
# auto_composite_tensor_bijector = functools.partial(
1626-
# auto_composite_tensor.auto_composite_tensor,
1627-
# omit_kwargs=('parameters',),
1628-
# non_identifying_kwargs=('name',),
1629-
# module_name='tfp.bijectors')
1613+
```python
1614+
@tfp.experimental.auto_composite_tensor(
1615+
omit_kwargs=('name',), module_name='my_module')
1616+
class MyBijector(tfb.AutoCompositeTensorBijector):
16301617
1631-
AutoCompositeTensorBijector = Bijector
1618+
# The remainder of the subclass implementation is unchanged.
1619+
```
1620+
"""
1621+
pass
16321622

16331623

1634-
auto_composite_tensor_bijector = lambda cls, **kwargs: cls
1624+
auto_composite_tensor_bijector = functools.partial(
1625+
auto_composite_tensor.auto_composite_tensor,
1626+
omit_kwargs=('parameters',),
1627+
non_identifying_kwargs=('name',),
1628+
module_name='tfp.bijectors')
16351629

16361630

16371631
def check_valid_ndims(ndims, validate=True):

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
from tensorflow_probability.python import bijectors as tfb
2929
from tensorflow_probability.python import experimental
3030
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
31+
from tensorflow_probability.python.bijectors import invert as invert_lib
3132
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
3233
from tensorflow_probability.python.internal import prefer_static
3334
from tensorflow_probability.python.internal import samplers
3435
from tensorflow_probability.python.internal import tensor_util
3536
from tensorflow_probability.python.internal import tensorshape_util
3637
from tensorflow_probability.python.internal import test_util
37-
# from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
38+
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3839

3940

4041
TF2_FRIENDLY_BIJECTORS = (
@@ -185,7 +186,6 @@
185186

186187
COMPOSITE_TENSOR_IS_BROKEN = [
187188
'BatchNormalization', # tf.layers arg
188-
'Inline', # callable
189189
'RationalQuadraticSpline', # TODO(b/185628453): Debug loss of static info.
190190
]
191191

@@ -204,7 +204,7 @@
204204

205205

206206
def is_invert(bijector):
207-
return isinstance(bijector, tfb.Invert)
207+
return isinstance(bijector, (tfb.Invert, invert_lib._Invert))
208208

209209

210210
def is_transform_diagonal(bijector):
@@ -244,8 +244,8 @@ def _constraint(param):
244244

245245

246246
# TODO(b/141098791): Eliminate this.
247-
# @experimental.auto_composite_tensor
248-
class CallableModule(tf.Module): # , experimental.AutoCompositeTensor):
247+
@experimental.auto_composite_tensor
248+
class CallableModule(tf.Module, experimental.AutoCompositeTensor):
249249
"""Convenience object for capturing variables closed over by Inline."""
250250

251251
def __init__(self, fn, varobj):
@@ -887,16 +887,38 @@ def testEquality(self, bijector_name, data):
887887
@hp.given(hps.data())
888888
@tfp_hps.tfp_hp_settings()
889889
def testCompositeTensor(self, bijector_name, data):
890-
# Test that making a composite tensor of this bijector doesn't throw any
891-
# errors.
890+
892891
bijector, event_dim = self._draw_bijector(
893-
bijector_name, data, batch_shape=[],
892+
bijector_name, data,
893+
batch_shape=[],
894+
validate_args=True,
894895
allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
895896
set(COMPOSITE_TENSOR_IS_BROKEN)))
896-
composite_bij = experimental.as_composite(bijector)
897+
898+
if type(bijector) is invert_lib._Invert: # pylint: disable=unidiomatic-typecheck
899+
if isinstance(bijector.bijector, tf.__internal__.CompositeTensor):
900+
raise TypeError('`_Invert` should wrap only non-`CompositeTensor` '
901+
'bijectors.')
902+
self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')
903+
904+
# TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector
905+
# when AutoCT is enabled for meta-bijectors and LinearOperator.
906+
if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN:
907+
composite_bij = experimental.as_composite(bijector)
908+
else:
909+
composite_bij = bijector
910+
911+
if not tf.executing_eagerly():
912+
composite_bij = tf.nest.map_structure(
913+
lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda
914+
if isinstance(x, DeferredTensor) else x),
915+
composite_bij,
916+
expand_composites=True)
917+
918+
self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor)
897919
flat = tf.nest.flatten(composite_bij, expand_composites=True)
898-
unflat = tf.nest.pack_sequence_as(composite_bij, flat,
899-
expand_composites=True)
920+
unflat = tf.nest.pack_sequence_as(
921+
composite_bij, flat, expand_composites=True)
900922

901923
# Compare forward maps before and after compositing.
902924
n = 3
@@ -911,6 +933,26 @@ def testCompositeTensor(self, bijector_name, data):
911933
after_xs = unflat.inverse(ys)
912934
self.assertAllClose(*self.evaluate((before_xs, after_xs)))
913935

936+
# Input to tf.function
937+
self.assertAllClose(
938+
before_ys,
939+
tf.function(lambda b: b.forward(xs))(composite_bij),
940+
rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
941+
atol=COMPOSITE_TENSOR_ATOL[bijector_name])
942+
943+
# Forward mapping: Check differentiation through forward mapping with
944+
# respect to the input and parameter variables. Also check that any
945+
# variables are not referenced overmuch.
946+
xs = self._draw_domain_tensor(bijector, data, event_dim)
947+
wrt_vars = [xs] + [v for v in composite_bij.trainable_variables
948+
if v.dtype.is_floating]
949+
with tf.GradientTape() as tape:
950+
tape.watch(wrt_vars)
951+
# TODO(b/73073515): Fix graph mode gradients with bijector caching.
952+
ys = bijector.forward(xs + 0)
953+
grads = tape.gradient(ys, wrt_vars)
954+
assert_no_none_grad(bijector, 'forward', wrt_vars, grads)
955+
914956

915957
def ensure_nonzero(x):
916958
return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x)

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -790,26 +790,25 @@ def _forward_log_det_jacobian(self, _):
790790
return tf.math.log(self._scale)
791791

792792

793-
# Test disabled temporarily for TFP 0.13 release.
794-
# @test_util.test_all_tf_execution_regimes
795-
# class AutoCompositeTensorBijectorTest(test_util.TestCase):
793+
@test_util.test_all_tf_execution_regimes
794+
class AutoCompositeTensorBijectorTest(test_util.TestCase):
796795

797-
# def test_disable_ct_bijector(self):
796+
def test_disable_ct_bijector(self):
798797

799-
# ct_bijector = CompositeForwardBijector()
800-
# self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor)
798+
ct_bijector = CompositeForwardBijector()
799+
self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor)
801800

802-
# non_ct_bijector = ForwardOnlyBijector()
803-
# self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor)
801+
non_ct_bijector = ForwardOnlyBijector()
802+
self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor)
804803

805-
# flat = tf.nest.flatten(ct_bijector, expand_composites=True)
806-
# unflat = tf.nest.pack_sequence_as(
807-
# ct_bijector, flat, expand_composites=True)
804+
flat = tf.nest.flatten(ct_bijector, expand_composites=True)
805+
unflat = tf.nest.pack_sequence_as(
806+
ct_bijector, flat, expand_composites=True)
808807

809-
# x = tf.constant([2., 3.])
810-
# self.assertAllClose(
811-
# non_ct_bijector.forward(x),
812-
# tf.function(lambda b: b.forward(x))(unflat))
808+
x = tf.constant([2., 3.])
809+
self.assertAllClose(
810+
non_ct_bijector.forward(x),
811+
tf.function(lambda b: b.forward(x))(unflat))
813812

814813

815814
if __name__ == '__main__':

tensorflow_probability/python/bijectors/exp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2424
from tensorflow_probability.python.bijectors import invert
2525
from tensorflow_probability.python.bijectors import power_transform
26+
from tensorflow_probability.python.internal import auto_composite_tensor
2627

2728

2829
__all__ = [
@@ -76,7 +77,8 @@ def __init__(self,
7677
# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
7778
# `AutoCompositeTensor` and ensure `tf.saved_model` still works.
7879
@bijector_lib.auto_composite_tensor_bijector
79-
class Log(invert.Invert):
80+
class Log(invert.Invert,
81+
auto_composite_tensor.AutoCompositeTensor):
8082
"""Compute `Y = log(X)`. This is `Invert(Exp())`."""
8183

8284
def __init__(self, validate_args=False, name='log'):

tensorflow_probability/python/bijectors/expm1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python.bijectors import bijector
2323
from tensorflow_probability.python.bijectors import invert
24+
from tensorflow_probability.python.internal import auto_composite_tensor
2425

2526

2627
__all__ = [
@@ -94,7 +95,7 @@ def _forward_log_det_jacobian(self, x):
9495
# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
9596
# `AutoCompositeTensor`.
9697
@bijector.auto_composite_tensor_bijector
97-
class Log1p(invert.Invert):
98+
class Log1p(invert.Invert, auto_composite_tensor.AutoCompositeTensor):
9899
"""Compute `Y = log1p(X)`. This is `Invert(Expm1())`."""
99100

100101
def __init__(self, validate_args=False, name='log1p'):

tensorflow_probability/python/bijectors/invert.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import tensorflow.compat.v2 as tf
2222

2323
from tensorflow_probability.python.bijectors import bijector as bijector_lib
24-
# from tensorflow_probability.python.internal import auto_composite_tensor
24+
from tensorflow_probability.python.internal import auto_composite_tensor
2525

2626
__all__ = [
2727
'Invert',
2828
]
2929

3030

31-
class Invert(bijector_lib.Bijector):
31+
class _Invert(bijector_lib.Bijector):
3232
"""Bijector which inverts another Bijector.
3333
3434
Example Use: [ExpGammaDistribution (see Background & Context)](
@@ -73,7 +73,7 @@ def __init__(self, bijector, validate_args=False, parameters=None, name=None):
7373
name = name or '_'.join(['invert', bijector.name])
7474
with tf.name_scope(name) as name:
7575
self._bijector = bijector
76-
super(Invert, self).__init__(
76+
super(_Invert, self).__init__(
7777
forward_min_event_ndims=bijector.inverse_min_event_ndims,
7878
inverse_min_event_ndims=bijector.forward_min_event_ndims,
7979
dtype=bijector.dtype,
@@ -138,28 +138,26 @@ def forward_event_ndims(self, event_ndims, **kwargs):
138138
return self.bijector.inverse_event_ndims(event_ndims, **kwargs)
139139

140140

141-
# Temporarily removing AutoCompositeTensor for TFP 0.13 release.
142-
# pylint: disable=line-too-long
143-
# @bijector_lib.auto_composite_tensor_bijector
144-
# class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):
141+
@bijector_lib.auto_composite_tensor_bijector
142+
class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):
145143

146-
# def __new__(cls, *args, **kwargs):
147-
# """Returns an `_Invert` instance if `bijector` is not a `CompositeTensor."""
148-
# if cls is Invert:
149-
# if args:
150-
# bijector = args[0]
151-
# elif 'bijector' in kwargs:
152-
# bijector = kwargs['bijector']
153-
# else:
154-
# raise TypeError('`Invert.__new__()` is missing argument `bijector`.')
144+
def __new__(cls, *args, **kwargs):
145+
"""Returns an `_Invert` instance if `bijector` is not a `CompositeTensor."""
146+
if cls is Invert:
147+
if args:
148+
bijector = args[0]
149+
elif 'bijector' in kwargs:
150+
bijector = kwargs['bijector']
151+
else:
152+
raise TypeError('`Invert.__new__()` is missing argument `bijector`.')
155153

156-
# if not isinstance(bijector, tf.__internal__.CompositeTensor):
157-
# return _Invert(*args, **kwargs)
158-
# return super(Invert, cls).__new__(cls)
154+
if not isinstance(bijector, tf.__internal__.CompositeTensor):
155+
return _Invert(*args, **kwargs)
156+
return super(Invert, cls).__new__(cls)
159157

160158

161-
# Invert.__doc__ = _Invert.__doc__ + '/n' + (
162-
# 'When an `Invert` bijector is constructed, if its `bijector` arg is not a '
163-
# '`CompositeTensor` instance, an `_Invert` instance is returned instead. '
164-
# 'Bijectors subclasses that inherit from `Invert` will also inherit from '
165-
# ' `CompositeTensor`.')
159+
Invert.__doc__ = _Invert.__doc__ + '/n' + (
160+
'When an `Invert` bijector is constructed, if its `bijector` arg is not a '
161+
'`CompositeTensor` instance, an `_Invert` instance is returned instead. '
162+
'Bijectors subclasses that inherit from `Invert` will also inherit from '
163+
' `CompositeTensor`.')

tensorflow_probability/python/experimental/composite_tensor_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,7 @@ def test_basics_mixture_same_family(self):
454454
self.evaluate(unflat.log_prob(.5))
455455

456456
def test_already_composite_tensor(self):
457-
AutoScale = tfp.experimental.auto_composite_tensor( # pylint: disable=invalid-name
458-
tfb.Scale, omit_kwargs=('parameters',),
459-
non_identifying_kwargs=('name',),
460-
module_name=('tfp.bijectors'))
461-
b = AutoScale(2.)
457+
b = tfb.Scale(2.)
462458
b2 = tfp.experimental.as_composite(b)
463459
self.assertIsInstance(b, tf.__internal__.CompositeTensor)
464460
self.assertIs(b, b2)

0 commit comments

Comments
 (0)