17
17
18
18
import tensorflow .compat .v2 as tf
19
19
20
- from tensorflow_probability .python import bijectors as tfb
21
20
from tensorflow_probability .python import util
21
+ from tensorflow_probability .python .bijectors import bijector
22
+ from tensorflow_probability .python .bijectors import chain
23
+ from tensorflow_probability .python .bijectors import fill_scale_tril
24
+ from tensorflow_probability .python .bijectors import fill_triangular
25
+ from tensorflow_probability .python .bijectors import pad
26
+ from tensorflow_probability .python .bijectors import shift
27
+ from tensorflow_probability .python .bijectors import sigmoid
28
+ from tensorflow_probability .python .bijectors import softplus
29
+ from tensorflow_probability .python .bijectors import transform_diagonal
22
30
from tensorflow_probability .python .internal import cache_util
23
31
from tensorflow_probability .python .internal import dtype_util
24
32
from tensorflow_probability .python .internal import prefer_static as ps
@@ -61,17 +69,18 @@ def build_highway_flow_layer(width,
61
69
62
70
bias_seed , upper_seed , lower_seed = samplers .split_seed (
63
71
seed , n = 3 )
64
- lower_bijector = tfb .Chain (
65
- [tfb .TransformDiagonal (diag_bijector = tfb .Shift (1. )),
66
- tfb .Pad (paddings = [(1 , 0 ), (0 , 1 )]),
67
- tfb .FillTriangular ()])
72
+ lower_bijector = chain .Chain (
73
+ [transform_diagonal .TransformDiagonal (diag_bijector = shift .Shift (1. )),
74
+ pad .Pad (paddings = [(1 , 0 ), (0 , 1 )]),
75
+ fill_triangular .FillTriangular ()])
68
76
unconstrained_lower_initial_values = samplers .normal (
69
77
shape = lower_bijector .inverse_event_shape ([width , width ]),
70
78
mean = 0. ,
71
79
stddev = .01 ,
72
80
seed = lower_seed )
73
- upper_bijector = tfb .FillScaleTriL (diag_bijector = tfb .Softplus (),
74
- diag_shift = None )
81
+ upper_bijector = fill_scale_tril .FillScaleTriL (
82
+ diag_bijector = softplus .Softplus (),
83
+ diag_shift = None )
75
84
unconstrained_upper_initial_values = samplers .normal (
76
85
shape = upper_bijector .inverse_event_shape ([width , width ]),
77
86
mean = 0. ,
@@ -81,7 +90,7 @@ def build_highway_flow_layer(width,
81
90
return HighwayFlow (
82
91
residual_fraction = util .TransformedVariable (
83
92
initial_value = residual_fraction_initial_value ,
84
- bijector = tfb .Sigmoid (),
93
+ bijector = sigmoid .Sigmoid (),
85
94
dtype = dtype ),
86
95
activation_fn = activation_fn ,
87
96
bias = tf .Variable (
@@ -99,7 +108,7 @@ def build_highway_flow_layer(width,
99
108
)
100
109
101
110
102
- class HighwayFlow (tfb .Bijector ):
111
+ class HighwayFlow (bijector .Bijector ):
103
112
"""Implements an Highway Flow bijector [1].
104
113
105
114
HighwayFlow interpolates the input `X` with the transformations at each step
@@ -193,12 +202,14 @@ def __init__(self, residual_fraction, activation_fn, bias,
193
202
residual_fraction , dtype = dtype , name = 'residual_fraction' )
194
203
# The upper matrix is still lower triangular, transpose is done in
195
204
# _inverse and _forwars metowds.
196
- self ._upper_diagonal_weights_matrix = tensor_util .convert_nonref_to_tensor (
197
- upper_diagonal_weights_matrix , dtype = dtype ,
198
- name = 'upper_diagonal_weights_matrix' )
199
- self ._lower_diagonal_weights_matrix = tensor_util .convert_nonref_to_tensor (
200
- lower_diagonal_weights_matrix , dtype = dtype ,
201
- name = 'lower_diagonal_weights_matrix' )
205
+ self ._upper_diagonal_weights_matrix = \
206
+ tensor_util .convert_nonref_to_tensor (
207
+ upper_diagonal_weights_matrix , dtype = dtype ,
208
+ name = 'upper_diagonal_weights_matrix' )
209
+ self ._lower_diagonal_weights_matrix = \
210
+ tensor_util .convert_nonref_to_tensor (
211
+ lower_diagonal_weights_matrix , dtype = dtype ,
212
+ name = 'lower_diagonal_weights_matrix' )
202
213
self ._activation_fn = activation_fn
203
214
self ._gate_first_n = gate_first_n if gate_first_n else self .width
204
215
0 commit comments