Skip to content

Commit 520d4c2

Browse files
authored
Update dense_variational_v2_test.py
1 parent 17e4ee8 commit 520d4c2

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

tensorflow_probability/python/layers/dense_variational_v2_test.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,29 @@ def prior_trainable(kernel_size, bias_size=0, dtype=None):
7272
@test_util.test_all_tf_execution_regimes
7373
class DenseVariationalLayerTest(test_util.TestCase):
7474

75-
def test_end_to_end(self):
76-
# Get dataset.
77-
y, x, x_tst = create_dataset()
75+
def test_end_to_end(self):
76+
# Get dataset.
77+
y, x, x_tst = create_dataset()
7878

79-
# Build model.
80-
model = tf.keras.Sequential([
81-
tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable),
82-
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)),
83-
])
79+
layer = tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable)
8480

85-
# Do inference.
86-
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.05),
87-
loss=negloglik)
88-
model.fit(x, y, epochs=2, verbose=False)
81+
model = tf.keras.Sequential([
82+
layer,
83+
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1))
84+
])
8985

90-
# Profit.
91-
yhat = model(x_tst)
92-
assert isinstance(yhat, tfd.Distribution)
86+
# Do inference.
87+
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.05),
88+
loss=negloglik)
89+
model.fit(x, y, epochs=2, verbose=False)
9390

91+
# Check the output_shape.
92+
expected_output_shape = layer.compute_output_shape((None, x.shape[-1]))
93+
self.assertAllEqual(expected_output_shape, (None, 1))
94+
95+
# Profit.
96+
yhat = model(x_tst)
97+
assert isinstance(yhat, tfd.Distribution)
9498

9599
if __name__ == '__main__':
96100
test_util.main()

0 commit comments

Comments
 (0)