Skip to content

Commit 63501a0

Browse files
authored
Added input_shape param when testing layers
1 parent 75dae2d commit 63501a0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow_probability/python/layers/conv_variational_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,8 @@ def _testLayerInSequential(self, layer_class): # pylint: disable=invalid-name
609609
outputs = self.maybe_transpose_tensor(outputs)
610610

611611
net = tf.keras.Sequential([
612-
layer_class(filters=2, kernel_size=3, data_format=self.data_format),
612+
layer_class(filters=2, kernel_size=3, data_format=self.data_format,
613+
input_shape = inputs.shape.as_list()[1:]),
613614
layer_class(filters=2, kernel_size=1, data_format=self.data_format)])
614615

615616
net.compile(loss='mse', optimizer='adam')

0 commit comments

Comments
 (0)