Skip to content

Commit 4132455

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Convert Chain, FillScaleTriL, and LambertWTransform bijectors to AutoCompositeTensor.
PiperOrigin-RevId: 378014572
1 parent b32ab32 commit 4132455

File tree

5 files changed

+101
-28
lines changed

5 files changed

+101
-28
lines changed

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,6 @@
196196
})
197197
COMPOSITE_TENSOR_ATOL = collections.defaultdict(lambda: 1e-6)
198198

199-
# TODO(b/182603117): Enable AutoCT for meta-bijectors.
200-
AUTO_COMPOSITE_TENSOR_IS_BROKEN = [
201-
'FillScaleTriL',
202-
]
203-
204199

205200
def is_invert(bijector):
206201
return isinstance(bijector, (tfb.Invert, invert_lib._Invert))
@@ -915,24 +910,16 @@ def testCompositeTensor(self, bijector_name, data):
915910
'bijectors.')
916911
self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')
917912

918-
# TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector
919-
# when AutoCT is enabled for meta-bijectors and LinearOperator.
920-
if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN:
921-
composite_bij = experimental.as_composite(bijector)
922-
else:
923-
composite_bij = bijector
924-
925913
if not tf.executing_eagerly():
926-
composite_bij = tf.nest.map_structure(
914+
bijector = tf.nest.map_structure(
927915
lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda
928916
if isinstance(x, DeferredTensor) else x),
929-
composite_bij,
917+
bijector,
930918
expand_composites=True)
931919

932-
self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor)
933-
flat = tf.nest.flatten(composite_bij, expand_composites=True)
934-
unflat = tf.nest.pack_sequence_as(
935-
composite_bij, flat, expand_composites=True)
920+
self.assertIsInstance(bijector, tf.__internal__.CompositeTensor)
921+
flat = tf.nest.flatten(bijector, expand_composites=True)
922+
unflat = tf.nest.pack_sequence_as(bijector, flat, expand_composites=True)
936923

937924
# Compare forward maps before and after compositing.
938925
n = 3
@@ -950,15 +937,15 @@ def testCompositeTensor(self, bijector_name, data):
950937
# Input to tf.function
951938
self.assertAllClose(
952939
before_ys,
953-
tf.function(lambda b: b.forward(xs))(composite_bij),
940+
tf.function(lambda b: b.forward(xs))(bijector),
954941
rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
955942
atol=COMPOSITE_TENSOR_ATOL[bijector_name])
956943

957944
# Forward mapping: Check differentiation through forward mapping with
958945
# respect to the input and parameter variables. Also check that any
959946
# variables are not referenced overmuch.
960947
xs = self._draw_domain_tensor(bijector, data, event_dim)
961-
wrt_vars = [xs] + [v for v in composite_bij.trainable_variables
948+
wrt_vars = [xs] + [v for v in bijector.trainable_variables
962949
if v.dtype.is_floating]
963950
with tf.GradientTape() as tape:
964951
tape.watch(wrt_vars)

tensorflow_probability/python/bijectors/chain.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from __future__ import print_function
2121

2222
import tensorflow.compat.v2 as tf
23+
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2324
from tensorflow_probability.python.bijectors import composition
2425
from tensorflow_probability.python.bijectors import ldj_ratio
2526
from tensorflow_probability.python.internal import parameter_properties
@@ -31,7 +32,7 @@
3132
]
3233

3334

