Skip to content

Commit b32ab32

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Convert JointMap bijector to AutoCompositeTensor.
PiperOrigin-RevId: 377622021
1 parent f4a7fe0 commit b32ab32

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

tensorflow_probability/python/bijectors/joint_map.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import tensorflow.compat.v2 as tf
22+
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2223
from tensorflow_probability.python.bijectors import composition
2324
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
2425

@@ -28,7 +29,7 @@
2829
]
2930

3031

31-
class JointMap(composition.Composition):
32+
class _JointMap(composition.Composition):
3233
"""Bijector which applies a structure of bijectors in parallel.
3334
3435
This is the "structured" counterpart to `Chain`. Whereas `Chain` applies an
@@ -92,7 +93,7 @@ def __init__(self,
9293
self._nested_structure = self._no_dependency(
9394
nest.map_structure(lambda b: None, bijectors))
9495

95-
super(JointMap, self).__init__(
96+
super(_JointMap, self).__init__(
9697
bijectors=bijectors,
9798
validate_args=validate_args,
9899
forward_min_event_ndims=self._nested_structure,
@@ -115,3 +116,27 @@ def _walk_inverse(self, step_fn, ys, **kwargs):
115116
self._nested_structure,
116117
lambda bij, y: step_fn(bij, y, **kwargs.get(bij.name, {})), # pylint: disable=unnecessary-lambda
117118
self._bijectors, ys, check_types=False)
119+
120+
121+
class JointMap(_JointMap, bijector_lib.AutoCompositeTensorBijector):
122+
123+
def __new__(cls, *args, **kwargs):
124+
"""Returns a `_JointMap` any of `bijectors` is not a `CompositeTensor."""
125+
if cls is JointMap:
126+
if args:
127+
bijectors = args[0]
128+
else:
129+
bijectors = kwargs.get('bijectors')
130+
if bijectors is not None:
131+
if not all(isinstance(b, tf.__internal__.CompositeTensor)
132+
for b in tf.nest.flatten(bijectors)):
133+
return _JointMap(*args, **kwargs)
134+
return super(JointMap, cls).__new__(cls)
135+
136+
137+
JointMap.__doc__ = _JointMap.__doc__ + '\n' + (
138+
'If every element of `bijectors` is a `CompositeTensor`, the resulting '
139+
'`JointMap` bijector is a `CompositeTensor` as well. If any element of '
140+
'`bijectors` is not a `CompositeTensor`, then a non-`CompositeTensor` '
141+
'`_JointMap` instance is created instead. Bijector subclasses that inherit '
142+
'from `JointMap` will also inherit from `CompositeTensor`.')

tensorflow_probability/python/bijectors/joint_map_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,61 @@ def test_inverse_has_event_ndims(self):
129129
bij_reshape.inverse_event_ndims([10]) # expect [9]
130130
self.assertEqual(bij_reshape.inverse_event_ndims([10]), [9])
131131

132+
@test_util.disable_test_for_backend(
133+
disable_numpy=True, disable_jax=True,
134+
reason='Numpy and JAX have no notion of CompositeTensor/saved_model.')
135+
def testCompositeTensor(self):
136+
exp = tfb.Exp()
137+
sp = tfb.Softplus()
138+
aff = tfb.Scale(scale=2.)
139+
bij = tfb.JointMap(bijectors=[exp, sp, aff])
140+
self.assertIsInstance(bij, tf.__internal__.CompositeTensor)
141+
142+
# Bijector may be flattened into `Tensor` components and rebuilt.
143+
flat = tf.nest.flatten(bij, expand_composites=True)
144+
unflat = tf.nest.pack_sequence_as(bij, flat, expand_composites=True)
145+
self.assertIsInstance(unflat, tfb.JointMap)
146+
147+
# Bijector may be input to a `tf.function`-decorated callable.
148+
@tf.function
149+
def call_forward(bij, x):
150+
return bij.forward(x)
151+
152+
x = [1., 2., 3.]
153+
self.assertAllClose(call_forward(unflat, x), bij.forward(x))
154+
155+
# Type spec can be encoded/decoded.
156+
struct_coder = tf.__internal__.saved_model.StructureCoder()
157+
enc = struct_coder.encode_structure(bij._type_spec)
158+
dec = struct_coder.decode_proto(enc)
159+
self.assertEqual(bij._type_spec, dec)
160+
161+
def testNonCompositeTensor(self):
162+
163+
# TODO(b/182603117): Move NonComposite* into test_util.
164+
class NonCompositeScale(tfb.Bijector):
165+
"""Bijector that is not a `CompositeTensor`."""
166+
167+
def __init__(self, scale):
168+
parameters = dict(locals())
169+
self.scale = scale
170+
super(NonCompositeScale, self).__init__(
171+
validate_args=True,
172+
forward_min_event_ndims=0.,
173+
parameters=parameters,
174+
name='non_composite_scale')
175+
176+
def _forward(self, x):
177+
return x * self.scale
178+
179+
exp = tfb.Exp()
180+
scale = NonCompositeScale(scale=tf.constant(3.))
181+
bij = tfb.JointMap(bijectors=[exp, scale])
182+
self.assertNotIsInstance(bij, tf.__internal__.CompositeTensor)
183+
self.assertAllClose(
184+
bij.forward([1., 1.]),
185+
[exp.forward(1.), scale.forward(1.)])
186+
132187

133188
if __name__ == '__main__':
134189
tf.test.main()

0 commit comments

Comments
 (0)