Skip to content

Commit 78cf6f8

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Temporarily remove AutoCompositeTensor from bijectors for TFP 0.13 release.
PiperOrigin-RevId: 374899448
1 parent b9fc082 commit 78cf6f8

File tree

9 files changed

+159
-173
lines changed

9 files changed

+159
-173
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import abc
2222
import contextlib
23-
import functools
23+
# import functools
2424

2525
# Dependency imports
2626
import numpy as np
@@ -1599,33 +1599,39 @@ def _composite_tensor_shape_params(self):
15991599
return ()
16001600

16011601

1602-
class AutoCompositeTensorBijector(
1603-
Bijector, auto_composite_tensor.AutoCompositeTensor):
1604-
r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.
1602+
# Temporarily disable AutoCT for TFP 0.13 release
1603+
# class AutoCompositeTensorBijector(
1604+
# Bijector, auto_composite_tensor.AutoCompositeTensor):
1605+
# r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.
16051606

1606-
`CompositeTensor` objects are able to pass in and out of `tf.function` and
1607-
`tf.while_loop`, or serve as part of the signature of a TF saved model.
1608-
`Bijector` subclasses that follow the contract of
1609-
`tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s
1610-
by inheriting from `AutoCompositeTensorBijector` and applying a class
1611-
decorator as shown here:
1607+
# `CompositeTensor` objects are able to pass in and out of `tf.function` and
1608+
# `tf.while_loop`, or serve as part of the signature of a TF saved model.
1609+
# `Bijector` subclasses that follow the contract of
1610+
# `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s # pylint: disable=line-too-long
1611+
# by inheriting from `AutoCompositeTensorBijector` and applying a class
1612+
# decorator as shown here:
16121613

1613-
```python
1614-
@tfp.experimental.auto_composite_tensor(
1615-
omit_kwargs=('name',), module_name='my_module')
1616-
class MyBijector(tfb.AutoCompositeTensorBijector):
1614+
# ```python
1615+
# @tfp.experimental.auto_composite_tensor(
1616+
# omit_kwargs=('name',), module_name='my_module')
1617+
# class MyBijector(tfb.AutoCompositeTensorBijector):
16171618

1618-
# The remainder of the subclass implementation is unchanged.
1619-
```
1620-
"""
1621-
pass
1619+
# # The remainder of the subclass implementation is unchanged.
1620+
# ```
1621+
# """
1622+
# pass
1623+
1624+
1625+
# auto_composite_tensor_bijector = functools.partial(
1626+
# auto_composite_tensor.auto_composite_tensor,
1627+
# omit_kwargs=('parameters',),
1628+
# non_identifying_kwargs=('name',),
1629+
# module_name='tfp.bijectors')
1630+
1631+
AutoCompositeTensorBijector = Bijector
16221632

16231633

1624-
auto_composite_tensor_bijector = functools.partial(
1625-
auto_composite_tensor.auto_composite_tensor,
1626-
omit_kwargs=('parameters',),
1627-
non_identifying_kwargs=('name',),
1628-
module_name='tfp.bijectors')
1634+
auto_composite_tensor_bijector = lambda cls, **kwargs: cls
16291635

16301636

16311637
def check_valid_ndims(ndims, validate=True):

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@
2828
from tensorflow_probability.python import bijectors as tfb
2929
from tensorflow_probability.python import experimental
3030
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
31-
from tensorflow_probability.python.bijectors import invert as invert_lib
3231
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
3332
from tensorflow_probability.python.internal import prefer_static
3433
from tensorflow_probability.python.internal import samplers
3534
from tensorflow_probability.python.internal import tensor_util
3635
from tensorflow_probability.python.internal import tensorshape_util
3736
from tensorflow_probability.python.internal import test_util
38-
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
37+
# from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3938

4039

