Skip to content

Commit 2f18733

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Convert Blockwise bijector to AutoCompositeTensor.
PiperOrigin-RevId: 375833034
1 parent 0df33ac commit 2f18733

File tree

4 files changed

+102
-12
lines changed

4 files changed

+102
-12
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ multi_substrate_py_library(
306306
# numpy dep,
307307
# tensorflow dep,
308308
"//tensorflow_probability/python/internal:assert_util",
309+
"//tensorflow_probability/python/internal:auto_composite_tensor",
309310
"//tensorflow_probability/python/internal:prefer_static",
310311
"//tensorflow_probability/python/internal:tensorshape_util",
311312
],

tensorflow_probability/python/bijectors/blockwise.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
import numpy as np
2222
import tensorflow.compat.v2 as tf
2323

24+
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2425
from tensorflow_probability.python.bijectors import chain
26+
from tensorflow_probability.python.bijectors import composition
2527
from tensorflow_probability.python.bijectors import invert
2628
from tensorflow_probability.python.bijectors import joint_map
2729
from tensorflow_probability.python.bijectors import split
2830
from tensorflow_probability.python.internal import assert_util
31+
from tensorflow_probability.python.internal import auto_composite_tensor
2932
from tensorflow_probability.python.internal import prefer_static as ps
3033
from tensorflow_probability.python.internal import tensorshape_util
3134

@@ -42,8 +45,7 @@ def _get_static_splits(splits):
4245
return splits if static_splits is None else static_splits
4346

4447

