Skip to content

Commit 08ed88b

Browse files
committed
removed tf.ones multiplications and added identity for gradients check
1 parent 1d294b8 commit 08ed88b

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

tensorflow_probability/python/experimental/bijectors/highway_flow_test.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,15 @@ def testBijector(self):
3131
for dim in range(2):
3232
if dim == 0:
3333
# Test generic case with scalar input
34-
x = tf.ones((width,)) * samplers.uniform((width,), minval=-1.,
35-
maxval=1.,
36-
seed=test_util.test_seed(
37-
sampler_type='stateless'))
34+
x = samplers.uniform((width,), minval=-1.,
35+
maxval=1.,
36+
seed=test_util.test_seed(sampler_type='stateless'))
3837
elif dim == 1:
3938
# Test with 2D tensor + batch
40-
x = tf.ones((5, width,
41-
width)) * samplers.uniform((5, width, width),
42-
minval=-1.,
43-
maxval=1.,
44-
seed=test_util.test_seed(
45-
sampler_type='stateless'))
39+
x = samplers.uniform((5, width, width),
40+
minval=-1.,
41+
maxval=1.,
42+
seed=test_util.test_seed(sampler_type='stateless'))
4643

4744
bijector = tfp.experimental.bijectors.build_highway_flow_layer(
4845
width, activation_fn=True)
@@ -58,12 +55,10 @@ def testBijector(self):
5855

5956
def testBijectorWithoutActivation(self):
6057
width = 4
61-
x = tf.ones(2, width,
62-
width) * samplers.uniform((2, width, width),
63-
minval=-1.,
64-
maxval=1.,
65-
seed=test_util.test_seed(
66-
sampler_type='stateless'))
58+
x = samplers.uniform((2, width, width),
59+
minval=-1.,
60+
maxval=1.,
61+
seed=test_util.test_seed(sampler_type='stateless'))
6762

6863
bijector = tfp.experimental.bijectors.build_highway_flow_layer(
6964
width, activation_fn=False)
@@ -79,11 +74,10 @@ def testBijectorWithoutActivation(self):
7974

8075
def testGating(self):
8176
width = 4
82-
x = tf.ones((2, width,
83-
width)) * samplers.uniform((2, width, width),
84-
minval=-1.,
85-
maxval=1., seed=test_util.test_seed(
86-
sampler_type='stateless'))
77+
x = samplers.uniform((2, width, width),
78+
minval=-1.,
79+
maxval=1.,
80+
seed=test_util.test_seed(sampler_type='stateless'))
8781

8882
# Test with gating half of the inputs
8983
bijector = tfp.experimental.bijectors.build_highway_flow_layer(
@@ -135,9 +129,9 @@ def testResidualFractionGradientsWithCenteredDifference(self):
135129

136130
# pylint: disable=protected-access
137131
bijector._residual_fraction = residual_fraction + h
138-
y1 = tf.reduce_mean(target.log_prob(bijector.forward(x)))
132+
y1 = tf.reduce_mean(target.log_prob(bijector.forward(tf.identity(x))))
139133
bijector._residual_fraction = residual_fraction - h
140-
y2 = tf.reduce_mean(target.log_prob(bijector.forward(x)))
134+
y2 = tf.reduce_mean(target.log_prob(bijector.forward(tf.identity(x))))
141135
# pylint: enable=protected-access
142136

143137
manual_grad = (y1 - y2) / (2 * h)

0 commit comments

Comments
 (0)