Skip to content

Commit 2e8e8bd

Browse files
committed
small fixes
1 parent 356ef3d commit 2e8e8bd

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def build_highway_flow_layer(width,
2828
activation_fn=False,
2929
gate_first_n=-1,
3030
seed=None):
31-
"""Builds HighwayFlow making sure that all the requirements ar satisfied.
31+
"""Builds HighwayFlow making sure that all the requirements are satisfied.
3232
3333
Args:
3434
width: Input dimension of the bijector.
3535
residual_fraction_initial_value: Initial value for gating parameter, must be
36-
between 0 and 1.
37-
activation_fn: Whether or not use SoftPlus activation function.
36+
between 0 and 1.
37+
activation_fn: Whether or not use SoftPlus activation function.
3838
gate_first_n: Decides which part of the input should be gated (useful for
3939
example when using auxiliary variables).
4040
seed: Seed for random initialization of the weights.
@@ -58,8 +58,8 @@ def build_highway_flow_layer(width,
5858
name='residual_fraction_initial_value')
5959
dtype = residual_fraction_initial_value.dtype
6060

61-
bias_seed, upper_seed, lower_seed, diagonal_seed = samplers.split_seed(
62-
seed, n=4)
61+
bias_seed, upper_seed, lower_seed = samplers.split_seed(
62+
seed, n=3)
6363
lower_bijector = tfb.Chain(
6464
[tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
6565
tfb.Pad(paddings=[(1, 0), (0, 1)]),
@@ -103,9 +103,9 @@ class HighwayFlow(tfb.Bijector):
103103
104104
HighwayFlow interpolates the input `X` with the transformations at each step
105105
of the bjiector. The Highway Flow can be used as building block for a
106-
Cascading flow [1 or as a generic normalizing flow.
106+
Cascading flow [1] or as a generic normalizing flow.
107107
108-
The transformation consists in a convex update between the input `X` and a
108+
The transformation consists of a convex update between the input `X` and a
109109
linear transformation of `X` followed by activation with the form `g(A @
110110
X + b)`, where `g(.)` is a differentiable non-decreasing activation
111111
function, and `A` and `b` are trainable weights.
@@ -128,7 +128,7 @@ class HighwayFlow(tfb.Bijector):
128128
129129
For more details on Highway Flow and Cascading Flows see [1].
130130
131-
#### Usage example:
131+
#### Usage example
132132
```python
133133
tfd = tfp.distributions
134134
tfb = tfp.bijectors
@@ -145,21 +145,8 @@ class HighwayFlow(tfb.Bijector):
145145
#### References
146146
147147
[1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
148-
"Automatic variational inference with
149-
cascading flows." arXiv preprint arXiv:2102.04801 (2021).
150-
151-
Attributes:
152-
residual_fraction: Scalar `Tensor` used for the convex update, must be
153-
between 0 and 1.
154-
activation_fn: Boolean to decide whether to use SoftPlus (True) activation
155-
or no activation (False).
156-
bias: Bias vector.
157-
upper_diagonal_weights_matrix: Lower diagional matrix of size (width, width)
158-
with positive diagonal (is transposed to Upper diagonal within the
159-
bijector).
160-
lower_diagonal_weights_matrix: Lower diagonal matrix with ones on the main
161-
diagional.
162-
gate_first_n: Integer that decides what part of the input is gated.
148+
"Automatic variational inference with cascading flows." arXiv preprint
149+
arXiv:2102.04801 (2021).
163150
"""
164151

165152
# HighWay Flow simultaneously computes `forward` and `fldj`
@@ -177,7 +164,20 @@ def __init__(self, residual_fraction, activation_fn, bias,
177164
gate_first_n,
178165
validate_args=False,
179166
name=None):
180-
"""Initializes the HighwayFlow."""
167+
"""Initializes the HighwayFlow.
168+
Args:
169+
residual_fraction: Scalar `Tensor` used for the convex update, must be
170+
between 0 and 1.
171+
activation_fn: Boolean to decide whether to use SoftPlus (True) activation
172+
or no activation (False).
173+
bias: Bias vector.
174+
upper_diagonal_weights_matrix: Lower diagional matrix of size
175+
(width, width) with positive diagonal (is transposed to Upper diagonal
176+
within the bijector).
177+
lower_diagonal_weights_matrix: Lower diagonal matrix with ones on the main
178+
diagional.
179+
gate_first_n: Integer that decides what part of the input is gated.
180+
"""
181181
parameters = dict(locals())
182182
name = name or 'highway_flow'
183183
with tf.name_scope(name) as name:

0 commit comments

Comments
 (0)