Skip to content

Commit 39d9e61

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Subclass Composition instead of Chain and Blockwise in Glow bijector, in preparation for converting Chain and Blockwise to CompositeTensor.
PiperOrigin-RevId: 375154471
1 parent 847489b commit 39d9e61

File tree

3 files changed

+48
-21
lines changed

3 files changed

+48
-21
lines changed

tensorflow_probability/python/bijectors/composition.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import abc
2221
import collections
2322
import functools
2423
import sys
@@ -115,7 +114,11 @@ def __init__(self,
115114
"""Instantiates a Composition of bijectors.
116115
117116
Args:
118-
bijectors: A nest-compatible structure of bijector instances.
117+
bijectors: A nest-compatible structure of bijector instances or a
118+
`Composition` bijector. If `bijectors` is a nested structure, then
119+
`_walk_forward` and `_walk_inverse` must be implemented. If `bijectors`
120+
is a `Composition` bijector, `_walk_forward` and `_walk_inverse` call
121+
its corresponding methods.
119122
name: Name of this bijector.
120123
parameters: Dictionary of parameters used to initialize this bijector.
121124
These must be the exact values passed to `__init__`.
@@ -195,6 +198,8 @@ def _parameter_properties(cls, dtype):
195198

196199
@property
197200
def bijectors(self):
201+
if isinstance(self._bijectors, Composition):
202+
return self._bijectors.bijectors # pylint: disable=protected-access
198203
return self._bijectors
199204

200205
@property
@@ -346,9 +351,6 @@ def transform_wrapper(bij, packed_ys, **nested):
346351
transform_wrapper, packed_args, **kwargs)
347352
return unpack_structs_like(self.forward_min_event_ndims, packed_result)
348353

349-
### Abstract Methods
350-
351-
@abc.abstractmethod
352354
def _walk_forward(self, step_fn, argument, **kwargs):
353355
"""Subclass stub for forward-mode traversals.
354356
@@ -363,11 +365,17 @@ def _walk_forward(self, step_fn, argument, **kwargs):
363365
`_call_walk_forward` instead.
364366
argument: A (structure of) Tensor matching `self.forward_min_event_ndims`.
365367
**kwargs: Keyword arguments to be forwarded to nested bijectors.
368+
Returns:
369+
bijectors_forward: The value returned by `self._bijectors._walk_forward`
370+
if `self._bijectors` is a `Composition` bijector.
371+
Raises:
372+
NotImplementedError, if `self._bijectors` is a nested structure.
366373
"""
374+
if isinstance(self._bijectors, Composition):
375+
return self._bijectors._walk_forward(step_fn, argument, **kwargs) # pylint: disable=protected-access
367376
raise NotImplementedError('{}._walk_forward is not implemented'.format(
368377
type(self).__name__))
369378

370-
@abc.abstractmethod
371379
def _walk_inverse(self, step_fn, argument, **kwargs):
372380
"""Subclass stub for inverse-mode traversals.
373381
@@ -382,7 +390,14 @@ def _walk_inverse(self, step_fn, argument, **kwargs):
382390
`_call_walk_inverse` instead.
383391
argument: A (structure of) Tensor matching `self.inverse_min_event_ndims`.
384392
**kwargs: Keyword arguments to be forwarded to nested bijectors.
393+
Returns:
394+
bijectors_inverse: The value returned by `self._bijectors._walk_inverse`
395+
if `self._bijectors` is a `Composition` bijector.
396+
Raises:
397+
NotImplementedError, if `self._bijectors` is a nested structure.
385398
"""
399+
if isinstance(self._bijectors, Composition):
400+
return self._bijectors._walk_inverse(step_fn, argument, **kwargs) # pylint: disable=protected-access
386401
raise NotImplementedError('{}._walk_inverse is not implemented'.format(
387402
type(self).__name__))
388403

tensorflow_probability/python/bijectors/glow.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow_probability.python.bijectors import bijector
2626
from tensorflow_probability.python.bijectors import blockwise
2727
from tensorflow_probability.python.bijectors import chain
28+
from tensorflow_probability.python.bijectors import composition
2829
from tensorflow_probability.python.bijectors import exp
2930
from tensorflow_probability.python.bijectors import identity
3031
from tensorflow_probability.python.bijectors import invert
@@ -50,7 +51,7 @@
5051
]
5152

5253

53-
class Glow(chain.Chain):
54+
class Glow(composition.Composition):
5455
r"""Implements the Glow Bijector from Kingma & Dhariwal (2018)[1].
5556
5657
Overview: `Glow` is a chain of bijectors which transforms a rank-1 tensor
@@ -319,6 +320,7 @@ def __init__(self,
319320
name: Python `str`, name given to ops managed by this object.
320321
Default value: `'glow'`.
321322
"""
323+
parameters = dict(locals())
322324
# Make sure that the input shape is fully defined.
323325
if not tensorshape_util.is_fully_defined(output_shape):
324326
raise ValueError('Shape must be fully defined.')
@@ -441,10 +443,13 @@ def __init__(self,
441443
]))
442444

443445
glow_chain = glow_chain[::-1]
444-
# To finish off, we initialize the bijector with the chain we've built
445-
# This way, the rest of the model attributes are taken care of for us.
446+
# To finish off, we build a bijector that chains the components together
447+
# sequentially.
446448
super(Glow, self).__init__(
447-
bijectors=glow_chain, validate_args=validate_args, name=name)
449+
bijectors=chain.Chain(glow_chain, validate_args=validate_args),
450+
validate_args=validate_args,
451+
parameters=parameters,
452+
name=name)
448453

449454
@classmethod
450455
def _parameter_properties(cls, dtype):
@@ -491,7 +496,7 @@ def blockwise_splits(self):
491496
return self._blockwise_splits
492497

493498

494-
class ExitBijector(blockwise.Blockwise):
499+
class ExitBijector(composition.Composition):
495500
"""The spatial coupling bijector used in Glow.
496501
497502
This bijector consists of a blockwise bijector of a realNVP bijector. It is
@@ -524,7 +529,7 @@ def __init__(self,
524529
tensor with event shape `input_shape`, and returns a tensor with shape
525530
`input_shape`.
526531
"""
527-
532+
parameters = dict(locals())
528533
nleave, ngrab, npass = blockwise_splits
529534

530535
new_input_shape = input_shape[:-1]+(nleave,)
@@ -551,7 +556,10 @@ def __init__(self,
551556
bijector_fn=exit_bijector_fn)
552557

553558
super(ExitBijector, self).__init__(
554-
[shift_distribution, identity.Identity()], [nleave + ngrab, npass])
559+
blockwise.Blockwise(
560+
[shift_distribution, identity.Identity()], [nleave + ngrab, npass]),
561+
parameters=parameters,
562+
name='exit_bijector')
555563

556564
@staticmethod
557565
def make_bijector_fn(layer, target_shape, scale_fn=tf.nn.sigmoid):
@@ -601,7 +609,7 @@ def bijector_fn(inputs, ignored_input):
601609
return bijector_fn
602610

603611

604-
class GlowBlock(chain.Chain):
612+
class GlowBlock(composition.Composition):
605613
"""Single block for a glow model.
606614
607615
This bijector contains `num_steps` steps of the flow, each consisting of an
@@ -613,7 +621,7 @@ class GlowBlock(chain.Chain):
613621

614622
def __init__(self, input_shape, num_steps, coupling_bijector_fn,
615623
use_actnorm, seedstream):
616-
624+
parameters = dict(locals())
617625
rnvp_block = [identity.Identity()]
618626
this_nchan = input_shape[-1]
619627

@@ -646,7 +654,8 @@ def __init__(self, input_shape, num_steps, coupling_bijector_fn,
646654

647655
# Note that we reverse the list since Chain applies bijectors in reverse
648656
# order.
649-
super(GlowBlock, self).__init__(rnvp_block[::-1])
657+
super(GlowBlock, self).__init__(
658+
chain.Chain(rnvp_block[::-1]), parameters=parameters, name='glow_block')
650659

651660
@staticmethod
652661
def make_bijector_fn(layer, scale_fn=tf.nn.sigmoid):
@@ -809,7 +818,7 @@ def _init():
809818
return tf.cond(self._initialized, tf.no_op, _init)
810819

811820

812-
class Expand(chain.Chain):
821+
class Expand(composition.Composition):
813822
"""A bijector to transform channels into spatial pixels."""
814823

815824
def __init__(self, input_shape, block_size=2, validate_args=False, name=None):
@@ -827,8 +836,10 @@ def __init__(self, input_shape, block_size=2, validate_args=False, name=None):
827836
event_shape_in=[h, w, c],
828837
event_shape_out=[h, w, c // n**2, n, n]),
829838
]
830-
super(Expand, self).__init__(b, name=name or 'Expand',
831-
parameters=parameters)
839+
super(Expand, self).__init__(
840+
bijectors=chain.Chain(b, validate_args=validate_args),
841+
name=name or 'Expand',
842+
parameters=parameters)
832843

833844

834845
class GlowDefaultNetwork(tfk.Sequential):

tensorflow_probability/python/bijectors/glow_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tensorflow.compat.v2 as tf
2121
import tensorflow_probability as tfp
2222
from tensorflow_probability.python.bijectors import bijector_test_util
23+
from tensorflow_probability.python.bijectors import composition
2324
from tensorflow_probability.python.internal import test_util
2425
from tensorflow_probability.python.math.gradient import batch_jacobian
2526

@@ -131,7 +132,7 @@ def testDataInit_inverse(self):
131132
x = tf.concat([x1, x2], axis=-1)
132133
nblocks += 1
133134

134-
elif isinstance(b, tfb.chain.Chain):
135+
elif isinstance(b, composition.Composition):
135136
for bb in b.bijectors:
136137
x = self.evaluate(bb.inverse(x))
137138
if isinstance(bb, tfb.glow.ActivationNormalization):
@@ -172,7 +173,7 @@ def testDataInit_forward(self):
172173
y = tf.concat([y1, y2], axis=-1)
173174
nblocks += 1
174175

175-
elif isinstance(b, tfb.chain.Chain):
176+
elif isinstance(b, composition.Composition):
176177
for bb in reversed(b.bijectors):
177178
y = self.evaluate(bb.forward(y))
178179
if isinstance(bb, tfb.glow.ActivationNormalization):

0 commit comments

Comments
 (0)