@@ -72,25 +72,29 @@ def prior_trainable(kernel_size, bias_size=0, dtype=None):
72
72
@test_util .test_all_tf_execution_regimes
73
73
class DenseVariationalLayerTest (test_util .TestCase ):
74
74
75
- def test_end_to_end (self ):
76
- # Get dataset.
77
- y , x , x_tst = create_dataset ()
75
+ def test_end_to_end (self ):
76
+ # Get dataset.
77
+ y , x , x_tst = create_dataset ()
78
78
79
- # Build model.
80
- model = tf .keras .Sequential ([
81
- tfp .layers .DenseVariational (1 , posterior_mean_field , prior_trainable ),
82
- tfp .layers .DistributionLambda (lambda t : tfd .Normal (loc = t , scale = 1 )),
83
- ])
79
+ layer = tfp .layers .DenseVariational (1 , posterior_mean_field , prior_trainable )
84
80
85
- # Do inference.
86
- model . compile ( optimizer = tf . optimizers . Adam ( learning_rate = 0.05 ) ,
87
- loss = negloglik )
88
- model . fit ( x , y , epochs = 2 , verbose = False )
81
+ model = tf . keras . Sequential ([
82
+ layer ,
83
+ tfp . layers . DistributionLambda ( lambda t : tfd . Normal ( loc = t , scale = 1 ) )
84
+ ] )
89
85
90
- # Profit.
91
- yhat = model (x_tst )
92
- assert isinstance (yhat , tfd .Distribution )
86
+ # Do inference.
87
+ model .compile (optimizer = tf .optimizers .Adam (learning_rate = 0.05 ),
88
+ loss = negloglik )
89
+ model .fit (x , y , epochs = 2 , verbose = False )
93
90
91
+ # Check the output_shape.
92
+ expected_output_shape = layer .compute_output_shape ((None , x .shape [- 1 ]))
93
+ self .assertAllEqual (expected_output_shape , (None , 1 ))
94
+
95
+ # Profit.
96
+ yhat = model (x_tst )
97
+ assert isinstance (yhat , tfd .Distribution )
94
98
95
99
if __name__ == '__main__' :
96
100
test_util .main ()
0 commit comments