Skip to content

Commit 1ad515a

Browse files
committed
fixed merging mess and changed gate_first_n default to None
1 parent 5bab1c9 commit 1ad515a

File tree

1 file changed

+6
-245
lines changed

1 file changed

+6
-245
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 6 additions & 245 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def build_highway_flow_layer(width,
2727
residual_fraction_initial_value=0.5,
2828
activation_fn=False,
29-
gate_first_n=-1,
29+
gate_first_n=None,
3030
seed=None):
3131
"""Builds HighwayFlow making sure that all the requirements are satisfied.
3232
@@ -49,8 +49,6 @@ def build_highway_flow_layer(width,
4949
`bias` is a randomly initialized vector of size `width`
5050
"""
5151

52-
if gate_first_n == -1:
53-
gate_first_n = width
5452
# TODO: add control that residual_fraction_initial_value is between 0 and 1
5553
residual_fraction_initial_value = tf.convert_to_tensor(
5654
residual_fraction_initial_value,
@@ -189,7 +187,10 @@ def __init__(self, residual_fraction, activation_fn, bias,
189187
self._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
190188
self._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
191189
self._activation_fn = activation_fn
192-
self._gate_first_n = gate_first_n
190+
if gate_first_n:
191+
self._gate_first_n = gate_first_n if gate_first_n else self.width
192+
193+
193194

194195
super(HighwayFlow, self).__init__(
195196
validate_args=validate_args,
@@ -362,244 +363,4 @@ def _inverse_log_det_jacobian(self, y):
362363
if 'ildj' not in cached:
363364
_, attrs = self._augmented_inverse(y)
364365
cached.update(attrs)
365-
return cached['ildj']
366-
367-
def build_highway_flow_layer(width, residual_fraction_initial_value=0.5,
368-
activation_fn=False, seed=None):
369-
# TODO: add control that residual_fraction_initial_value is between 0 and 1
370-
residual_fraction_initial_value = tf.convert_to_tensor(
371-
residual_fraction_initial_value,
372-
dtype_hint=tf.float32,
373-
name='residual_fraction_initial_value')
374-
dtype = residual_fraction_initial_value.dtype
375-
376-
bias_seed, upper_seed, lower_seed, diagonal_seed = samplers.split_seed(seed,
377-
n=4)
378-
return HighwayFlow(
379-
residual_fraction=util.TransformedVariable(
380-
initial_value=residual_fraction_initial_value,
381-
bijector=tfb.Sigmoid(),
382-
dtype=dtype),
383-
activation_fn=activation_fn,
384-
bias=tf.Variable(
385-
samplers.normal((width,), mean=0., stddev=0.01, seed=bias_seed),
386-
dtype=dtype),
387-
upper_diagonal_weights_matrix=util.TransformedVariable(
388-
initial_value=tf.experimental.numpy.tril(
389-
samplers.normal((width, width), mean=0., stddev=1.,
390-
seed=upper_seed),
391-
k=-1) + tf.linalg.diag(
392-
samplers.uniform((width,), minval=0., maxval=1.,
393-
seed=diagonal_seed)),
394-
bijector=tfb.FillScaleTriL(diag_bijector=tfb.Softplus(),
395-
diag_shift=None),
396-
dtype=dtype),
397-
lower_diagonal_weights_matrix=util.TransformedVariable(
398-
initial_value=samplers.normal((width, width), mean=0., stddev=1.,
399-
seed=lower_seed),
400-
bijector=tfb.Chain(
401-
[tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
402-
tfb.Pad(paddings=[(1, 0), (0, 1)]),
403-
tfb.FillTriangular()]),
404-
dtype=dtype)
405-
)
406-
407-
408-
class HighwayFlow(tfb.Bijector):
409-
"""Implements an Highway Flow bijector [1], which interpolates the input
410-
`X` with the transformations at each step of the bjiector.
411-
The Highway Flow can be used as building block for a Cascading flow [1]
412-
or as a generic normalizing flow.
413-
414-
The transformation consists in a convex update between the input `X` and a
415-
linear transformation of `X` followed by activation with the form `g(A @
416-
X + b)`, where `g(.)` is a differentiable non-decreasing activation
417-
function, and `A` and `b` are trainable weights.
418-
419-
The convex update is regulated by a trainable residual fraction `l`
420-
constrained between 0 and 1, and can be
421-
formalized as:
422-
`Y = l * X + (1 - l) * g(A @ X + b)`.
423-
424-
To make this transformation invertible, the bijector is split in three
425-
convex updates:
426-
- `Y1 = l * X + (1 - l) * L @ X`, with `L` lower diagonal matrix with ones
427-
on the diagonal;
428-
- `Y2 = l * Y1 + (1 - l) * (U @ Y1 + b)`, with `U` upper diagonal matrix
429-
with positive diagonal;
430-
- `Y = l * Y2 + (1 - l) * g(Y2)`
431-
432-
The function `build_highway_flow_layer` helps initializing the bijector
433-
with the variables respecting the various constraints.
434-
435-
For more details on Highway Flow and Cascading Flows see [1].
436-
437-
#### Usage example:
438-
```python
439-
tfd = tfp.distributions
440-
tfb = tfp.bijectors
441-
442-
dim = 4 # last input dimension
443-
444-
bijector = build_highway_flow_layer(dim, activation_fn=True)
445-
y = bijector.forward(x) # forward mapping
446-
x = bijector.inverse(y) # inverse mapping
447-
base = tfd.MultivariateNormalDiag(loc=tf.zeros(dim)) # Base distribution
448-
transformed_distribution = tfd.TransformedDistribution(base, bijector)
449-
```
450-
451-
#### References
452-
453-
[1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
454-
"Automatic variational inference with
455-
cascading flows." arXiv preprint arXiv:2102.04801 (2021).
456-
"""
457-
458-
# HighWay Flow simultaneously computes `forward` and `fldj`
459-
# (and `inverse`/`ildj`), so we override the bijector cache to update the
460-
# LDJ entries of attrs on forward/inverse inverse calls (instead of
461-
# updating them only when the LDJ methods themselves are called).
462-
463-
_cache = cache_util.BijectorCacheWithGreedyAttrs(
464-
forward_name='_augmented_forward',
465-
inverse_name='_augmented_inverse')
466-
467-
def __init__(self, residual_fraction, activation_fn, bias,
468-
upper_diagonal_weights_matrix,
469-
lower_diagonal_weights_matrix, validate_args=False,
470-
name='highway_flow'):
471-
'''
472-
Args:
473-
residual_fraction: scalar `Tensor` used for the convex update,
474-
must be
475-
between 0 and 1
476-
activation_fn: bool to decide whether to use softplus (True)
477-
activation or no activation (False)
478-
bias: bias vector
479-
upper_diagonal_weights_matrix: Lower diagional matrix of size
480-
(width, width) with positive diagonal
481-
(is transposed to Upper diagonal within the bijector)
482-
lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
483-
the main diagional.
484-
'''
485-
parameters = dict(locals())
486-
with tf.name_scope(name) as name:
487-
self._width = tf.shape(bias)[-1]
488-
self._bias = bias
489-
self._residual_fraction = residual_fraction
490-
# The upper matrix is still lower triangular, transpose is done in
491-
# _inverse and _forwars metowds, within matvec.
492-
self._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
493-
self._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
494-
self._activation_fn = activation_fn
495-
496-
super(HighwayFlow, self).__init__(
497-
validate_args=validate_args,
498-
forward_min_event_ndims=1,
499-
parameters=parameters,
500-
name=name)
501-
502-
@property
503-
def bias(self):
504-
return self._bias
505-
506-
@property
507-
def width(self):
508-
return self._width
509-
510-
@property
511-
def residual_fraction(self):
512-
return self._residual_fraction
513-
514-
@property
515-
def upper_diagonal_weights_matrix(self):
516-
return self._upper_diagonal_weights_matrix
517-
518-
@property
519-
def lower_diagonal_weights_matrix(self):
520-
return self._lower_diagonal_weights_matrix
521-
522-
@property
523-
def activation_fn(self):
524-
return self._activation_fn
525-
526-
def _derivative_of_sigmoid(self, x):
527-
return self.residual_fraction + (
528-
1. - self.residual_fraction) * tf.math.sigmoid(x)
529-
530-
def _convex_update(self, weights_matrix):
531-
return self.residual_fraction * tf.eye(self.width) + (
532-
1. - self.residual_fraction) * weights_matrix
533-
534-
def _inverse_of_sigmoid(self, y, N=20):
535-
# Inverse of the activation layer with softplus using Newton iteration.
536-
x = tf.ones(y.shape)
537-
for _ in range(N):
538-
x = x - (self.residual_fraction * x + (
539-
1. - self.residual_fraction) * tf.math.softplus(
540-
x) - y) / (
541-
self._derivative_of_sigmoid(x))
542-
return x
543-
544-
def _augmented_forward(self, x):
545-
# Log determinant term from the upper matrix. Note that the log determinant
546-
# of the lower matrix is zero.
547-
fldj = tf.zeros(x.shape[:-1]) + tf.reduce_sum(
548-
tf.math.log(self.residual_fraction + (
549-
1. - self.residual_fraction) * tf.linalg.diag_part(
550-
self.upper_diagonal_weights_matrix)))
551-
x = tf.linalg.matvec(
552-
self._convex_update(self.lower_diagonal_weights_matrix), x)
553-
x = tf.linalg.matvec(tf.transpose(
554-
self._convex_update(self.upper_diagonal_weights_matrix)),
555-
x) + (
556-
1 - self.residual_fraction) * self.bias
557-
if self.activation_fn:
558-
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_sigmoid(x)),
559-
-1)
560-
x = self.residual_fraction * x + (
561-
1. - self.residual_fraction) * self.activation_fn(x)
562-
return x, {'ildj': -fldj, 'fldj': fldj}
563-
564-
def _augmented_inverse(self, y):
565-
ildj = tf.zeros(y.shape[:-1]) - tf.reduce_sum(
566-
tf.math.log(self.residual_fraction + (
567-
1. - self.residual_fraction) * tf.linalg.diag_part(
568-
self.upper_diagonal_weights_matrix)))
569-
if self.activation_fn:
570-
y = self._inverse_of_sigmoid(y)
571-
ildj -= tf.reduce_sum(tf.math.log(self._derivative_of_sigmoid(y)),
572-
-1)
573-
574-
y = tf.linalg.triangular_solve(tf.transpose(
575-
self._convex_update(self.upper_diagonal_weights_matrix)),
576-
tf.linalg.matrix_transpose(y - (
577-
1 - self.residual_fraction) * self.bias),
578-
lower=False)
579-
y = tf.linalg.triangular_solve(
580-
self._convex_update(self.lower_diagonal_weights_matrix), y)
581-
return tf.linalg.matrix_transpose(y), {'ildj': ildj, 'fldj': -ildj}
582-
583-
def _forward(self, x):
584-
y, _ = self._augmented_forward(x)
585-
return y
586-
587-
def _inverse(self, y):
588-
x, _ = self._augmented_inverse(y)
589-
return x
590-
591-
def _forward_log_det_jacobian(self, x):
592-
cached = self._cache.forward_attributes(x)
593-
# If LDJ isn't in the cache, call forward once.
594-
if 'fldj' not in cached:
595-
_, attrs = self._augmented_forward(x)
596-
cached.update(attrs)
597-
return cached['fldj']
598-
599-
def _inverse_log_det_jacobian(self, y):
600-
cached = self._cache.inverse_attributes(y)
601-
# If LDJ isn't in the cache, call inverse once.
602-
if 'ildj' not in cached:
603-
_, attrs = self._augmented_inverse(y)
604-
cached.update(attrs)
605-
return cached['ildj']
366+
return cached['ildj']

0 commit comments

Comments
 (0)