@@ -28,13 +28,13 @@ def build_highway_flow_layer(width,
28
28
activation_fn = False ,
29
29
gate_first_n = - 1 ,
30
30
seed = None ):
31
- """Builds HighwayFlow making sure that all the requirements ar satisfied.
31
+ """Builds HighwayFlow making sure that all the requirements are satisfied.
32
32
33
33
Args:
34
34
width: Input dimension of the bijector.
35
35
residual_fraction_initial_value: Initial value for gating parameter, must be
36
- between 0 and 1.
37
- activation_fn: Whether or not use SoftPlus activation function.
36
+ between 0 and 1.
37
+ activation_fn: Whether or not use SoftPlus activation function.
38
38
gate_first_n: Decides which part of the input should be gated (useful for
39
39
example when using auxiliary variables).
40
40
seed: Seed for random initialization of the weights.
@@ -58,8 +58,8 @@ def build_highway_flow_layer(width,
58
58
name = 'residual_fraction_initial_value' )
59
59
dtype = residual_fraction_initial_value .dtype
60
60
61
- bias_seed , upper_seed , lower_seed , diagonal_seed = samplers .split_seed (
62
- seed , n = 4 )
61
+ bias_seed , upper_seed , lower_seed = samplers .split_seed (
62
+ seed , n = 3 )
63
63
lower_bijector = tfb .Chain (
64
64
[tfb .TransformDiagonal (diag_bijector = tfb .Shift (1. )),
65
65
tfb .Pad (paddings = [(1 , 0 ), (0 , 1 )]),
@@ -103,9 +103,9 @@ class HighwayFlow(tfb.Bijector):
103
103
104
104
HighwayFlow interpolates the input `X` with the transformations at each step
105
105
of the bjiector. The Highway Flow can be used as building block for a
106
- Cascading flow [1 or as a generic normalizing flow.
106
+ Cascading flow [1] or as a generic normalizing flow.
107
107
108
- The transformation consists in a convex update between the input `X` and a
108
+ The transformation consists of a convex update between the input `X` and a
109
109
linear transformation of `X` followed by activation with the form `g(A @
110
110
X + b)`, where `g(.)` is a differentiable non-decreasing activation
111
111
function, and `A` and `b` are trainable weights.
@@ -128,7 +128,7 @@ class HighwayFlow(tfb.Bijector):
128
128
129
129
For more details on Highway Flow and Cascading Flows see [1].
130
130
131
- #### Usage example:
131
+ #### Usage example
132
132
```python
133
133
tfd = tfp.distributions
134
134
tfb = tfp.bijectors
@@ -145,21 +145,8 @@ class HighwayFlow(tfb.Bijector):
145
145
#### References
146
146
147
147
[1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
148
- "Automatic variational inference with
149
- cascading flows." arXiv preprint arXiv:2102.04801 (2021).
150
-
151
- Attributes:
152
- residual_fraction: Scalar `Tensor` used for the convex update, must be
153
- between 0 and 1.
154
- activation_fn: Boolean to decide whether to use SoftPlus (True) activation
155
- or no activation (False).
156
- bias: Bias vector.
157
- upper_diagonal_weights_matrix: Lower diagional matrix of size (width, width)
158
- with positive diagonal (is transposed to Upper diagonal within the
159
- bijector).
160
- lower_diagonal_weights_matrix: Lower diagonal matrix with ones on the main
161
- diagional.
162
- gate_first_n: Integer that decides what part of the input is gated.
148
+ "Automatic variational inference with cascading flows." arXiv preprint
149
+ arXiv:2102.04801 (2021).
163
150
"""
164
151
165
152
# HighWay Flow simultaneously computes `forward` and `fldj`
@@ -177,7 +164,20 @@ def __init__(self, residual_fraction, activation_fn, bias,
177
164
gate_first_n ,
178
165
validate_args = False ,
179
166
name = None ):
180
- """Initializes the HighwayFlow."""
167
+ """Initializes the HighwayFlow.
168
+ Args:
169
+ residual_fraction: Scalar `Tensor` used for the convex update, must be
170
+ between 0 and 1.
171
+ activation_fn: Boolean to decide whether to use SoftPlus (True) activation
172
+ or no activation (False).
173
+ bias: Bias vector.
174
+ upper_diagonal_weights_matrix: Lower diagional matrix of size
175
+ (width, width) with positive diagonal (is transposed to Upper diagonal
176
+ within the bijector).
177
+ lower_diagonal_weights_matrix: Lower diagonal matrix with ones on the main
178
+ diagional.
179
+ gate_first_n: Integer that decides what part of the input is gated.
180
+ """
181
181
parameters = dict (locals ())
182
182
name = name or 'highway_flow'
183
183
with tf .name_scope (name ) as name :
0 commit comments