Skip to content

Commit 1620ebd

Browse files
committed
removed cascading_flows from pr
1 parent 08ed88b commit 1620ebd

File tree

3 files changed

+240
-897
lines changed

3 files changed

+240
-897
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 240 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from tensorflow_probability.python.internal import cache_util
2323
from tensorflow_probability.python.internal import samplers
2424

25-
2625
def build_highway_flow_layer(width,
2726
residual_fraction_initial_value=0.5,
2827
activation_fn=False,
@@ -363,3 +362,243 @@ def _inverse_log_det_jacobian(self, y):
363362
_, attrs = self._augmented_inverse(y)
364363
cached.update(attrs)
365364
return cached['ildj']
365+
=======
366+
def build_highway_flow_layer(width, residual_fraction_initial_value=0.5,
367+
activation_fn=False, seed=None):
368+
# TODO: add control that residual_fraction_initial_value is between 0 and 1
369+
residual_fraction_initial_value = tf.convert_to_tensor(
370+
residual_fraction_initial_value,
371+
dtype_hint=tf.float32,
372+
name='residual_fraction_initial_value')
373+
dtype = residual_fraction_initial_value.dtype
374+
375+
bias_seed, upper_seed, lower_seed, diagonal_seed = samplers.split_seed(seed,
376+
n=4)
377+
return HighwayFlow(
378+
residual_fraction=util.TransformedVariable(
379+
initial_value=residual_fraction_initial_value,
380+
bijector=tfb.Sigmoid(),
381+
dtype=dtype),
382+
activation_fn=activation_fn,
383+
bias=tf.Variable(
384+
samplers.normal((width,), mean=0., stddev=0.01, seed=bias_seed),
385+
dtype=dtype),
386+
upper_diagonal_weights_matrix=util.TransformedVariable(
387+
initial_value=tf.experimental.numpy.tril(
388+
samplers.normal((width, width), mean=0., stddev=1.,
389+
seed=upper_seed),
390+
k=-1) + tf.linalg.diag(
391+
samplers.uniform((width,), minval=0., maxval=1.,
392+
seed=diagonal_seed)),
393+
bijector=tfb.FillScaleTriL(diag_bijector=tfb.Softplus(),
394+
diag_shift=None),
395+
dtype=dtype),
396+
lower_diagonal_weights_matrix=util.TransformedVariable(
397+
initial_value=samplers.normal((width, width), mean=0., stddev=1.,
398+
seed=lower_seed),
399+
bijector=tfb.Chain(
400+
[tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
401+
tfb.Pad(paddings=[(1, 0), (0, 1)]),
402+
tfb.FillTriangular()]),
403+
dtype=dtype)
404+
)
405+
406+
407+
class HighwayFlow(tfb.Bijector):
408+
"""Implements an Highway Flow bijector [1], which interpolates the input
409+
`X` with the transformations at each step of the bjiector.
410+
The Highway Flow can be used as building block for a Cascading flow [1]
411+
or as a generic normalizing flow.
412+
413+
The transformation consists in a convex update between the input `X` and a
414+
linear transformation of `X` followed by activation with the form `g(A @
415+
X + b)`, where `g(.)` is a differentiable non-decreasing activation
416+
function, and `A` and `b` are trainable weights.
417+
418+
The convex update is regulated by a trainable residual fraction `l`
419+
constrained between 0 and 1, and can be
420+
formalized as:
421+
`Y = l * X + (1 - l) * g(A @ X + b)`.
422+
423+
To make this transformation invertible, the bijector is split in three
424+
convex updates:
425+
- `Y1 = l * X + (1 - l) * L @ X`, with `L` lower diagonal matrix with ones
426+
on the diagonal;
427+
- `Y2 = l * Y1 + (1 - l) * (U @ Y1 + b)`, with `U` upper diagonal matrix
428+
with positive diagonal;
429+
- `Y = l * Y2 + (1 - l) * g(Y2)`
430+
431+
The function `build_highway_flow_layer` helps initializing the bijector
432+
with the variables respecting the various constraints.
433+
434+
For more details on Highway Flow and Cascading Flows see [1].
435+
436+
#### Usage example:
437+
```python
438+
tfd = tfp.distributions
439+
tfb = tfp.bijectors
440+
441+
dim = 4 # last input dimension
442+
443+
bijector = build_highway_flow_layer(dim, activation_fn=True)
444+
y = bijector.forward(x) # forward mapping
445+
x = bijector.inverse(y) # inverse mapping
446+
base = tfd.MultivariateNormalDiag(loc=tf.zeros(dim)) # Base distribution
447+
transformed_distribution = tfd.TransformedDistribution(base, bijector)
448+
```
449+
450+
#### References
451+
452+
[1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
453+
"Automatic variational inference with
454+
cascading flows." arXiv preprint arXiv:2102.04801 (2021).
455+
"""
456+
457+
# HighWay Flow simultaneously computes `forward` and `fldj`
458+
# (and `inverse`/`ildj`), so we override the bijector cache to update the
459+
# LDJ entries of attrs on forward/inverse inverse calls (instead of
460+
# updating them only when the LDJ methods themselves are called).
461+
462+
_cache = cache_util.BijectorCacheWithGreedyAttrs(
463+
forward_name='_augmented_forward',
464+
inverse_name='_augmented_inverse')
465+
466+
def __init__(self, residual_fraction, activation_fn, bias,
467+
upper_diagonal_weights_matrix,
468+
lower_diagonal_weights_matrix, validate_args=False,
469+
name='highway_flow'):
470+
'''
471+
Args:
472+
residual_fraction: scalar `Tensor` used for the convex update,
473+
must be
474+
between 0 and 1
475+
activation_fn: bool to decide whether to use softplus (True)
476+
activation or no activation (False)
477+
bias: bias vector
478+
upper_diagonal_weights_matrix: Lower diagional matrix of size
479+
(width, width) with positive diagonal
480+
(is transposed to Upper diagonal within the bijector)
481+
lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
482+
the main diagional.
483+
'''
484+
parameters = dict(locals())
485+
with tf.name_scope(name) as name:
486+
self._width = tf.shape(bias)[-1]
487+
self._bias = bias
488+
self._residual_fraction = residual_fraction
489+
# The upper matrix is still lower triangular, transpose is done in
490+
# _inverse and _forwars metowds, within matvec.
491+
self._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
492+
self._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
493+
self._activation_fn = activation_fn
494+
495+
super(HighwayFlow, self).__init__(
496+
validate_args=validate_args,
497+
forward_min_event_ndims=1,
498+
parameters=parameters,
499+
name=name)
500+
501+
@property
502+
def bias(self):
503+
return self._bias
504+
505+
@property
506+
def width(self):
507+
return self._width
508+
509+
@property
510+
def residual_fraction(self):
511+
return self._residual_fraction
512+
513+
@property
514+
def upper_diagonal_weights_matrix(self):
515+
return self._upper_diagonal_weights_matrix
516+
517+
@property
518+
def lower_diagonal_weights_matrix(self):
519+
return self._lower_diagonal_weights_matrix
520+
521+
@property
522+
def activation_fn(self):
523+
return self._activation_fn
524+
525+
def _derivative_of_sigmoid(self, x):
526+
return self.residual_fraction + (
527+
1. - self.residual_fraction) * tf.math.sigmoid(x)
528+
529+
def _convex_update(self, weights_matrix):
530+
return self.residual_fraction * tf.eye(self.width) + (
531+
1. - self.residual_fraction) * weights_matrix
532+
533+
def _inverse_of_sigmoid(self, y, N=20):
534+
# Inverse of the activation layer with softplus using Newton iteration.
535+
x = tf.ones(y.shape)
536+
for _ in range(N):
537+
x = x - (self.residual_fraction * x + (
538+
1. - self.residual_fraction) * tf.math.softplus(
539+
x) - y) / (
540+
self._derivative_of_sigmoid(x))
541+
return x
542+
543+
def _augmented_forward(self, x):
544+
# Log determinant term from the upper matrix. Note that the log determinant
545+
# of the lower matrix is zero.
546+
fldj = tf.zeros(x.shape[:-1]) + tf.reduce_sum(
547+
tf.math.log(self.residual_fraction + (
548+
1. - self.residual_fraction) * tf.linalg.diag_part(
549+
self.upper_diagonal_weights_matrix)))
550+
x = tf.linalg.matvec(
551+
self._convex_update(self.lower_diagonal_weights_matrix), x)
552+
x = tf.linalg.matvec(tf.transpose(
553+
self._convex_update(self.upper_diagonal_weights_matrix)),
554+
x) + (
555+
1 - self.residual_fraction) * self.bias
556+
if self.activation_fn:
557+
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_sigmoid(x)),
558+
-1)
559+
x = self.residual_fraction * x + (
560+
1. - self.residual_fraction) * self.activation_fn(x)
561+
return x, {'ildj': -fldj, 'fldj': fldj}
562+
563+
def _augmented_inverse(self, y):
564+
ildj = tf.zeros(y.shape[:-1]) - tf.reduce_sum(
565+
tf.math.log(self.residual_fraction + (
566+
1. - self.residual_fraction) * tf.linalg.diag_part(
567+
self.upper_diagonal_weights_matrix)))
568+
if self.activation_fn:
569+
y = self._inverse_of_sigmoid(y)
570+
ildj -= tf.reduce_sum(tf.math.log(self._derivative_of_sigmoid(y)),
571+
-1)
572+
573+
y = tf.linalg.triangular_solve(tf.transpose(
574+
self._convex_update(self.upper_diagonal_weights_matrix)),
575+
tf.linalg.matrix_transpose(y - (
576+
1 - self.residual_fraction) * self.bias),
577+
lower=False)
578+
y = tf.linalg.triangular_solve(
579+
self._convex_update(self.lower_diagonal_weights_matrix), y)
580+
return tf.linalg.matrix_transpose(y), {'ildj': ildj, 'fldj': -ildj}
581+
582+
def _forward(self, x):
583+
y, _ = self._augmented_forward(x)
584+
return y
585+
586+
def _inverse(self, y):
587+
x, _ = self._augmented_inverse(y)
588+
return x
589+
590+
def _forward_log_det_jacobian(self, x):
591+
cached = self._cache.forward_attributes(x)
592+
# If LDJ isn't in the cache, call forward once.
593+
if 'fldj' not in cached:
594+
_, attrs = self._augmented_forward(x)
595+
cached.update(attrs)
596+
return cached['fldj']
597+
598+
def _inverse_log_det_jacobian(self, y):
599+
cached = self._cache.inverse_attributes(y)
600+
# If LDJ isn't in the cache, call inverse once.
601+
if 'ildj' not in cached:
602+
_, attrs = self._augmented_inverse(y)
603+
cached.update(attrs)
604+
return cached['ildj']

0 commit comments

Comments
 (0)