4140
TF2_FRIENDLY_BIJECTORS = (
@@ -186,6 +185,7 @@
186185

187186
COMPOSITE_TENSOR_IS_BROKEN = [
188187
'BatchNormalization', # tf.layers arg
188+
'Inline', # callable
189189
'RationalQuadraticSpline', # TODO(b/185628453): Debug loss of static info.
190190
]
191191

@@ -204,7 +204,7 @@
204204

205205

206206
def is_invert(bijector):
207-
return isinstance(bijector, (tfb.Invert, invert_lib._Invert))
207+
return isinstance(bijector, tfb.Invert)
208208

209209

210210
def is_transform_diagonal(bijector):
@@ -244,8 +244,8 @@ def _constraint(param):
244244

245245

246246
# TODO(b/141098791): Eliminate this.
247-
@experimental.auto_composite_tensor
248-
class CallableModule(tf.Module, experimental.AutoCompositeTensor):
247+
# @experimental.auto_composite_tensor
248+
class CallableModule(tf.Module): # , experimental.AutoCompositeTensor):
249249
"""Convenience object for capturing variables closed over by Inline."""
250250

251251
def __init__(self, fn, varobj):
@@ -887,38 +887,16 @@ def testEquality(self, bijector_name, data):
887887
@hp.given(hps.data())
888888
@tfp_hps.tfp_hp_settings()
889889
def testCompositeTensor(self, bijector_name, data):
890-
890+
# Test that making a composite tensor of this bijector doesn't throw any
891+
# errors.
891892
bijector, event_dim = self._draw_bijector(
892-
bijector_name, data,
893-
batch_shape=[],
894-
validate_args=True,
893+
bijector_name, data, batch_shape=[],
895894
allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
896895
set(COMPOSITE_TENSOR_IS_BROKEN)))
897-
898-
if type(bijector) is invert_lib._Invert: # pylint: disable=unidiomatic-typecheck
899-
if isinstance(bijector.bijector, tf.__internal__.CompositeTensor):
900-
raise TypeError('`_Invert` should wrap only non-`CompositeTensor` '
901-
'bijectors.')
902-
self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')
903-
904-
# TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector
905-
# when AutoCT is enabled for meta-bijectors and LinearOperator.
906-
if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN:
907-
composite_bij = experimental.as_composite(bijector)
908-
else:
909-
composite_bij = bijector
910-
911-
if not tf.executing_eagerly():
912-
composite_bij = tf.nest.map_structure(
913-
lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda
914-
if isinstance(x, DeferredTensor) else x),
915-
composite_bij,
916-
expand_composites=True)
917-
918-
self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor)
896+
composite_bij = experimental.as_composite(bijector)
919897
flat = tf.nest.flatten(composite_bij, expand_composites=True)
920-
unflat = tf.nest.pack_sequence_as(
921-
composite_bij, flat, expand_composites=True)
898+
unflat = tf.nest.pack_sequence_as(composite_bij, flat,
899+
expand_composites=True)
922900

923901
# Compare forward maps before and after compositing.
924902
n = 3
@@ -933,26 +911,6 @@ def testCompositeTensor(self, bijector_name, data):
933911
after_xs = unflat.inverse(ys)
934912
self.assertAllClose(*self.evaluate((before_xs, after_xs)))
935913

936-
# Input to tf.function
937-
self.assertAllClose(
938-
before_ys,
939-
tf.function(lambda b: b.forward(xs))(composite_bij),
940-
rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
941-
atol=COMPOSITE_TENSOR_ATOL[bijector_name])
942-
943-
# Forward mapping: Check differentiation through forward mapping with
944-
# respect to the input and parameter variables. Also check that any
945-
# variables are not referenced overmuch.
946-
xs = self._draw_domain_tensor(bijector, data, event_dim)
947-
wrt_vars = [xs] + [v for v in composite_bij.trainable_variables
948-
if v.dtype.is_floating]
949-
with tf.GradientTape() as tape:
950-
tape.watch(wrt_vars)
951-
# TODO(b/73073515): Fix graph mode gradients with bijector caching.
952-
ys = bijector.forward(xs + 0)
953-
grads = tape.gradient(ys, wrt_vars)
954-
assert_no_none_grad(bijector, 'forward', wrt_vars, grads)
955-
956914

