File tree Expand file tree Collapse file tree 1 file changed +18
-0
lines changed
tensorflow_probability/python/layers Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -140,6 +140,24 @@ def call(self, inputs):
140
140
141
141
return outputs
142
142
143
+ def compute_output_shape (self , input_shape ):
144
+ """
145
+ Computes the output shape of the layer.
146
+ Args:
147
+ input_shape: `TensorShape` or `list` of `TensorShape`
148
+ (only last dim is used)
149
+ Returns:
150
+ The output shape.
151
+ Raises:
152
+ ValueError: If the innermost dimension of `input_shape` is not defined.
153
+ """
154
+ input_shape = tf .TensorShape (input_shape )
155
+ input_shape = input_shape .with_rank_at_least (2 )
156
+ if input_shape [- 1 ] is None :
157
+ raise ValueError (
158
+ 'The innermost dimension of input_shape must be defined, but saw: %s' % (
159
+ input_shape ,))
160
+ return input_shape [:- 1 ].concatenate (self .units )
143
161
144
162
def _make_kl_divergence_penalty (
145
163
use_exact_kl = False ,
You can’t perform that action at this time.
0 commit comments