Skip to content

Commit b18cf30

Browse files
jburnimtensorflower-gardener
authored andcommitted
Update distribution_layer_test for upcoming Keras changes.
Soon, when Keras Sequential and Functional models explicitly specify their input shape, those models will only accept inputs with exactly zero or one batch dimensions + the specified input shape. (Plus some corner cases where some dimensions have size 1.) PiperOrigin-RevId: 386470452
1 parent c99c86e commit b18cf30

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tensorflow_probability/python/layers/distribution_layer_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def test_keras_sequential_api_multiple_draws(self):
276276
])
277277

278278
decoder_model = tfk.Sequential([
279-
tfkl.InputLayer(input_shape=[self.encoded_size]),
279+
tfkl.InputLayer(input_shape=[None, self.encoded_size]),
280280
tfkl.Dense(10, activation='relu'),
281281
tfkl.Dense(tfpl.IndependentBernoulli.params_size(
282282
self.input_shape)),
@@ -859,7 +859,8 @@ def test_serialization(self):
859859
batch_shape + [params_size], seed=42))
860860

861861
model = tfk.Sequential([
862-
tfkl.Dense(params_size, input_shape=(params_size,), dtype=self.dtype),
862+
tfkl.Dense(params_size, input_shape=batch_shape[1:] + [params_size],
863+
dtype=self.dtype),
863864
self.layer_class(event_shape, validate_args=True, dtype=self.dtype),
864865
])
865866

@@ -1222,12 +1223,12 @@ def test_serialization(self):
12221223
n = 3
12231224
event_shape = []
12241225
params_size = self.layer_class.params_size(n, event_shape)
1225-
batch_shape = [4, 1]
1226+
batch_size = 7
12261227

12271228
low = self._build_tensor(-3., dtype=self.dtype)
12281229
high = self._build_tensor(3., dtype=self.dtype)
12291230
x = self.evaluate(tfd.Uniform(low, high).sample(
1230-
batch_shape + [params_size], seed=42))
1231+
[batch_size] + [params_size], seed=42))
12311232

12321233
model = tfk.Sequential([
12331234
tfkl.Dense(params_size, input_shape=(params_size,), dtype=self.dtype),
@@ -1243,7 +1244,7 @@ def test_serialization(self):
12431244

12441245
self.assertEqual(self.dtype, model(x).mean().dtype.as_numpy_dtype)
12451246

1246-
ones = np.ones([7] + batch_shape + event_shape, dtype=self.dtype)
1247+
ones = np.ones([3, 2] + [batch_size] + event_shape, dtype=self.dtype)
12471248
self.assertAllEqual(self.evaluate(model(x).log_prob(ones)),
12481249
self.evaluate(model_copy(x).log_prob(ones)))
12491250

0 commit comments

Comments
 (0)