957915
def ensure_nonzero(x):
958916
return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x)

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -790,25 +790,26 @@ def _forward_log_det_jacobian(self, _):
790790
return tf.math.log(self._scale)
791791

792792

793-
@test_util.test_all_tf_execution_regimes
794-
class AutoCompositeTensorBijectorTest(test_util.TestCase):
793+
# Test disabled temporarily for TFP 0.13 release.
794+
# @test_util.test_all_tf_execution_regimes
795+
# class AutoCompositeTensorBijectorTest(test_util.TestCase):
795796

796-
def test_disable_ct_bijector(self):
797+
# def test_disable_ct_bijector(self):
797798

798-
ct_bijector = CompositeForwardBijector()
799-
self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor)
799+
# ct_bijector = CompositeForwardBijector()
800+
# self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor)
800801

801-
non_ct_bijector = ForwardOnlyBijector()
802-
self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor)
802+
# non_ct_bijector = ForwardOnlyBijector()
803+
# self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor)
803804

804-
flat = tf.nest.flatten(ct_bijector, expand_composites=True)
805-
unflat = tf.nest.pack_sequence_as(
806-
ct_bijector, flat, expand_composites=True)
805+
# flat = tf.nest.flatten(ct_bijector, expand_composites=True)
806+
# unflat = tf.nest.pack_sequence_as(
807+
# ct_bijector, flat, expand_composites=True)
807808

808-
x = tf.constant([2., 3.])
809-
self.assertAllClose(
810-
non_ct_bijector.forward(x),
811-
tf.function(lambda b: b.forward(x))(unflat))
809+
# x = tf.constant([2., 3.])
810+
# self.assertAllClose(
811+
# non_ct_bijector.forward(x),
812+
# tf.function(lambda b: b.forward(x))(unflat))
812813

813814

814815
if __name__ == '__main__':

tensorflow_probability/python/bijectors/exp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2424
from tensorflow_probability.python.bijectors import invert
2525
from tensorflow_probability.python.bijectors import power_transform
26-
from tensorflow_probability.python.internal import auto_composite_tensor
2726

2827

