1
+ # Copyright 2021 The TensorFlow Probability Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
1
15
"""Tests for HighwayFlow."""
2
16
import tensorflow .compat .v2 as tf
3
17
8
22
tfb = tfp .bijectors
9
23
tfd = tfp .distributions
10
24
11
- #FIXME: test_util.test_seed throws an error
12
- seed = 1 # test_util.test_seed(sampler_type='stateless')
13
-
14
-
15
- def _dx (x , activation ):
16
- if activation == 'sigmoid' :
17
- return tf .math .sigmoid (x ) * (1 - tf .math .sigmoid (x ))
18
- elif activation == 'softplus' :
19
- return tf .math .sigmoid (x )
20
- elif activation == 'tanh' :
21
- return 1. - tf .math .tanh (x ) ** 2
22
-
23
-
24
- def _activation_log_det_jacobian (x , residual_fraction , activation , width ,
25
- gate_first_n ):
26
- if activation == 'none' :
27
- return tf .zeros (x .shape [0 ])
28
- else :
29
- return tf .reduce_sum (tf .math .log (
30
- tf .concat ([(residual_fraction ) * tf .ones (
31
- gate_first_n ), tf .zeros (width - gate_first_n )],
32
- axis = 0 ) + tf .concat ([(1. - residual_fraction ) * tf .ones (
33
- gate_first_n ), tf .ones (width - gate_first_n )],
34
- axis = 0 ) * _dx (x , activation )),
35
- - 1 )
36
-
37
25
38
26
@test_util .test_all_tf_execution_regimes
39
27
class HighwayFlowTests (test_util .TestCase ):
@@ -45,13 +33,16 @@ def testBijector(self):
45
33
# Test generic case with scalar input
46
34
x = tf .ones ((width ,)) * samplers .uniform ((width ,), minval = - 1. ,
47
35
maxval = 1. ,
48
- seed = seed )
36
+ seed = test_util .test_seed (
37
+ sampler_type = 'stateless' ))
49
38
elif dim == 1 :
50
39
# Test with 2D tensor + batch
51
40
x = tf .ones ((5 , width ,
52
41
width )) * samplers .uniform ((5 , width , width ),
53
42
minval = - 1. ,
54
- maxval = 1. , seed = seed )
43
+ maxval = 1. ,
44
+ seed = test_util .test_seed (
45
+ sampler_type = 'stateless' ))
55
46
56
47
bijector = tfp .experimental .bijectors .build_highway_flow_layer (
57
48
width , activation_fn = True )
@@ -65,12 +56,34 @@ def testBijector(self):
65
56
- bijector .inverse_log_det_jacobian (
66
57
tf .identity (bijector .forward (x )), event_ndims = dim + 1 ))
67
58
59
+ def testBijectorWithoutActivation (self ):
60
+ width = 4
61
+ x = tf .ones (2 , width ,
62
+ width ) * samplers .uniform ((2 , width , width ),
63
+ minval = - 1. ,
64
+ maxval = 1. ,
65
+ seed = test_util .test_seed (
66
+ sampler_type = 'stateless' ))
67
+
68
+ bijector = tfp .experimental .bijectors .build_highway_flow_layer (
69
+ width , activation_fn = False )
70
+ self .evaluate (
71
+ [v .initializer for v in bijector .trainable_variables ])
72
+ self .assertStartsWith (bijector .name , 'highway_flow' )
73
+ self .assertAllClose (x , bijector .inverse (
74
+ tf .identity (bijector .forward (x ))))
75
+ self .assertAllClose (
76
+ bijector .forward_log_det_jacobian (x , event_ndims = 2 ),
77
+ - bijector .inverse_log_det_jacobian (
78
+ tf .identity (bijector .forward (x )), event_ndims = 2 ))
79
+
68
80
def testGating (self ):
69
81
width = 4
70
82
x = tf .ones ((2 , width ,
71
83
width )) * samplers .uniform ((2 , width , width ),
72
84
minval = - 1. ,
73
- maxval = 1. , seed = seed )
85
+ maxval = 1. , seed = test_util .test_seed (
86
+ sampler_type = 'stateless' ))
74
87
75
88
# Test with gating half of the inputs
76
89
bijector = tfp .experimental .bijectors .build_highway_flow_layer (
@@ -98,66 +111,6 @@ def testGating(self):
98
111
- bijector .inverse_log_det_jacobian (
99
112
tf .identity (bijector .forward (x )), event_ndims = 2 ))
100
113
101
- def testJacobianWithActivation (self ):
102
- activations = ['softplus' ]
103
- batch_size = 3
104
- width = 4
105
- dtype = tf .float32
106
- gate_first_n = 2
107
- residual_fraction = tf .constant (0.5 )
108
- for activation in activations :
109
-
110
- if activation == 'sigmoid' :
111
- activation_fn = tf .nn .sigmoid
112
- elif activation == 'softplus' :
113
- activation_fn = tf .nn .softplus
114
- elif activation == 'tanh' :
115
- activation_fn = tf .nn .tanh
116
- elif activation == 'none' :
117
- activation_fn = None
118
-
119
- bijector = tfp .experimental .bijectors .HighwayFlow (
120
- residual_fraction = residual_fraction ,
121
- activation_fn = activation_fn ,
122
- bias = tf .zeros (width ),
123
- upper_diagonal_weights_matrix = tf .eye (width ),
124
- lower_diagonal_weights_matrix = tf .eye (width ),
125
- gate_first_n = gate_first_n ,
126
- )
127
-
128
- self .evaluate (
129
- [v .initializer for v in bijector .trainable_variables ])
130
- x = tf .ones ((batch_size ,
131
- width )) * samplers .uniform ((batch_size , width ), - 10. ,
132
- 10. , seed = seed )
133
- if activation == 'none' :
134
- y = x
135
- else :
136
- y = tf .concat ([(residual_fraction ) * tf .ones (gate_first_n ),
137
- tf .zeros (width - gate_first_n )],
138
- axis = 0 ) * x + tf .concat (
139
- [(1. - residual_fraction ) * tf .ones (
140
- gate_first_n ), tf .ones (width - gate_first_n )],
141
- axis = 0 ) * activation_fn (x )
142
- expected_forward_log_det_jacobian = \
143
- _activation_log_det_jacobian (x ,
144
- residual_fraction ,
145
- activation ,
146
- width ,
147
- gate_first_n )
148
- expected_inverse_log_det_jacobian = \
149
- - expected_forward_log_det_jacobian
150
- self .assertAllClose (y , bijector .forward (x ))
151
- self .assertAllClose (x , bijector .inverse (y ))
152
- self .assertAllClose (
153
- expected_inverse_log_det_jacobian ,
154
- bijector .inverse_log_det_jacobian (y , event_ndims = 1 ),
155
- )
156
- self .assertAllClose (
157
- expected_forward_log_det_jacobian ,
158
- bijector .forward_log_det_jacobian (x , event_ndims = 1 ),
159
- )
160
-
161
114
def testResidualFractionGradientsWithCenteredDifference (self ):
162
115
width = 4
163
116
batch_size = 3
@@ -180,10 +133,12 @@ def testResidualFractionGradientsWithCenteredDifference(self):
180
133
181
134
h = 1e-3
182
135
136
+ # pylint: disable=protected-access
183
137
bijector ._residual_fraction = residual_fraction + h
184
138
y1 = tf .reduce_mean (target .log_prob (bijector .forward (x )))
185
139
bijector ._residual_fraction = residual_fraction - h
186
140
y2 = tf .reduce_mean (target .log_prob (bijector .forward (x )))
141
+ # pylint: enable=protected-access
187
142
188
143
manual_grad = (y1 - y2 ) / (2 * h )
189
144
0 commit comments