Skip to content

Commit 832cd7a

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds unit test for straight-through gradients.
PiperOrigin-RevId: 306778342 Change-Id: I229c469467865f4d54430aa3bb8581eaa1327a3f
1 parent 6ece08b commit 832cd7a

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ def test_quantizes_to_integers_modulo_offset(self):
5757
x_quantized = em.quantize(x_perturbed)
5858
self.assertAllEqual(x, x_quantized)
5959

60+
def test_gradients_are_straight_through(self):
61+
noisy = uniform_noise.NoisyNormal(loc=0, scale=1)
62+
em = ContinuousBatchedEntropyModel(noisy, 1)
63+
x = tf.range(-20., 20.)
64+
x_perturbed = x + tf.random.uniform(x.shape, -.49, .49)
65+
with tf.GradientTape() as tape:
66+
tape.watch(x_perturbed)
67+
x_quantized = em.quantize(x_perturbed)
68+
gradients = tape.gradient(x_quantized, x_perturbed)
69+
self.assertAllEqual(gradients, tf.ones_like(gradients))
70+
6071
def test_default_kwargs_throw_error_on_compression(self):
6172
noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.)
6273
em = ContinuousBatchedEntropyModel(noisy, 1)

0 commit comments

Comments
 (0)