@@ -28,24 +28,25 @@ def build_highway_flow_layer(width,
28
28
activation_fn = False ,
29
29
gate_first_n = - 1 ,
30
30
seed = None ):
31
- """
32
- Builds an HighwayFlow layer making sure that all the requirements are
33
- satisfied, namely:
34
- - `residual_fraction` is bounded between 0 and 1;
35
- - `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
36
- matrix with positive diagonal of size `width x width`;
37
- - `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
38
- matrix with ones on the diagonal of size `width x width`;
39
- - `bias` is a randomly initialized vector of size `width`
40
-
41
- :param int width: input dimension of the bijector
42
- :param float residual_fraction_initial_value: initial value for residual
43
- fraction, must be between 0. and 1.
44
- :param bool activation_fn: whether or not use activation function in the
45
- HighwayFlow
46
- :param int seed: seed for random initialization of the weights
47
- :returns: the initialized HighwayFlow bijector
48
- :rtype: `tfb.Bijector`
31
+ """Builds HighwayFlow making sure that all the requirements ar satisfied.
32
+
33
+ Args:
34
+ width: Input dimension of the bijector.
35
+ residual_fraction_initial_value: Initial value for gating parameter, must be
36
+ between 0 and 1.
37
+ activation_fn: Whether or not use SoftPlus activation function.
38
+ gate_first_n: Decides which part of the input should be gated (useful for
39
+ example when using auxiliary variables).
40
+ seed: Seed for random initialization of the weights.
41
+
42
+ Returns:
43
+ The initialized bijector with the following elements:
44
+ `residual_fraction` is bounded between 0 and 1.
45
+ `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
46
+ matrix with positive diagonal of size `width x width`.
47
+ `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
48
+ matrix with ones on the diagonal of size `width x width`;
49
+ `bias` is a randomly initialized vector of size `width`
49
50
"""
50
51
51
52
if gate_first_n == - 1 :
@@ -98,10 +99,11 @@ def build_highway_flow_layer(width,
98
99
99
100
100
101
class HighwayFlow (tfb .Bijector ):
101
- """Implements an Highway Flow bijector [1], which interpolates the input
102
- `X` with the transformations at each step of the bjiector.
103
- The Highway Flow can be used as building block for a Cascading flow [1]
104
- or as a generic normalizing flow.
102
+ """Implements an Highway Flow bijector [1].
103
+
104
+ HighwayFlow interpolates the input `X` with the transformations at each step
105
+ of the bjiector. The Highway Flow can be used as building block for a
106
+ Cascading flow [1 or as a generic normalizing flow.
105
107
106
108
The transformation consists in a convex update between the input `X` and a
107
109
linear transformation of `X` followed by activation with the form `g(A @
@@ -145,6 +147,19 @@ class HighwayFlow(tfb.Bijector):
145
147
[1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
146
148
"Automatic variational inference with
147
149
cascading flows." arXiv preprint arXiv:2102.04801 (2021).
150
+
151
+ Attributes:
152
+ residual_fraction: Scalar `Tensor` used for the convex update, must be
153
+ between 0 and 1.
154
+ activation_fn: Boolean to decide whether to use SoftPlus (True) activation
155
+ or no activation (False).
156
+ bias: Bias vector.
157
+ upper_diagonal_weights_matrix: Lower diagional matrix of size (width, width)
158
+ with positive diagonal (is transposed to Upper diagonal within the
159
+ bijector).
160
+ lower_diagonal_weights_matrix: Lower diagonal matrix with ones on the main
161
+ diagional.
162
+ gate_first_n: Integer that decides what part of the input is gated.
148
163
"""
149
164
150
165
# HighWay Flow simultaneously computes `forward` and `fldj`
@@ -162,20 +177,7 @@ def __init__(self, residual_fraction, activation_fn, bias,
162
177
gate_first_n ,
163
178
validate_args = False ,
164
179
name = None ):
165
- '''
166
- Args:
167
- residual_fraction: scalar `Tensor` used for the convex update,
168
- must be
169
- between 0 and 1
170
- activation_fn: bool to decide whether to use softplus (True)
171
- activation or no activation (False)
172
- bias: bias vector
173
- upper_diagonal_weights_matrix: Lower diagional matrix of size
174
- (width, width) with positive diagonal
175
- (is transposed to Upper diagonal within the bijector)
176
- lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
177
- the main diagional.
178
- '''
180
+ """Initializes the HighwayFlow."""
179
181
parameters = dict (locals ())
180
182
name = name or 'highway_flow'
181
183
with tf .name_scope (name ) as name :
@@ -242,7 +244,7 @@ def _convex_update(self, weights_matrix):
242
244
axis = 0 ) * weights_matrix
243
245
244
246
def _inverse_of_softplus (self , y , n = 20 ):
245
- # Inverse of the activation layer with softplus using Newton iteration.
247
+ """ Inverse of the activation layer with softplus using Newton iteration."""
246
248
x = tf .ones (y .shape )
247
249
for _ in range (n ):
248
250
x = x - (tf .concat ([(self .residual_fraction ) * tf .ones (
@@ -257,12 +259,18 @@ def _inverse_of_softplus(self, y, n=20):
257
259
258
260
def _augmented_forward (self , x ):
259
261
"""Computes forward and forward_log_det_jacobian transformations.
260
- :param tf.Tensor x: input of the bijector
261
- :returns: x after forward flow
262
- :rtype: `tf.Tensor`
262
+
263
+ Args:
264
+ x: Input of the bijector.
265
+
266
+ Returns:
267
+ x after forward flow and a dict containing forward and inverse log
268
+ determinant of the jacobian.
263
269
"""
270
+
264
271
# Log determinant term from the upper matrix. Note that the log determinant
265
272
# of the lower matrix is zero.
273
+
266
274
added_batch = False
267
275
if len (x .shape ) <= 1 :
268
276
if len (x .shape ) == 0 :
@@ -302,10 +310,15 @@ def _augmented_forward(self, x):
302
310
303
311
def _augmented_inverse (self , y ):
304
312
"""Computes inverse and inverse_log_det_jacobian transformations.
305
- :param tf.Tensor y: input of the (inverse) bijector
306
- :returns: y after inverse flow
307
- :rtype: `tf.Tensor`
313
+
314
+ Args:
315
+ y: input of the (inverse) bijectorr.
316
+
317
+ Returns:
318
+ y after inverse flow and a dict containing inverse and forward log
319
+ determinant of the jacobian.
308
320
"""
321
+
309
322
added_batch = False
310
323
if len (y .shape ) <= 1 :
311
324
if len (y .shape ) == 0 :
0 commit comments