17
17
import tensorflow .compat .v2 as tf
18
18
import tensorflow_probability as tfp
19
19
20
- from tensorflow_compression .python .distributions import helpers
21
- from tensorflow_compression .python .ops import math_ops
20
+ from tensorflow_compression .python .distributions import uniform_noise
22
21
23
22
24
- __all__ = ["DeepFactorized" ]
23
+ __all__ = ["DeepFactorized" , "NoisyDeepFactorized" ]
24
+
25
+
26
+ def log_expm1 (x ):
27
+ """Computes log(exp(x)-1) stably.
28
+
29
+ For large values of x, exp(x) will return Inf whereas log(exp(x)-1) ~= x.
30
+ Here we use this approximation for x>15, such that the output is non-Inf for
31
+ all positive values x.
32
+
33
+ Args:
34
+ x: A tensor.
35
+
36
+ Returns:
37
+ log(exp(x)-1)
38
+
39
+ """
40
+ # If x<15.0, we can compute it directly. For larger values,
41
+ # we have log(exp(x)-1) ~= log(exp(x)) = x.
42
+ cond = (x < 15.0 )
43
+ x_small = tf .minimum (x , 15.0 )
44
+ return tf .where (cond , tf .math .log (tf .math .expm1 (x_small )), x )
25
45
26
46
27
47
class DeepFactorized (tfp .distributions .Distribution ):
@@ -34,7 +54,7 @@ class DeepFactorized(tfp.distributions.Distribution):
34
54
> J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston<br />
35
55
> https://openreview.net/forum?id=rkcQFMZRb
36
56
37
- This implementation already includes convolution with a unit-width uniform
57
+ but *without* convolution with a unit-width uniform
38
58
density, as described in appendix 6.2 of the same paper. Please cite the paper
39
59
if you use this code for scientific work.
40
60
@@ -43,7 +63,8 @@ class DeepFactorized(tfp.distributions.Distribution):
43
63
trainable distribution parameters.
44
64
"""
45
65
46
- def __init__ (self , batch_shape = (), num_filters = (3 , 3 ), init_scale = 10 ,
66
+ def __init__ (self ,
67
+ batch_shape = (), num_filters = (3 , 3 ), init_scale = 10 ,
47
68
allow_nan_stats = False , dtype = tf .float32 , name = "DeepFactorized" ):
48
69
"""Initializer.
49
70
@@ -98,22 +119,31 @@ def _make_variables(self):
98
119
self ._factors = []
99
120
100
121
for i in range (len (self .num_filters ) + 1 ):
101
- init = tf .math .log (tf .math .expm1 (1 / scale / filters [i + 1 ]))
102
- init = tf .cast (init , dtype = self .dtype )
103
- init = tf .broadcast_to (init , (channels , filters [i + 1 ], filters [i ]))
104
- matrix = tf .Variable (init , name = "matrix_{}" .format (i ))
122
+
123
+ def matrix_initializer (i = i ):
124
+ init = log_expm1 (1 / scale / filters [i + 1 ])
125
+ init = tf .cast (init , dtype = self .dtype )
126
+ init = tf .broadcast_to (init , (channels , filters [i + 1 ], filters [i ]))
127
+ return init
128
+
129
+ matrix = tf .Variable (matrix_initializer , name = "matrix_{}" .format (i ))
105
130
self ._matrices .append (matrix )
106
131
107
- bias = tf .Variable (
108
- tf .random .uniform (
109
- (channels , filters [i + 1 ], 1 ), - .5 , .5 , dtype = self .dtype ),
110
- name = "bias_{}" .format (i ))
132
+ def bias_initializer (i = i ):
133
+ return tf .random .uniform ((channels , filters [i + 1 ], 1 ),
134
+ - .5 ,
135
+ .5 ,
136
+ dtype = self .dtype )
137
+
138
+ bias = tf .Variable (bias_initializer , name = "bias_{}" .format (i ))
111
139
self ._biases .append (bias )
112
140
113
141
if i < len (self .num_filters ):
114
- factor = tf .Variable (
115
- tf .zeros ((channels , filters [i + 1 ], 1 ), dtype = self .dtype ),
116
- name = "factor_{}" .format (i ))
142
+
143
+ def factor_initializer (i = i ):
144
+ return tf .zeros ((channels , filters [i + 1 ], 1 ), dtype = self .dtype )
145
+
146
+ factor = tf .Variable (factor_initializer , name = "factor_{}" .format (i ))
117
147
self ._factors .append (factor )
118
148
119
149
def _batch_shape_tensor (self ):
@@ -132,13 +162,20 @@ def _logits_cumulative(self, inputs):
132
162
"""Evaluate logits of the cumulative densities.
133
163
134
164
Arguments:
135
- inputs: The values at which to evaluate the cumulative densities, expected
136
- to be a `tf.Tensor` of shape `(channels, 1, batch)`.
165
+ inputs: The values at which to evaluate the cumulative densities.
137
166
138
167
Returns:
139
168
A `tf.Tensor` of the same shape as `inputs`, containing the logits of the
140
169
cumulative densities evaluated at the given inputs.
141
170
"""
171
+ # Convert to (channels, 1, batch) format by collapsing dimensions and then
172
+ # commuting channels to front.
173
+ inputs = tf .broadcast_to (
174
+ inputs ,
175
+ tf .broadcast_dynamic_shape (tf .shape (inputs ), self .batch_shape_tensor ()))
176
+ shape = tf .shape (inputs )
177
+ inputs = tf .reshape (inputs , (- 1 , 1 , self .batch_shape .num_elements ()))
178
+ inputs = tf .transpose (inputs , (2 , 1 , 0 ))
142
179
logits = inputs
143
180
for i in range (len (self .num_filters ) + 1 ):
144
181
matrix = tf .nn .softplus (self ._matrices [i ])
@@ -147,48 +184,53 @@ def _logits_cumulative(self, inputs):
147
184
if i < len (self .num_filters ):
148
185
factor = tf .math .tanh (self ._factors [i ])
149
186
logits += factor * tf .math .tanh (logits )
150
- return logits
151
-
152
- def _prob (self , y ):
153
- """Called by the base class to compute likelihoods."""
154
- # Convert to (channels, 1, batch) format by collapsing dimensions and then
155
- # commuting channels to front.
156
- y = tf .broadcast_to (
157
- y , tf .broadcast_dynamic_shape (tf .shape (y ), self .batch_shape_tensor ()))
158
- shape = tf .shape (y )
159
- y = tf .reshape (y , (- 1 , 1 , self .batch_shape .num_elements ()))
160
- y = tf .transpose (y , (2 , 1 , 0 ))
161
-
162
- # Evaluate densities.
163
- # We can use the special rule below to only compute differences in the left
164
- # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1
165
- # for large x, 0 for small x. Subtracting two numbers close to 0 can be done
166
- # with much higher precision than subtracting two numbers close to 1.
167
- lower = self ._logits_cumulative (y - .5 )
168
- upper = self ._logits_cumulative (y + .5 )
169
- # Flip signs if we can move more towards the left tail of the sigmoid.
170
- sign = tf .stop_gradient (- tf .math .sign (lower + upper ))
171
- p = abs (tf .sigmoid (sign * upper ) - tf .sigmoid (sign * lower ))
172
- p = math_ops .lower_bound (p , 0. )
173
187
174
188
# Convert back to (broadcasted) input tensor shape.
175
- p = tf .transpose (p , (2 , 1 , 0 ))
176
- p = tf .reshape (p , shape )
177
- return p
189
+ logits = tf .transpose (logits , (2 , 1 , 0 ))
190
+ logits = tf .reshape (logits , shape )
191
+ return logits
192
+
193
+ def _log_cdf (self , inputs ):
194
+ logits = self ._logits_cumulative (inputs )
195
+ return tf .math .log_sigmoid (logits )
196
+
197
+ def _log_survival_function (self , inputs ):
198
+ logits = self ._logits_cumulative (inputs )
199
+ # 1-sigmoid(x) = sigmoid(-x)
200
+ return tf .math .log_sigmoid (- logits )
201
+
202
+ def _cdf (self , inputs ):
203
+ logits = self ._logits_cumulative (inputs )
204
+ return tf .math .sigmoid (logits )
205
+
206
+ def _prob (self , inputs ):
207
+ with tf .GradientTape () as tape :
208
+ tape .watch (inputs )
209
+ cdf = self ._cdf (inputs )
210
+ prob = tape .gradient (cdf , inputs )
211
+ return prob
212
+
213
+ def _log_prob (self , inputs ):
214
+ # let x=inputs and s(x)=sigmoid(x).
215
+ with tf .GradientTape () as tape :
216
+ tape .watch (inputs )
217
+ logits = self ._logits_cumulative (inputs )
218
+ # We have F(x) = s(logits(x))
219
+ # so p(x) = F'(x)
220
+ # = s'(logits(x)) * logits'(x)
221
+ # = s(logits(x))*s(-logits(x)) * logits'(x)
222
+ # so log p(x) = log(s(logits(x)) + log(s(-logits(x)) + log(logits'(x))
223
+ log_s_logits = tf .math .log_sigmoid (logits )
224
+ log_s_neg_logits = tf .math .log_sigmoid (- logits )
225
+ dlogits = tape .gradient (logits , inputs )
226
+ return log_s_logits + log_s_neg_logits + tf .math .log (dlogits )
178
227
179
228
def _quantization_offset (self ):
180
229
return tf .constant (0 , dtype = self .dtype )
181
230
182
- def _lower_tail (self , tail_mass ):
183
- tail = helpers .estimate_tails (
184
- self ._logits_cumulative , - tf .math .log (2 / tail_mass - 1 ),
185
- tf .constant ([self .batch_shape .num_elements (), 1 , 1 ], tf .int32 ),
186
- self .dtype )
187
- return tf .reshape (tail , self .batch_shape_tensor ())
188
-
189
- def _upper_tail (self , tail_mass ):
190
- tail = helpers .estimate_tails (
191
- self ._logits_cumulative , tf .math .log (2 / tail_mass - 1 ),
192
- tf .constant ([self .batch_shape .num_elements (), 1 , 1 ], tf .int32 ),
193
- self .dtype )
194
- return tf .reshape (tail , self .batch_shape_tensor ())
231
+
232
+ class NoisyDeepFactorized (uniform_noise .UniformNoiseAdapter ):
233
+ """DeepFactorized that is convolved with uniform noise."""
234
+
235
+ def __init__ (self , name = "NoisyDeepFactorized" , ** kwargs ):
236
+ super ().__init__ (DeepFactorized (** kwargs ), name = name )
0 commit comments