34-
class Chain(composition.Composition):
35+
class _Chain(composition.Composition):
3536
"""Bijector which applies a sequence of bijectors.
3637
3738
Example Use:
@@ -124,7 +125,7 @@ def __init__(self,
124125
inverse_min_event_ndims = None # Inferred by base class.
125126

126127
with tf.name_scope(name) as name:
127-
super(Chain, self).__init__(
128+
super(_Chain, self).__init__(
128129
bijectors=bijectors or (),
129130
validate_args=validate_args,
130131
validate_event_size=validate_event_size,
@@ -160,7 +161,32 @@ def _walk_inverse(self, step_fn, y, **kwargs):
160161
return y # Now `x`
161162

162163

163-
@ldj_ratio.RegisterFLDJRatio(Chain)
164+
class Chain(_Chain, bijector_lib.AutoCompositeTensorBijector):
165+
166+
def __new__(cls, *args, **kwargs):
167+
"""Returns a `_Chain` if any of `bijectors` is not a `CompositeTensor."""
168+
if cls is Chain:
169+
if args:
170+
bijectors = args[0]
171+
else:
172+
bijectors = kwargs.get('bijectors')
173+
174+
if bijectors is not None:
175+
if not all(isinstance(b, tf.__internal__.CompositeTensor)
176+
for b in bijectors):
177+
return _Chain(*args, **kwargs)
178+
return super(Chain, cls).__new__(cls)
179+
180+
181+
Chain.__doc__ = _Chain.__doc__ + '\n' + (
182+
'If every element of the `bijectors` list is a `CompositeTensor`, the '
183+
'resulting `Chain` bijector is a `CompositeTensor` as well. If any element '
184+
'of `bijectors` is not a `CompositeTensor`, then a non-`CompositeTensor` '
185+
'`_Chain` instance is created instead. Bijector subclasses that inherit '
186+
'from `Chain` will also inherit from `CompositeTensor`.')
187+
188+
189+
@ldj_ratio.RegisterFLDJRatio(_Chain)
164190
def _fldj_ratio_chain(p, x, q, y):
165191
"""Sum-of-diffs FLDJRatio for Chains."""
166192
if len(p.bijectors) != len(q.bijectors):
@@ -177,7 +203,7 @@ def _fldj_ratio_chain(p, x, q, y):
177203
return tf.add_n(ratios)
178204

179205

180-
@ldj_ratio.RegisterILDJRatio(Chain)
206+
@ldj_ratio.RegisterILDJRatio(_Chain)
181207
def _ildj_ratio_chain(p, x, q, y):
182208
"""Sum-of-diffs ILDJRatio for Chains."""
183209
if len(p.bijectors) != len(q.bijectors):

tensorflow_probability/python/bijectors/chain_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,60 @@ def testDofChangeError(self):
404404
ignore_bij = tfb.Chain([exp, smc], validate_event_size=False)
405405
self.evaluate(ignore_bij.forward_log_det_jacobian([1., 2., 3.], 1))
406406

407+
@test_util.disable_test_for_backend(
408+
disable_numpy=True, disable_jax=True,
409+
reason="Numpy and JAX have no notion of CompositeTensor/saved_model.")
410+
def testCompositeTensor(self):
411+
exp = tfb.Exp()
412+
sp = tfb.Softplus()
413+
aff = tfb.Scale(scale=2.)
414+
chain = tfb.Chain(bijectors=[exp, sp, aff])
415+
self.assertIsInstance(chain, tf.__internal__.CompositeTensor)
416+
417+
# Bijector may be flattened into `Tensor` components and rebuilt.
418+
flat = tf.nest.flatten(chain, expand_composites=True)
419+
unflat = tf.nest.pack_sequence_as(chain, flat, expand_composites=True)
420+
self.assertIsInstance(unflat, tfb.Chain)
421+
422+
# Bijector may be input to a `tf.function`-decorated callable.
423+
@tf.function
424+
def call_forward(bij, x):
425+
return bij.forward(x)
426+
427+
x = tf.ones([2, 3], dtype=tf.float32)
428+
self.assertAllClose(call_forward(unflat, x), chain.forward(x))
429+
430+
# TypeSpec can be encoded/decoded.
431+
struct_coder = tf.__internal__.saved_model.StructureCoder()
432+
enc = struct_coder.encode_structure(chain._type_spec)
433+
dec = struct_coder.decode_proto(enc)
434+
self.assertEqual(chain._type_spec, dec)
435+
436+
def testNonCompositeTensor(self):
437+
438+
class NonCompositeScale(tfb.Bijector):
439+
"""Bijector that is not a `CompositeTensor`."""
440+
441+
def __init__(self, scale):
442+
parameters = dict(locals())
443+
self.scale = scale
444+
super(NonCompositeScale, self).__init__(
445+
validate_args=True,
446+
forward_min_event_ndims=0.,
447+
parameters=parameters,
448+
name="non_composite_scale")
449+
450+
def _forward(self, x):
451+
return x * self.scale
452+
453+
def _inverse(self, y):
454+
return y / self.scale
455+
456+
exp = tfb.Exp()
457+
scale = NonCompositeScale(scale=tf.constant(3.))
458+
chain = tfb.Chain(bijectors=[exp, scale])
459+
self.assertNotIsInstance(chain, tf.__internal__.CompositeTensor)
460+
self.assertAllClose(chain.forward([1.]), exp.forward(scale.forward([1.])))
407461

408462
if __name__ == "__main__":
409463
tf.test.main()

tensorflow_probability/python/bijectors/fill_scale_tril.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
]
3434

3535

36-
# TODO(b/182603117): Enable AutoCompositeTensor once Chain subclasses it.
3736
class FillScaleTriL(chain.Chain):
3837
"""Transforms unconstrained vectors to TriL matrices with positive diagonal.
3938
@@ -91,7 +90,8 @@ def __init__(self,
9190
9291
Args:
9392
diag_bijector: `Bijector` instance, used to transform the output diagonal
94-
to be positive.
93+
to be positive. Must be an instance of `tf.__internal__.CompositeTensor`
94+
(including `tfb.AutoCompositeTensorBijector`).
9595
Default value: `None` (i.e., `tfb.Softplus()`).
9696
diag_shift: Float value broadcastable and added to all diagonal entries
9797
after applying the `diag_bijector`. Setting a positive
@@ -104,11 +104,18 @@ def __init__(self,
104104
Default value: `False` (i.e., arguments are not validated).
105105
name: Python `str` name given to ops managed by this object.
106106
Default value: `fill_scale_tril`.
107+
108+
Raises:
109+
TypeError, if `diag_bijector` is not an instance of
110+
`tf.__internal__.CompositeTensor`.
107111
"""
108112
parameters = dict(locals())
109113
with tf.name_scope(name) as name:
110114
if diag_bijector is None:
111115
diag_bijector = softplus.Softplus(validate_args=validate_args)
116+
if not isinstance(diag_bijector, tf.__internal__.CompositeTensor):
117+
raise TypeError('`diag_bijector` must be an instance of '
118+
'`tf.__internal__.CompositeTensor`.')
112119

113120
if diag_shift is not None:
114121
dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32)

tensorflow_probability/python/bijectors/lambertw_transform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _w_delta_squared(z, delta):
106106

107107

108108
# Private class that implements the heavy tail transformation.
109-
class _HeavyTailOnly(bijector.Bijector):
109+
class _HeavyTailOnly(bijector.AutoCompositeTensorBijector):
110110
"""Heavy tail transformation for Lambert W x F distributions.
111111
112112
This bijector defines the transformation z = u * exp(0.5 * delta * u**2)
@@ -178,7 +178,6 @@ def _inverse_log_det_jacobian(self, y):
178178
# fix batch_shape inconsistencies when running distribution_properties_test.
179179

180180

181-
# TODO(b/182603117): Enable AutoCompositeTensor when Chain is enabled.
182181
class LambertWTail(chain.Chain):
183182
"""LambertWTail transformation for heavy-tail Lambert W x F random variables.
184183

0 commit comments

Comments
 (0)