Skip to content

Commit 0efbee9

Browse files
Merge pull request #1515 from Frightera:frighterafix#1505
PiperOrigin-RevId: 453335007
2 parents 6d04325 + bd53773 commit 0efbee9

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

tensorflow_probability/python/layers/dense_variational_v2.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,25 @@ def call(self, inputs):
142142

143143
return outputs
144144

145+
def compute_output_shape(self, input_shape):
146+
"""Computes the output shape of the layer.
147+
148+
Args:
149+
input_shape: `TensorShape` or `list` of `TensorShape`
150+
(only last dim is used)
151+
Returns:
152+
The output shape.
153+
Raises:
154+
ValueError: If the innermost dimension of `input_shape` is not defined.
155+
"""
156+
input_shape = tf.TensorShape(input_shape)
157+
input_shape = input_shape.with_rank_at_least(2)
158+
if input_shape[-1] is None:
159+
raise ValueError(
160+
f'The innermost dimension of input_shape must be defined, but saw: {input_shape}'
161+
)
162+
return input_shape[:-1].concatenate(self.units)
163+
145164

146165
def _make_kl_divergence_penalty(
147166
use_exact_kl=False,

tensorflow_probability/python/layers/dense_variational_v2_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ def test_end_to_end(self):
7676
# Get dataset.
7777
y, x, x_tst = create_dataset()
7878

79-
# Build model.
79+
layer = tfp.layers.DenseVariational(1, posterior_mean_field,
80+
prior_trainable)
81+
8082
model = tf.keras.Sequential([
81-
tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable),
82-
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)),
83+
layer,
84+
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1))
8385
])
8486

8587
# Do inference.
@@ -96,6 +98,11 @@ def test_end_to_end(self):
9698
self.assertContainsSubsequence(posterior.name, '/posterior/')
9799
self.assertContainsSubsequence(prior.name, '/prior/')
98100

101+
# Check the output_shape.
102+
expected_output_shape = layer.compute_output_shape(
103+
(None, x.shape[-1])).as_list()
104+
self.assertAllEqual(expected_output_shape, (None, 1))
105+
99106
# Profit.
100107
yhat = model(x_tst)
101108
self.assertIsInstance(yhat, tfd.Distribution)

0 commit comments

Comments
 (0)