Skip to content

Commit 8f78852

Browse files
committed
refactored docstrings
1 parent d4de7b1 commit 8f78852

File tree

1 file changed

+56
-43
lines changed

1 file changed

+56
-43
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,25 @@ def build_highway_flow_layer(width,
2828
activation_fn=False,
2929
gate_first_n=-1,
3030
seed=None):
31-
"""
32-
Builds an HighwayFlow layer making sure that all the requirements are
33-
satisfied, namely:
34-
- `residual_fraction` is bounded between 0 and 1;
35-
- `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
36-
matrix with positive diagonal of size `width x width`;
37-
- `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
38-
matrix with ones on the diagonal of size `width x width`;
39-
- `bias` is a randomly initialized vector of size `width`
40-
41-
:param int width: input dimension of the bijector
42-
:param float residual_fraction_initial_value: initial value for residual
43-
fraction, must be between 0. and 1.
44-
:param bool activation_fn: whether or not use activation function in the
45-
HighwayFlow
46-
:param int seed: seed for random initialization of the weights
47-
:returns: the initialized HighwayFlow bijector
48-
:rtype: `tfb.Bijector`
31+
"""Builds HighwayFlow making sure that all the requirements ar satisfied.
32+
33+
Args:
34+
width: Input dimension of the bijector.
35+
residual_fraction_initial_value: Initial value for gating parameter, must be
36+
between 0 and 1.
37+
activation_fn: Whether or not use SoftPlus activation function.
38+
gate_first_n: Decides which part of the input should be gated (useful for
39+
example when using auxiliary variables).
40+
seed: Seed for random initialization of the weights.
41+
42+
Returns:
43+
The initialized bijector with the following elements:
44+
`residual_fraction` is bounded between 0 and 1.
45+
`upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
46+
matrix with positive diagonal of size `width x width`.
47+
`lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
48+
matrix with ones on the diagonal of size `width x width`;
49+
`bias` is a randomly initialized vector of size `width`
4950
"""
5051

5152
if gate_first_n == -1:
@@ -98,10 +99,11 @@ def build_highway_flow_layer(width,
9899

99100

100101
class HighwayFlow(tfb.Bijector):
101-
"""Implements an Highway Flow bijector [1], which interpolates the input
102-
`X` with the transformations at each step of the bjiector.
103-
The Highway Flow can be used as building block for a Cascading flow [1]
104-
or as a generic normalizing flow.
102+
"""Implements an Highway Flow bijector [1].
103+
104+
HighwayFlow interpolates the input `X` with the transformations at each step
105+
of the bjiector. The Highway Flow can be used as building block for a
106+
Cascading flow [1 or as a generic normalizing flow.
105107
106108
The transformation consists in a convex update between the input `X` and a
107109
linear transformation of `X` followed by activation with the form `g(A @
@@ -145,6 +147,19 @@ class HighwayFlow(tfb.Bijector):
145147
[1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
146148
"Automatic variational inference with
147149
cascading flows." arXiv preprint arXiv:2102.04801 (2021).
150+
151+
Attributes:
152+
residual_fraction: Scalar `Tensor` used for the convex update, must be
153+
between 0 and 1.
154+
activation_fn: Boolean to decide whether to use SoftPlus (True) activation
155+
or no activation (False).
156+
bias: Bias vector.
157+
upper_diagonal_weights_matrix: Lower diagional matrix of size (width, width)
158+
with positive diagonal (is transposed to Upper diagonal within the
159+
bijector).
160+
lower_diagonal_weights_matrix: Lower diagonal matrix with ones on the main
161+
diagional.
162+
gate_first_n: Integer that decides what part of the input is gated.
148163
"""
149164

150165
# HighWay Flow simultaneously computes `forward` and `fldj`
@@ -162,20 +177,7 @@ def __init__(self, residual_fraction, activation_fn, bias,
162177
gate_first_n,
163178
validate_args=False,
164179
name=None):
165-
'''
166-
Args:
167-
residual_fraction: scalar `Tensor` used for the convex update,
168-
must be
169-
between 0 and 1
170-
activation_fn: bool to decide whether to use softplus (True)
171-
activation or no activation (False)
172-
bias: bias vector
173-
upper_diagonal_weights_matrix: Lower diagional matrix of size
174-
(width, width) with positive diagonal
175-
(is transposed to Upper diagonal within the bijector)
176-
lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
177-
the main diagional.
178-
'''
180+
"""Initializes the HighwayFlow."""
179181
parameters = dict(locals())
180182
name = name or 'highway_flow'
181183
with tf.name_scope(name) as name:
@@ -242,7 +244,7 @@ def _convex_update(self, weights_matrix):
242244
axis=0) * weights_matrix
243245

244246
def _inverse_of_softplus(self, y, n=20):
245-
# Inverse of the activation layer with softplus using Newton iteration.
247+
"""Inverse of the activation layer with softplus using Newton iteration."""
246248
x = tf.ones(y.shape)
247249
for _ in range(n):
248250
x = x - (tf.concat([(self.residual_fraction) * tf.ones(
@@ -257,12 +259,18 @@ def _inverse_of_softplus(self, y, n=20):
257259

258260
def _augmented_forward(self, x):
259261
"""Computes forward and forward_log_det_jacobian transformations.
260-
:param tf.Tensor x: input of the bijector
261-
:returns: x after forward flow
262-
:rtype: `tf.Tensor`
262+
263+
Args:
264+
x: Input of the bijector.
265+
266+
Returns:
267+
x after forward flow and a dict containing forward and inverse log
268+
determinant of the jacobian.
263269
"""
270+
264271
# Log determinant term from the upper matrix. Note that the log determinant
265272
# of the lower matrix is zero.
273+
266274
added_batch = False
267275
if len(x.shape) <= 1:
268276
if len(x.shape) == 0:
@@ -302,10 +310,15 @@ def _augmented_forward(self, x):
302310

303311
def _augmented_inverse(self, y):
304312
"""Computes inverse and inverse_log_det_jacobian transformations.
305-
:param tf.Tensor y: input of the (inverse) bijector
306-
:returns: y after inverse flow
307-
:rtype: `tf.Tensor`
313+
314+
Args:
315+
y: input of the (inverse) bijectorr.
316+
317+
Returns:
318+
y after inverse flow and a dict containing inverse and forward log
319+
determinant of the jacobian.
308320
"""
321+
309322
added_batch = False
310323
if len(y.shape) <= 1:
311324
if len(y.shape) == 0:

0 commit comments

Comments
 (0)