|
| 1 | +# Lint as: python3 |
| 2 | +# Copyright 2020 Google LLC. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# ============================================================================== |
| 16 | +"""Deep fully factorized distribution based on cumulative.""" |
| 17 | + |
| 18 | +import tensorflow.compat.v2 as tf |
| 19 | +import tensorflow_probability as tfp |
| 20 | + |
| 21 | +from tensorflow_compression.python.distributions import helpers |
| 22 | + |
| 23 | + |
| 24 | +__all__ = ["DeepFactorized"] |
| 25 | + |
| 26 | + |
| 27 | +class DeepFactorized(tfp.distributions.Distribution): |
| 28 | + """Fully factorized distribution based on neural network cumulative. |
| 29 | +
|
| 30 | + This is a flexible, nonparametric probability density model, described in |
| 31 | + appendix 6.1 of the paper: |
| 32 | +
|
| 33 | + > "Variational image compression with a scale hyperprior"<br /> |
| 34 | + > J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston<br /> |
| 35 | + > https://openreview.net/forum?id=rkcQFMZRb |
| 36 | +
|
| 37 | + This implementation already includes convolution with a unit-width uniform |
| 38 | + density, as described in appendix 6.2 of the same paper. Please cite the paper |
| 39 | + if you use this code for scientific work. |
| 40 | +
|
| 41 | + This is a scalar distribution (i.e., its `event_shape` is always length 0), |
| 42 | + and the density object always creates its own `tf.Variable`s representing the |
| 43 | + trainable distribution parameters. |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self, batch_shape=(), num_filters=(3, 3), init_scale=10, |
| 47 | + allow_nan_stats=False, dtype=tf.float32, name="DeepFactorized"): |
| 48 | + """Initializer. |
| 49 | +
|
| 50 | + Arguments: |
| 51 | + batch_shape: Iterable of integers. The desired batch shape for the |
| 52 | + `Distribution` (rightmost dimensions which are assumed independent, but |
| 53 | + not identically distributed). |
| 54 | + num_filters: Iterable of integers. The number of filters for each of the |
| 55 | + hidden layers. The first and last layer of the network implementing the |
| 56 | + cumulative distribution are not included (they are assumed to be 1). |
| 57 | + init_scale: Float. Scale factor for the density at initialization. It is |
| 58 | + recommended to choose a large enough scale factor such that most values |
| 59 | + initially lie within a region of high likelihood. This improves |
| 60 | + training. |
| 61 | + allow_nan_stats: Boolean. Whether to allow `NaN`s to be returned when |
| 62 | + querying distribution statistics. |
| 63 | + dtype: A floating point `tf.dtypes.DType`. Computations relating to this |
| 64 | + distribution will be performed at this precision. |
| 65 | + name: String. A name for this distribution. |
| 66 | + """ |
| 67 | + parameters = dict(locals()) |
| 68 | + self._batch_shape_tuple = tuple(int(s) for s in batch_shape) |
| 69 | + self._num_filters = tuple(int(f) for f in num_filters) |
| 70 | + self._init_scale = float(init_scale) |
| 71 | + self._estimated_tail_mass = None |
| 72 | + super().__init__( |
| 73 | + dtype=dtype, |
| 74 | + reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED, |
| 75 | + validate_args=False, |
| 76 | + allow_nan_stats=allow_nan_stats, |
| 77 | + parameters=parameters, |
| 78 | + name=name, |
| 79 | + ) |
| 80 | + self._make_variables() |
| 81 | + |
| 82 | + @property |
| 83 | + def num_filters(self): |
| 84 | + return self._num_filters |
| 85 | + |
| 86 | + @property |
| 87 | + def init_scale(self): |
| 88 | + return self._init_scale |
| 89 | + |
| 90 | + def _make_variables(self): |
| 91 | + """Creates the variables representing the parameters of the distribution.""" |
| 92 | + channels = self.batch_shape.num_elements() |
| 93 | + filters = (1,) + self.num_filters + (1,) |
| 94 | + scale = self.init_scale ** (1 / (len(self.num_filters) + 1)) |
| 95 | + self._matrices = [] |
| 96 | + self._biases = [] |
| 97 | + self._factors = [] |
| 98 | + |
| 99 | + for i in range(len(self.num_filters) + 1): |
| 100 | + init = tf.math.log(tf.math.expm1(1 / scale / filters[i + 1])) |
| 101 | + init = tf.cast(init, dtype=self.dtype) |
| 102 | + init = tf.broadcast_to(init, (channels, filters[i + 1], filters[i])) |
| 103 | + matrix = tf.Variable(init, name="matrix_{}".format(i)) |
| 104 | + self._matrices.append(matrix) |
| 105 | + |
| 106 | + bias = tf.Variable( |
| 107 | + tf.random.uniform( |
| 108 | + (channels, filters[i + 1], 1), -.5, .5, dtype=self.dtype), |
| 109 | + name="bias_{}".format(i)) |
| 110 | + self._biases.append(bias) |
| 111 | + |
| 112 | + if i < len(self.num_filters): |
| 113 | + factor = tf.Variable( |
| 114 | + tf.zeros((channels, filters[i + 1], 1), dtype=self.dtype), |
| 115 | + name="factor_{}".format(i)) |
| 116 | + self._factors.append(factor) |
| 117 | + |
| 118 | + def _batch_shape_tensor(self): |
| 119 | + return tf.constant(self._batch_shape_tuple, dtype=int) |
| 120 | + |
| 121 | + def _batch_shape(self): |
| 122 | + return tf.TensorShape(self._batch_shape_tuple) |
| 123 | + |
| 124 | + def _event_shape_tensor(self): |
| 125 | + return tf.constant((), dtype=int) |
| 126 | + |
| 127 | + def _event_shape(self): |
| 128 | + return tf.TensorShape(()) |
| 129 | + |
| 130 | + def _logits_cumulative(self, inputs): |
| 131 | + """Evaluate logits of the cumulative densities. |
| 132 | +
|
| 133 | + Arguments: |
| 134 | + inputs: The values at which to evaluate the cumulative densities, expected |
| 135 | + to be a `tf.Tensor` of shape `(channels, 1, batch)`. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + A `tf.Tensor` of the same shape as `inputs`, containing the logits of the |
| 139 | + cumulative densities evaluated at the given inputs. |
| 140 | + """ |
| 141 | + logits = inputs |
| 142 | + for i in range(len(self.num_filters) + 1): |
| 143 | + matrix = tf.nn.softplus(self._matrices[i]) |
| 144 | + logits = tf.linalg.matmul(matrix, logits) |
| 145 | + logits += self._biases[i] |
| 146 | + if i < len(self.num_filters): |
| 147 | + factor = tf.math.tanh(self._factors[i]) |
| 148 | + logits += factor * tf.math.tanh(logits) |
| 149 | + return logits |
| 150 | + |
| 151 | + def _prob(self, y): |
| 152 | + """Called by the base class to compute likelihoods.""" |
| 153 | + # Convert to (channels, 1, batch) format by collapsing dimensions and then |
| 154 | + # commuting channels to front. |
| 155 | + y = tf.broadcast_to( |
| 156 | + y, tf.broadcast_dynamic_shape(tf.shape(y), self.batch_shape_tensor())) |
| 157 | + shape = tf.shape(y) |
| 158 | + y = tf.reshape(y, (-1, 1, self.batch_shape.num_elements())) |
| 159 | + y = tf.transpose(y, (2, 1, 0)) |
| 160 | + |
| 161 | + # Evaluate densities. |
| 162 | + # We can use the special rule below to only compute differences in the left |
| 163 | + # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 |
| 164 | + # for large x, 0 for small x. Subtracting two numbers close to 0 can be done |
| 165 | + # with much higher precision than subtracting two numbers close to 1. |
| 166 | + lower = self._logits_cumulative(y - .5) |
| 167 | + upper = self._logits_cumulative(y + .5) |
| 168 | + # Flip signs if we can move more towards the left tail of the sigmoid. |
| 169 | + sign = tf.stop_gradient(-tf.math.sign(lower + upper)) |
| 170 | + p = abs(tf.sigmoid(sign * upper) - tf.sigmoid(sign * lower)) |
| 171 | + |
| 172 | + # Convert back to (broadcasted) input tensor shape. |
| 173 | + p = tf.transpose(p, (2, 1, 0)) |
| 174 | + p = tf.reshape(p, shape) |
| 175 | + return p |
| 176 | + |
| 177 | + def _quantization_offset(self): |
| 178 | + return tf.constant(0, dtype=self.dtype) |
| 179 | + |
| 180 | + def _lower_tail(self, tail_mass): |
| 181 | + tail = helpers.estimate_tail( |
| 182 | + self._logits_cumulative, -tf.math.log(2 / tail_mass - 1), |
| 183 | + [self.batch_shape.num_elements(), 1, 1], self.dtype) |
| 184 | + return tf.reshape(tail, self.batch_shape_tensor()) |
| 185 | + |
| 186 | + def _upper_tail(self, tail_mass): |
| 187 | + tail = helpers.estimate_tail( |
| 188 | + self._logits_cumulative, tf.math.log(2 / tail_mass - 1), |
| 189 | + [self.batch_shape.num_elements(), 1, 1], self.dtype) |
| 190 | + return tf.reshape(tail, self.batch_shape_tensor()) |
0 commit comments