Skip to content

Commit 356ef3d

Browse files
committed
refactored and linted
1 parent 8f78852 commit 356ef3d

File tree

1 file changed

+44
-89
lines changed

1 file changed

+44
-89
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow_test.py

Lines changed: 44 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
115
"""Tests for HighwayFlow."""
216
import tensorflow.compat.v2 as tf
317

@@ -8,32 +22,6 @@
822
tfb = tfp.bijectors
923
tfd = tfp.distributions
1024

11-
#FIXME: test_util.test_seed throws an error
12-
seed = 1 # test_util.test_seed(sampler_type='stateless')
13-
14-
15-
def _dx(x, activation):
16-
if activation == 'sigmoid':
17-
return tf.math.sigmoid(x) * (1 - tf.math.sigmoid(x))
18-
elif activation == 'softplus':
19-
return tf.math.sigmoid(x)
20-
elif activation == 'tanh':
21-
return 1. - tf.math.tanh(x) ** 2
22-
23-
24-
def _activation_log_det_jacobian(x, residual_fraction, activation, width,
25-
gate_first_n):
26-
if activation == 'none':
27-
return tf.zeros(x.shape[0])
28-
else:
29-
return tf.reduce_sum(tf.math.log(
30-
tf.concat([(residual_fraction) * tf.ones(
31-
gate_first_n), tf.zeros(width - gate_first_n)],
32-
axis=0) + tf.concat([(1. - residual_fraction) * tf.ones(
33-
gate_first_n), tf.ones(width - gate_first_n)],
34-
axis=0) * _dx(x, activation)),
35-
-1)
36-
3725

3826
@test_util.test_all_tf_execution_regimes
3927
class HighwayFlowTests(test_util.TestCase):
@@ -45,13 +33,16 @@ def testBijector(self):
4533
# Test generic case with scalar input
4634
x = tf.ones((width,)) * samplers.uniform((width,), minval=-1.,
4735
maxval=1.,
48-
seed=seed)
36+
seed=test_util.test_seed(
37+
sampler_type='stateless'))
4938
elif dim == 1:
5039
# Test with 2D tensor + batch
5140
x = tf.ones((5, width,
5241
width)) * samplers.uniform((5, width, width),
5342
minval=-1.,
54-
maxval=1., seed=seed)
43+
maxval=1.,
44+
seed=test_util.test_seed(
45+
sampler_type='stateless'))
5546

5647
bijector = tfp.experimental.bijectors.build_highway_flow_layer(
5748
width, activation_fn=True)
@@ -65,12 +56,34 @@ def testBijector(self):
6556
-bijector.inverse_log_det_jacobian(
6657
tf.identity(bijector.forward(x)), event_ndims=dim + 1))
6758

59+
def testBijectorWithoutActivation(self):
60+
width = 4
61+
x = tf.ones(2, width,
62+
width) * samplers.uniform((2, width, width),
63+
minval=-1.,
64+
maxval=1.,
65+
seed=test_util.test_seed(
66+
sampler_type='stateless'))
67+
68+
bijector = tfp.experimental.bijectors.build_highway_flow_layer(
69+
width, activation_fn=False)
70+
self.evaluate(
71+
[v.initializer for v in bijector.trainable_variables])
72+
self.assertStartsWith(bijector.name, 'highway_flow')
73+
self.assertAllClose(x, bijector.inverse(
74+
tf.identity(bijector.forward(x))))
75+
self.assertAllClose(
76+
bijector.forward_log_det_jacobian(x, event_ndims=2),
77+
-bijector.inverse_log_det_jacobian(
78+
tf.identity(bijector.forward(x)), event_ndims=2))
79+
6880
def testGating(self):
6981
width = 4
7082
x = tf.ones((2, width,
7183
width)) * samplers.uniform((2, width, width),
7284
minval=-1.,
73-
maxval=1., seed=seed)
85+
maxval=1., seed=test_util.test_seed(
86+
sampler_type='stateless'))
7487

7588
# Test with gating half of the inputs
7689
bijector = tfp.experimental.bijectors.build_highway_flow_layer(
@@ -98,66 +111,6 @@ def testGating(self):
98111
-bijector.inverse_log_det_jacobian(
99112
tf.identity(bijector.forward(x)), event_ndims=2))
100113

101-
def testJacobianWithActivation(self):
102-
activations = ['softplus']
103-
batch_size = 3
104-
width = 4
105-
dtype = tf.float32
106-
gate_first_n = 2
107-
residual_fraction = tf.constant(0.5)
108-
for activation in activations:
109-
110-
if activation == 'sigmoid':
111-
activation_fn = tf.nn.sigmoid
112-
elif activation == 'softplus':
113-
activation_fn = tf.nn.softplus
114-
elif activation == 'tanh':
115-
activation_fn = tf.nn.tanh
116-
elif activation == 'none':
117-
activation_fn = None
118-
119-
bijector = tfp.experimental.bijectors.HighwayFlow(
120-
residual_fraction=residual_fraction,
121-
activation_fn=activation_fn,
122-
bias=tf.zeros(width),
123-
upper_diagonal_weights_matrix=tf.eye(width),
124-
lower_diagonal_weights_matrix=tf.eye(width),
125-
gate_first_n=gate_first_n,
126-
)
127-
128-
self.evaluate(
129-
[v.initializer for v in bijector.trainable_variables])
130-
x = tf.ones((batch_size,
131-
width)) * samplers.uniform((batch_size, width), -10.,
132-
10., seed=seed)
133-
if activation == 'none':
134-
y = x
135-
else:
136-
y = tf.concat([(residual_fraction) * tf.ones(gate_first_n),
137-
tf.zeros(width - gate_first_n)],
138-
axis=0) * x + tf.concat(
139-
[(1. - residual_fraction) * tf.ones(
140-
gate_first_n), tf.ones(width - gate_first_n)],
141-
axis=0) * activation_fn(x)
142-
expected_forward_log_det_jacobian = \
143-
_activation_log_det_jacobian(x,
144-
residual_fraction,
145-
activation,
146-
width,
147-
gate_first_n)
148-
expected_inverse_log_det_jacobian = \
149-
-expected_forward_log_det_jacobian
150-
self.assertAllClose(y, bijector.forward(x))
151-
self.assertAllClose(x, bijector.inverse(y))
152-
self.assertAllClose(
153-
expected_inverse_log_det_jacobian,
154-
bijector.inverse_log_det_jacobian(y, event_ndims=1),
155-
)
156-
self.assertAllClose(
157-
expected_forward_log_det_jacobian,
158-
bijector.forward_log_det_jacobian(x, event_ndims=1),
159-
)
160-
161114
def testResidualFractionGradientsWithCenteredDifference(self):
162115
width = 4
163116
batch_size = 3
@@ -180,10 +133,12 @@ def testResidualFractionGradientsWithCenteredDifference(self):
180133

181134
h = 1e-3
182135

136+
# pylint: disable=protected-access
183137
bijector._residual_fraction = residual_fraction + h
184138
y1 = tf.reduce_mean(target.log_prob(bijector.forward(x)))
185139
bijector._residual_fraction = residual_fraction - h
186140
y2 = tf.reduce_mean(target.log_prob(bijector.forward(x)))
141+
# pylint: enable=protected-access
187142

188143
manual_grad = (y1 - y2) / (2 * h)
189144

0 commit comments

Comments
 (0)