Skip to content

Commit afbc1a5

Browse files
fchollettensorflower-gardener
authored andcommitted
Remove usage of deprecated layer API.
PiperOrigin-RevId: 427561526
1 parent 1d651d6 commit afbc1a5

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

tensorflow_probability/python/layers/conv_variational_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,14 @@ def _testKLPenaltyKernel(self, layer_class): # pylint: disable=invalid-name
245245
inputs = self.maybe_transpose_tensor(inputs)
246246

247247
# No keys.
248-
input_dependent_losses = layer.get_losses_for(inputs=None)
248+
input_dependent_losses = layer.losses
249249
self.assertEqual(len(layer.losses), 0)
250250
self.assertListEqual(layer.losses, input_dependent_losses)
251251

252252
_ = layer(inputs)
253253

254254
# Yes keys.
255-
input_dependent_losses = layer.get_losses_for(inputs=None)
255+
input_dependent_losses = layer.losses
256256
self.assertEqual(len(layer.losses), 1)
257257
self.assertEqual(layer.losses[0].shape, ())
258258
self.assertListEqual(layer.losses, input_dependent_losses)
@@ -277,18 +277,14 @@ def _testKLPenaltyBoth(self, layer_class): # pylint: disable=invalid-name
277277
inputs = self.maybe_transpose_tensor(inputs)
278278

279279
# No keys.
280-
input_dependent_losses = layer.get_losses_for(inputs=None)
281280
self.assertEqual(len(layer.losses), 0)
282-
self.assertListEqual(layer.losses, input_dependent_losses)
283281

284282
_ = layer(inputs)
285283

286284
# Yes keys.
287-
input_dependent_losses = layer.get_losses_for(inputs=None)
288285
self.assertEqual(len(layer.losses), 2)
289286
self.assertEqual(layer.losses[0].shape, ())
290287
self.assertEqual(layer.losses[1].shape, ())
291-
self.assertListEqual(layer.losses, input_dependent_losses)
292288

293289
def _testConvSetUp(self, layer_class, batch_size, depth=None,
294290
height=None, width=None, channels=None, filters=None,
@@ -350,7 +346,7 @@ def _testConvSetUp(self, layer_class, batch_size, depth=None,
350346

351347
outputs = layer(inputs)
352348

353-
kl_penalty = layer.get_losses_for(inputs=None)
349+
kl_penalty = layer.losses
354350
return (kernel_posterior, kernel_prior, kernel_divergence,
355351
bias_posterior, bias_prior, bias_divergence,
356352
layer, inputs, outputs, kl_penalty, kernel_shape)

tensorflow_probability/python/layers/dense_variational_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,13 @@ def _testKLPenaltyKernel(self, layer_class):
140140
inputs = tf.random.uniform([2, 3], seed=1)
141141

142142
# No keys.
143-
input_dependent_losses = layer.get_losses_for(inputs=None)
144143
self.assertEqual(len(layer.losses), 0)
145-
self.assertListEqual(layer.losses, input_dependent_losses)
146144

147145
_ = layer(inputs)
148146

149147
# Yes keys.
150-
input_dependent_losses = layer.get_losses_for(inputs=None)
151148
self.assertEqual(len(layer.losses), 1)
152149
self.assertEqual(layer.losses[0].shape, ())
153-
self.assertListEqual(layer.losses, input_dependent_losses)
154150

155151
def _testKLPenaltyBoth(self, layer_class):
156152
with self.cached_session():
@@ -161,18 +157,14 @@ def _testKLPenaltyBoth(self, layer_class):
161157
inputs = tf.random.uniform([2, 3], seed=1)
162158

163159
# No keys.
164-
input_dependent_losses = layer.get_losses_for(inputs=None)
165160
self.assertEqual(len(layer.losses), 0)
166-
self.assertListEqual(layer.losses, input_dependent_losses)
167161

168162
_ = layer(inputs)
169163

170164
# Yes keys.
171-
input_dependent_losses = layer.get_losses_for(inputs=None)
172165
self.assertEqual(len(layer.losses), 2)
173166
self.assertEqual(layer.losses[0].shape, ())
174167
self.assertEqual(layer.losses[1].shape, ())
175-
self.assertListEqual(layer.losses, input_dependent_losses)
176168

177169
def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size,
178170
**kwargs):
@@ -215,7 +207,7 @@ def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size,
215207

216208
outputs = layer(inputs)
217209

218-
kl_penalty = layer.get_losses_for(inputs=None)
210+
kl_penalty = layer.losses
219211
return (kernel_posterior, kernel_prior, kernel_divergence,
220212
bias_posterior, bias_prior, bias_divergence,
221213
layer, inputs, outputs, kl_penalty)

0 commit comments

Comments
 (0)