Skip to content

Commit 17e4ee8

Browse files
authored
Implement compute_output_shape method
1 parent 201dc23 commit 17e4ee8

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tensorflow_probability/python/layers/dense_variational_v2.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,24 @@ def call(self, inputs):
140140

141141
return outputs
142142

143+
def compute_output_shape(self, input_shape):
144+
"""
145+
Computes the output shape of the layer.
146+
Args:
147+
input_shape: `TensorShape` or `list` of `TensorShape`
148+
(only last dim is used)
149+
Returns:
150+
The output shape.
151+
Raises:
152+
ValueError: If the innermost dimension of `input_shape` is not defined.
153+
"""
154+
input_shape = tf.TensorShape(input_shape)
155+
input_shape = input_shape.with_rank_at_least(2)
156+
if input_shape[-1] is None:
157+
raise ValueError(
158+
'The innermost dimension of input_shape must be defined, but saw: %s' % (
159+
input_shape,))
160+
return input_shape[:-1].concatenate(self.units)
143161

144162
def _make_kl_divergence_penalty(
145163
use_exact_kl=False,

0 commit comments

Comments
 (0)