File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed
tensorflow_compression/python/entropy_models Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -57,6 +57,17 @@ def test_quantizes_to_integers_modulo_offset(self):
57
57
x_quantized = em .quantize (x_perturbed )
58
58
self .assertAllEqual (x , x_quantized )
59
59
60
+ def test_gradients_are_straight_through (self ):
61
+ noisy = uniform_noise .NoisyNormal (loc = 0 , scale = 1 )
62
+ em = ContinuousBatchedEntropyModel (noisy , 1 )
63
+ x = tf .range (- 20. , 20. )
64
+ x_perturbed = x + tf .random .uniform (x .shape , - .49 , .49 )
65
+ with tf .GradientTape () as tape :
66
+ tape .watch (x_perturbed )
67
+ x_quantized = em .quantize (x_perturbed )
68
+ gradients = tape .gradient (x_quantized , x_perturbed )
69
+ self .assertAllEqual (gradients , tf .ones_like (gradients ))
70
+
60
71
def test_default_kwargs_throw_error_on_compression (self ):
61
72
noisy = uniform_noise .NoisyNormal (loc = .25 , scale = 10. )
62
73
em = ContinuousBatchedEntropyModel (noisy , 1 )
You can’t perform that action at this time.
0 commit comments