45-
# TODO(b/182603117): Enable AutoCompositeTensor once Chain subclasses it.
46-
class Blockwise(chain.Chain):
48+
class _Blockwise(composition.Composition):
4749
"""Bijector which applies a list of bijectors to blocks of a `Tensor`.
4850
4951
More specifically, given [F_0, F_1, ... F_n] which are scalar or vector
@@ -151,9 +153,12 @@ def __init__(self,
151153
name='concat')
152154

153155
self._maybe_changes_size = maybe_changes_size
154-
super(Blockwise, self).__init__(
155-
bijectors=[b_concat, b_joint, b_split],
156+
self._chain = chain.Chain(
157+
[b_concat, b_joint, b_split], validate_args=validate_args)
158+
super(_Blockwise, self).__init__(
159+
bijectors=self._chain.bijectors,
156160
validate_args=validate_args,
161+
validate_event_size=True,
157162
parameters=parameters,
158163
name=name)
159164

@@ -186,13 +191,13 @@ def inverse_block_sizes(self):
186191
return self._b_concat.split_sizes
187192

188193
def _forward(self, x, **kwargs):
189-
y = super(Blockwise, self)._forward(x, **kwargs)
194+
y = super(_Blockwise, self)._forward(x, **kwargs)
190195
if not self._maybe_changes_size:
191196
tensorshape_util.set_shape(y, x.shape)
192197
return y
193198

194199
def _inverse(self, y, **kwargs):
195-
x = super(Blockwise, self)._inverse(y, **kwargs)
200+
x = super(_Blockwise, self)._inverse(y, **kwargs)
196201
if not self._maybe_changes_size:
197202
tensorshape_util.set_shape(x, y.shape)
198203
return x
@@ -220,19 +225,19 @@ def _inverse_event_shape(self, output_shape):
220225
def _forward_event_shape_tensor(self, x, **kwargs):
221226
if not self._maybe_changes_size:
222227
return x
223-
return super(Blockwise, self)._forward_event_shape_tensor(x, **kwargs)
228+
return super(_Blockwise, self)._forward_event_shape_tensor(x, **kwargs)
224229

225230
def _inverse_event_shape_tensor(self, y, **kwargs):
226231
if not self._maybe_changes_size:
227232
return y
228-
return super(Blockwise, self)._inverse_event_shape_tensor(y, **kwargs)
233+
return super(_Blockwise, self)._inverse_event_shape_tensor(y, **kwargs)
229234

230235
def _walk_forward(self, step_fn, x, **kwargs):
231-
return super(Blockwise, self)._walk_forward(
236+
return self._chain._walk_forward( # pylint: disable=protected-access
232237
step_fn, x, **{self._b_joint.name: kwargs})
233238

234239
def _walk_inverse(self, step_fn, x, **kwargs):
235-
return super(Blockwise, self)._walk_inverse(
240+
return self._chain._walk_inverse( # pylint: disable=protected-access
236241
step_fn, x, **{self._b_joint.name: kwargs})
237242

238243

@@ -263,3 +268,32 @@ def _validate_block_sizes(block_sizes, bijectors, validate_args):
263268
# Set the shape if missing to pass statically known structure to split.
264269
tensorshape_util.set_shape(block_sizes, [len(bijectors)])
265270
return block_sizes
271+
272+
273+
@bijector_lib.auto_composite_tensor_bijector
274+
class Blockwise(_Blockwise, auto_composite_tensor.AutoCompositeTensor):
275+
276+
def __new__(cls, *args, **kwargs):
277+
"""Returns a `_Blockwise` if any of `bijectors` is not `CompositeTensor."""
278+
if cls is Blockwise:
279+
if args:
280+
bijectors = args[0]
281+
elif 'bijectors' in kwargs:
282+
bijectors = kwargs['bijectors']
283+
else:
284+
raise TypeError(
285+
'`Blockwise.__new__()` is missing argument `bijectors`.')
286+
287+
if not all(isinstance(b, tf.__internal__.CompositeTensor)
288+
for b in bijectors):
289+
return _Blockwise(*args, **kwargs)
290+
return super(Blockwise, cls).__new__(cls)
291+
292+
293+
Blockwise.__doc__ = _Blockwise.__doc__ + '\n' + (
294+
'If every element of the `bijectors` list is a `CompositeTensor`, the '
295+
'resulting `Blockwise` bijector is a `CompositeTensor` as well. If any '
296+
'element of `bijectors` is not a `CompositeTensor`, then a '
297+
'non-`CompositeTensor` `_Blockwise` instance is created instead. Bijector '
298+
'subclasses that inherit from `Blockwise` will also inherit from '
299+
'`CompositeTensor`.')

tensorflow_probability/python/bijectors/blockwise_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,60 @@ def testKwargs(self):
294294
bijectors[0]._inverse_log_det_jacobian.assert_called_with(mock.ANY, arg=7)
295295
bijectors[1]._inverse_log_det_jacobian.assert_called_with(mock.ANY, arg=8)
296296

297+
@test_util.disable_test_for_backend(
298+
disable_numpy=True, disable_jax=True,
299+
reason='Numpy and JAX have no notion of CompositeTensor.')
300+
def testCompositeTensor(self):
301+
exp = tfb.Exp()
302+
sp = tfb.Softplus()
303+
aff = tfb.Scale(scale=2.)
304+
blockwise = tfb.Blockwise(bijectors=[exp, sp, aff])
305+
self.assertIsInstance(blockwise, tf.__internal__.CompositeTensor)
306+
307+
# Bijector may be flattened into `Tensor` components and rebuilt.
308+
flat = tf.nest.flatten(blockwise, expand_composites=True)
309+
unflat = tf.nest.pack_sequence_as(blockwise, flat, expand_composites=True)
310+
self.assertIsInstance(unflat, tfb.Blockwise)
311+
312+
# Bijector may be input to a `tf.function`-decorated callable.
313+
@tf.function
314+
def call_forward(bij, x):
315+
return bij.forward(x)
316+
317+
x = tf.ones([2, 3], dtype=tf.float32)
318+
self.assertAllClose(call_forward(unflat, x), blockwise.forward(x))
319+
320+
# Type spec can be encoded/decoded.
321+
struct_coder = tf.__internal__.saved_model.StructureCoder()
322+
enc = struct_coder.encode_structure(blockwise._type_spec)
323+
dec = struct_coder.decode_proto(enc)
324+
self.assertEqual(blockwise._type_spec, dec)
325+
326+
def testNonCompositeTensor(self):
327+
328+
class NonCompositeScale(tfb.Bijector):
329+
"""Bijector that is not a `CompositeTensor`."""
330+
331+
def __init__(self, scale):
332+
parameters = dict(locals())
333+
self.scale = scale
334+
super(NonCompositeScale, self).__init__(
335+
validate_args=True,
336+
forward_min_event_ndims=0.,
337+
parameters=parameters,
338+
name='non_composite_scale')
339+
340+
def _forward(self, x):
341+
return x * self.scale
342+
343+
exp = tfb.Exp()
344+
scale = NonCompositeScale(scale=tf.constant(3.))
345+
blockwise = tfb.Blockwise(bijectors=[exp, scale])
346+
self.assertNotIsInstance(blockwise, tf.__internal__.CompositeTensor)
347+
self.assertAllClose(
348+
blockwise.forward([1., 1.]),
349+
tf.convert_to_tensor([exp.forward(1.), scale.forward(1.)]))
350+
297351

298352
if __name__ == '__main__':
299353
tf.test.main()

tensorflow_probability/python/bijectors/glow_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tensorflow.compat.v2 as tf
2121
import tensorflow_probability as tfp
2222
from tensorflow_probability.python.bijectors import bijector_test_util
23+
from tensorflow_probability.python.bijectors import blockwise
2324
from tensorflow_probability.python.bijectors import composition
2425
from tensorflow_probability.python.internal import test_util
2526
from tensorflow_probability.python.math.gradient import batch_jacobian
@@ -116,7 +117,7 @@ def testDataInit_inverse(self):
116117

117118
x = bijection.bijectors[0].inverse(x)
118119
for b in bijection.bijectors[1].bijectors:
119-
if isinstance(b, tfb.blockwise.Blockwise):
120+
if isinstance(b, blockwise._Blockwise):
120121
x1, x2 = tf.split(x, splits[-2-nblocks], axis=-1)
121122

122123
for bb in b.bijectors[0].bijectors:
@@ -157,7 +158,7 @@ def testDataInit_forward(self):
157158
splits = [[bs[0]+bs[1], bs[2]] for bs in splits]
158159

159160
for b in reversed(bijection.bijectors[1].bijectors):
160-
if isinstance(b, tfb.blockwise.Blockwise):
161+
if isinstance(b, blockwise._Blockwise):
161162
y1, y2 = tf.split(y, splits[nblocks], axis=-1)
162163

163164
for bb in reversed(b.bijectors[0].bijectors):

0 commit comments

Comments
 (0)