Skip to content

Commit af490be

Browse files
committed
changed imports
1 parent 66ae7f3 commit af490be

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717

1818
import tensorflow.compat.v2 as tf
1919

20-
from tensorflow_probability.python import bijectors as tfb
2120
from tensorflow_probability.python import util
21+
from tensorflow_probability.python.bijectors import bijector
22+
from tensorflow_probability.python.bijectors import chain
23+
from tensorflow_probability.python.bijectors import fill_scale_tril
24+
from tensorflow_probability.python.bijectors import fill_triangular
25+
from tensorflow_probability.python.bijectors import pad
26+
from tensorflow_probability.python.bijectors import shift
27+
from tensorflow_probability.python.bijectors import sigmoid
28+
from tensorflow_probability.python.bijectors import softplus
29+
from tensorflow_probability.python.bijectors import transform_diagonal
2230
from tensorflow_probability.python.internal import cache_util
2331
from tensorflow_probability.python.internal import dtype_util
2432
from tensorflow_probability.python.internal import prefer_static as ps
@@ -61,17 +69,18 @@ def build_highway_flow_layer(width,
6169

6270
bias_seed, upper_seed, lower_seed = samplers.split_seed(
6371
seed, n=3)
64-
lower_bijector = tfb.Chain(
65-
[tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
66-
tfb.Pad(paddings=[(1, 0), (0, 1)]),
67-
tfb.FillTriangular()])
72+
lower_bijector = chain.Chain(
73+
[transform_diagonal.TransformDiagonal(diag_bijector=shift.Shift(1.)),
74+
pad.Pad(paddings=[(1, 0), (0, 1)]),
75+
fill_triangular.FillTriangular()])
6876
unconstrained_lower_initial_values = samplers.normal(
6977
shape=lower_bijector.inverse_event_shape([width, width]),
7078
mean=0.,
7179
stddev=.01,
7280
seed=lower_seed)
73-
upper_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(),
74-
diag_shift=None)
81+
upper_bijector = fill_scale_tril.FillScaleTriL(
82+
diag_bijector=softplus.Softplus(),
83+
diag_shift=None)
7584
unconstrained_upper_initial_values = samplers.normal(
7685
shape=upper_bijector.inverse_event_shape([width, width]),
7786
mean=0.,
@@ -81,7 +90,7 @@ def build_highway_flow_layer(width,
8190
return HighwayFlow(
8291
residual_fraction=util.TransformedVariable(
8392
initial_value=residual_fraction_initial_value,
84-
bijector=tfb.Sigmoid(),
93+
bijector=sigmoid.Sigmoid(),
8594
dtype=dtype),
8695
activation_fn=activation_fn,
8796
bias=tf.Variable(
@@ -99,7 +108,7 @@ def build_highway_flow_layer(width,
99108
)
100109

101110

102-
class HighwayFlow(tfb.Bijector):
111+
class HighwayFlow(bijector.Bijector):
103112
"""Implements an Highway Flow bijector [1].
104113
105114
HighwayFlow interpolates the input `X` with the transformations at each step
@@ -193,12 +202,14 @@ def __init__(self, residual_fraction, activation_fn, bias,
193202
residual_fraction, dtype=dtype, name='residual_fraction')
194203
# The upper matrix is still lower triangular, transpose is done in
195204
# _inverse and _forwars metowds.
196-
self._upper_diagonal_weights_matrix = tensor_util.convert_nonref_to_tensor(
197-
upper_diagonal_weights_matrix, dtype=dtype,
198-
name='upper_diagonal_weights_matrix')
199-
self._lower_diagonal_weights_matrix = tensor_util.convert_nonref_to_tensor(
200-
lower_diagonal_weights_matrix, dtype=dtype,
201-
name='lower_diagonal_weights_matrix')
205+
self._upper_diagonal_weights_matrix = \
206+
tensor_util.convert_nonref_to_tensor(
207+
upper_diagonal_weights_matrix, dtype=dtype,
208+
name='upper_diagonal_weights_matrix')
209+
self._lower_diagonal_weights_matrix = \
210+
tensor_util.convert_nonref_to_tensor(
211+
lower_diagonal_weights_matrix, dtype=dtype,
212+
name='lower_diagonal_weights_matrix')
202213
self._activation_fn = activation_fn
203214
self._gate_first_n = gate_first_n if gate_first_n else self.width
204215

0 commit comments

Comments
 (0)