2928
__all__ = [
@@ -77,8 +76,7 @@ def __init__(self,
7776
# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
7877
# `AutoCompositeTensor` and ensure `tf.saved_model` still works.
7978
@bijector_lib.auto_composite_tensor_bijector
80-
class Log(invert.Invert,
81-
auto_composite_tensor.AutoCompositeTensor):
79+
class Log(invert.Invert):
8280
"""Compute `Y = log(X)`. This is `Invert(Exp())`."""
8381

8482
def __init__(self, validate_args=False, name='log'):

tensorflow_probability/python/bijectors/expm1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python.bijectors import bijector
2323
from tensorflow_probability.python.bijectors import invert
24-
from tensorflow_probability.python.internal import auto_composite_tensor
2524

2625

2726
__all__ = [
@@ -95,7 +94,7 @@ def _forward_log_det_jacobian(self, x):
9594
# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
9695
# `AutoCompositeTensor`.
9796
@bijector.auto_composite_tensor_bijector
98-
class Log1p(invert.Invert, auto_composite_tensor.AutoCompositeTensor):
97+
class Log1p(invert.Invert):
9998
"""Compute `Y = log1p(X)`. This is `Invert(Expm1())`."""
10099

101100
def __init__(self, validate_args=False, name='log1p'):

tensorflow_probability/python/bijectors/invert.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import tensorflow.compat.v2 as tf
2222

2323
from tensorflow_probability.python.bijectors import bijector as bijector_lib
24-
from tensorflow_probability.python.internal import auto_composite_tensor
24+
# from tensorflow_probability.python.internal import auto_composite_tensor
2525

2626
__all__ = [
2727
'Invert',
2828
]
2929

3030

31-
class _Invert(bijector_lib.Bijector):
31+
class Invert(bijector_lib.Bijector):
3232
"""Bijector which inverts another Bijector.
3333
3434
Example Use: [ExpGammaDistribution (see Background & Context)](
@@ -73,7 +73,7 @@ def __init__(self, bijector, validate_args=False, parameters=None, name=None):
7373
name = name or '_'.join(['invert', bijector.name])
7474
with tf.name_scope(name) as name:
7575
self._bijector = bijector
76-
super(_Invert, self).__init__(
76+
super(Invert, self).__init__(
7777
forward_min_event_ndims=bijector.inverse_min_event_ndims,
7878
inverse_min_event_ndims=bijector.forward_min_event_ndims,
7979
dtype=bijector.dtype,
@@ -138,26 +138,28 @@ def forward_event_ndims(self, event_ndims, **kwargs):
138138
return self.bijector.inverse_event_ndims(event_ndims, **kwargs)
139139

140140

141-
@bijector_lib.auto_composite_tensor_bijector
142-
class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):
141+
# Temporarily removing AutoCompositeTensor for TFP 0.13 release.
142+
# pylint: disable=line-too-long
143+
# @bijector_lib.auto_composite_tensor_bijector
144+
# class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):
143145

144-
def __new__(cls, *args, **kwargs):
145-
"""Returns an `_Invert` instance if `bijector` is not a `CompositeTensor."""
146-
if cls is Invert:
147-
if args:
148-
bijector = args[0]
149-
elif 'bijector' in kwargs:
150-
bijector = kwargs['bijector']
151-
else:
152-
raise TypeError('`Invert.__new__()` is missing argument `bijector`.')
146+
# def __new__(cls, *args, **kwargs):
147+
# """Returns an `_Invert` instance if `bijector` is not a `CompositeTensor."""
148+
# if cls is Invert:
149+
# if args:
150+
# bijector = args[0]
151+
# elif 'bijector' in kwargs:
152+
# bijector = kwargs['bijector']
153+
# else:
154+
# raise TypeError('`Invert.__new__()` is missing argument `bijector`.')
153155

154-
if not isinstance(bijector, tf.__internal__.CompositeTensor):
155-
return _Invert(*args, **kwargs)
156-
return super(Invert, cls).__new__(cls)
156+
# if not isinstance(bijector, tf.__internal__.CompositeTensor):
157+
# return _Invert(*args, **kwargs)
158+
# return super(Invert, cls).__new__(cls)
157159

158160

159-
Invert.__doc__ = _Invert.__doc__ + '/n' + (
160-
'When an `Invert` bijector is constructed, if its `bijector` arg is not a '
161-
'`CompositeTensor` instance, an `_Invert` instance is returned instead. '
162-
'Bijectors subclasses that inherit from `Invert` will also inherit from '
163-
' `CompositeTensor`.')
161+
# Invert.__doc__ = _Invert.__doc__ + '/n' + (
162+
# 'When an `Invert` bijector is constructed, if its `bijector` arg is not a '
163+
# '`CompositeTensor` instance, an `_Invert` instance is returned instead. '
164+
# 'Bijectors subclasses that inherit from `Invert` will also inherit from '
165+
# ' `CompositeTensor`.')

tensorflow_probability/python/experimental/composite_tensor_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,11 @@ def test_basics_mixture_same_family(self):
454454
self.evaluate(unflat.log_prob(.5))
455455

456456
def test_already_composite_tensor(self):
457-
b = tfb.Scale(2.)
457+
AutoScale = tfp.experimental.auto_composite_tensor( # pylint: disable=invalid-name
458+
tfb.Scale, omit_kwargs=('parameters',),
459+
non_identifying_kwargs=('name',),
460+
module_name=('tfp.bijectors'))
461+
b = AutoScale(2.)
458462
b2 = tfp.experimental.as_composite(b)
459463
self.assertIsInstance(b, tf.__internal__.CompositeTensor)
460464
self.assertIs(b, b2)

0 commit comments

Comments
 (0)