|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Copyright 2018 Google LLC. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# ============================================================================== |
| 16 | +"""Basic nonlinear transform coder for RGB images. |
| 17 | +
|
| 18 | +This is a close approximation of the image compression model of |
| 19 | +Ballé, Laparra, Simoncelli (2017): |
| 20 | +End-to-end optimized image compression |
| 21 | +https://arxiv.org/abs/1611.01704 |
| 22 | +""" |
| 23 | + |
| 24 | +from __future__ import absolute_import |
| 25 | +from __future__ import division |
| 26 | +from __future__ import print_function |
| 27 | + |
| 28 | +import argparse |
| 29 | + |
| 30 | +# Dependency imports |
| 31 | + |
| 32 | +import numpy as np |
| 33 | +import tensorflow as tf |
| 34 | +import tensorflow_compression as tfc |
| 35 | + |
| 36 | + |
| 37 | +def load_image(filename): |
| 38 | + string = tf.read_file(filename) |
| 39 | + image = tf.image.decode_image(string, channels=3) |
| 40 | + image = tf.cast(image, tf.float32) |
| 41 | + image /= 255 |
| 42 | + return image |
| 43 | + |
| 44 | + |
| 45 | +def save_image(filename, image): |
| 46 | + image = tf.clip_by_value(image, 0, 1) |
| 47 | + image = tf.round(image * 255) |
| 48 | + image = tf.cast(image, tf.uint8) |
| 49 | + string = tf.image.encode_png(image) |
| 50 | + return tf.write_file(filename, string) |
| 51 | + |
| 52 | + |
| 53 | +def analysis_transform(tensor, num_filters): |
| 54 | + with tf.variable_scope("analysis"): |
| 55 | + with tf.variable_scope("layer_0"): |
| 56 | + layer = tfc.SignalConv2D( |
| 57 | + num_filters, (9, 9), corr=True, strides_down=4, padding="same_zeros", |
| 58 | + use_bias=True, activation=tfc.GDN()) |
| 59 | + tensor = layer(tensor) |
| 60 | + |
| 61 | + with tf.variable_scope("layer_1"): |
| 62 | + layer = tfc.SignalConv2D( |
| 63 | + num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", |
| 64 | + use_bias=True, activation=tfc.GDN()) |
| 65 | + tensor = layer(tensor) |
| 66 | + |
| 67 | + with tf.variable_scope("layer_2"): |
| 68 | + layer = tfc.SignalConv2D( |
| 69 | + num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", |
| 70 | + use_bias=False, activation=None) |
| 71 | + tensor = layer(tensor) |
| 72 | + |
| 73 | + return tensor |
| 74 | + |
| 75 | + |
| 76 | +def synthesis_transform(tensor, num_filters): |
| 77 | + with tf.variable_scope("synthesis"): |
| 78 | + with tf.variable_scope("layer_0"): |
| 79 | + layer = tfc.SignalConv2D( |
| 80 | + num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", |
| 81 | + use_bias=True, activation=tfc.GDN(inverse=True)) |
| 82 | + tensor = layer(tensor) |
| 83 | + |
| 84 | + with tf.variable_scope("layer_1"): |
| 85 | + layer = tfc.SignalConv2D( |
| 86 | + num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", |
| 87 | + use_bias=True, activation=tfc.GDN(inverse=True)) |
| 88 | + tensor = layer(tensor) |
| 89 | + |
| 90 | + with tf.variable_scope("layer_2"): |
| 91 | + layer = tfc.SignalConv2D( |
| 92 | + 3, (9, 9), corr=False, strides_up=4, padding="same_zeros", |
| 93 | + use_bias=True, activation=None) |
| 94 | + tensor = layer(tensor) |
| 95 | + |
| 96 | + return tensor |
| 97 | + |
| 98 | + |
| 99 | +def train(args): |
| 100 | + # Load all training images into a constant. |
| 101 | + images = tf.map_fn( |
| 102 | + load_image, tf.matching_files(args.data_glob), |
| 103 | + dtype=tf.float32, back_prop=False) |
| 104 | + with tf.Session() as sess: |
| 105 | + images = tf.constant(sess.run(images), name="images") |
| 106 | + |
| 107 | + # Training inputs are random crops out of the images tensor. |
| 108 | + crop_shape = (args.batchsize, args.patchsize, args.patchsize, 3) |
| 109 | + x = tf.random_crop(images, crop_shape) |
| 110 | + num_pixels = np.prod(crop_shape[:-1]) |
| 111 | + |
| 112 | + # Build autoencoder. |
| 113 | + y = analysis_transform(x, args.num_filters) |
| 114 | + entropy_bottleneck = tfc.EntropyBottleneck() |
| 115 | + y_tilde, likelihoods = entropy_bottleneck(y, training=True) |
| 116 | + x_tilde = synthesis_transform(y_tilde, args.num_filters) |
| 117 | + |
| 118 | + # Total number of bits divided by number of pixels. |
| 119 | + train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) |
| 120 | + |
| 121 | + # Mean squared error across pixels. |
| 122 | + train_mse = tf.reduce_sum(tf.squared_difference(x, x_tilde)) / num_pixels |
| 123 | + |
| 124 | + # The rate-distortion cost. |
| 125 | + train_loss = args.lmbda * train_mse + train_bpp |
| 126 | + |
| 127 | + # Minimize loss and auxiliary loss, and execute update op. |
| 128 | + step = tf.train.create_global_step() |
| 129 | + main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) |
| 130 | + main_step = main_optimizer.minimize(train_loss, global_step=step) |
| 131 | + |
| 132 | + aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) |
| 133 | + aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0]) |
| 134 | + |
| 135 | + train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) |
| 136 | + |
| 137 | + hooks = [ |
| 138 | + tf.train.StopAtStepHook(last_step=args.last_step), |
| 139 | + tf.train.NanTensorHook(train_loss), |
| 140 | + ] |
| 141 | + with tf.train.MonitoredTrainingSession( |
| 142 | + hooks=hooks, checkpoint_dir=args.checkpoint_dir) as sess: |
| 143 | + while not sess.should_stop(): |
| 144 | + sess.run(train_op) |
| 145 | + |
| 146 | + |
| 147 | +def compress(args): |
| 148 | + # Load input image and add batch dimension. |
| 149 | + x = load_image(args.input) |
| 150 | + x = tf.expand_dims(x, 0) |
| 151 | + x.set_shape([1, None, None, 3]) |
| 152 | + |
| 153 | + # Transform and compress the image, then remove batch dimension. |
| 154 | + y = analysis_transform(x, args.num_filters) |
| 155 | + entropy_bottleneck = tfc.EntropyBottleneck() |
| 156 | + string = entropy_bottleneck.compress(y) |
| 157 | + string = tf.squeeze(string, axis=0) |
| 158 | + |
| 159 | + # Transform the quantized image back (if requested). |
| 160 | + y_hat, likelihoods = entropy_bottleneck(y, training=False) |
| 161 | + x_hat = synthesis_transform(y_hat, args.num_filters) |
| 162 | + |
| 163 | + num_pixels = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1])) |
| 164 | + |
| 165 | + # Total number of bits divided by number of pixels. |
| 166 | + eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) |
| 167 | + |
| 168 | + # Mean squared error across pixels. |
| 169 | + mse = tf.reduce_sum(tf.squared_difference(x, x_hat)) / num_pixels |
| 170 | + |
| 171 | + with tf.Session() as sess: |
| 172 | + # Load the latest model checkpoint, get the compressed string and the tensor |
| 173 | + # shapes. |
| 174 | + latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) |
| 175 | + tf.train.Saver().restore(sess, save_path=latest) |
| 176 | + string, x_shape, y_shape = sess.run([string, tf.shape(x), tf.shape(y)]) |
| 177 | + |
| 178 | + # Write a binary file with the shape information and the compressed string. |
| 179 | + with open(args.output, "wb") as file: |
| 180 | + file.write(np.array(x_shape[1:-1], dtype=np.uint16).tobytes()) |
| 181 | + file.write(np.array(y_shape[1:-1], dtype=np.uint16).tobytes()) |
| 182 | + file.write(string) |
| 183 | + |
| 184 | + # If requested, transform the quantized image back and measure performance. |
| 185 | + if args.verbose: |
| 186 | + eval_bpp, mse, num_pixels = sess.run([eval_bpp, mse, num_pixels]) |
| 187 | + |
| 188 | + # The actual bits per pixel including overhead. |
| 189 | + bpp = (8 + len(string)) * 8 / num_pixels |
| 190 | + |
| 191 | + print("Mean squared error: {:0.4}".format(mse)) |
| 192 | + print("Information content of this image in bpp: {:0.4}".format(eval_bpp)) |
| 193 | + print("Actual bits per pixel for this image: {:0.4}".format(bpp)) |
| 194 | + |
| 195 | + |
| 196 | +def decompress(args): |
| 197 | + # Read the shape information and compressed string from the binary file. |
| 198 | + with open(args.input, "rb") as file: |
| 199 | + x_shape = np.frombuffer(file.read(4), dtype=np.uint16) |
| 200 | + y_shape = np.frombuffer(file.read(4), dtype=np.uint16) |
| 201 | + string = file.read() |
| 202 | + |
| 203 | + bits = 8 * len(string) |
| 204 | + y_shape = [int(s) for s in y_shape] + [args.num_filters] |
| 205 | + |
| 206 | + # Add a batch dimension, then decompress and transform the image back. |
| 207 | + strings = tf.expand_dims(string, 0) |
| 208 | + entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) |
| 209 | + y_hat = entropy_bottleneck.decompress( |
| 210 | + strings, y_shape, channels=args.num_filters) |
| 211 | + x_hat = synthesis_transform(y_hat, args.num_filters) |
| 212 | + |
| 213 | + # Remove batch dimension, and crop away any extraneous padding on the bottom |
| 214 | + # or right boundaries. |
| 215 | + x_hat = x_hat[0, :x_shape[0], :x_shape[1], :] |
| 216 | + |
| 217 | + # Write reconstructed image out as a PNG file. |
| 218 | + op = save_image(args.output, x_hat) |
| 219 | + |
| 220 | + # Load the latest model checkpoint, and perform the above actions. |
| 221 | + with tf.Session() as sess: |
| 222 | + latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) |
| 223 | + tf.train.Saver().restore(sess, save_path=latest) |
| 224 | + sess.run(op) |
| 225 | + |
| 226 | + |
| 227 | +if __name__ == "__main__": |
| 228 | + parser = argparse.ArgumentParser( |
| 229 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 230 | + |
| 231 | + parser.add_argument( |
| 232 | + "command", choices=["train", "compress", "decompress"], |
| 233 | + help="What to do: 'train' loads training data and trains (or continues " |
| 234 | + "to train) a new model. 'compress' reads an image file (lossless " |
| 235 | + "PNG format) and writes a compressed binary file. 'decompress' " |
| 236 | + "reads a binary file and reconstructs the image (in PNG format). " |
| 237 | + "input and output filenames need to be provided for the latter " |
| 238 | + "two options.") |
| 239 | + parser.add_argument( |
| 240 | + "input", nargs="?", |
| 241 | + help="Input filename.") |
| 242 | + parser.add_argument( |
| 243 | + "output", nargs="?", |
| 244 | + help="Output filename.") |
| 245 | + parser.add_argument("--verbose", "-v", action="store_true", |
| 246 | + help="Report bitrate and distortion when training or compressing.") |
| 247 | + parser.add_argument("--num_filters", type=int, default=128, |
| 248 | + help="Number of filters per layer.") |
| 249 | + parser.add_argument("--checkpoint_dir", default="train", |
| 250 | + help="Directory where to save/load model checkpoints.") |
| 251 | + parser.add_argument("--data_glob", default="images/*.png", |
| 252 | + help="Glob pattern identifying training data. This pattern must expand " |
| 253 | + "to a list of RGB images in PNG format which all have the same " |
| 254 | + "shape.") |
| 255 | + parser.add_argument("--batchsize", type=int, default=8, |
| 256 | + help="Batch size for training.") |
| 257 | + parser.add_argument("--patchsize", type=int, default=128, |
| 258 | + help="Size of image patches for training.") |
| 259 | + parser.add_argument("--lambda", type=float, default=0.1, dest="lmbda", |
| 260 | + help="Lambda for rate-distortion tradeoff.") |
| 261 | + parser.add_argument("--last_step", type=int, default=1000000, |
| 262 | + help="Train up to this number of steps.") |
| 263 | + |
| 264 | + args = parser.parse_args() |
| 265 | + |
| 266 | + if args.command == "train": |
| 267 | + train(args) |
| 268 | + elif args.command == "compress": |
| 269 | + if args.input is None or args.output is None: |
| 270 | + raise ValueError("Need input and output filename for compression.") |
| 271 | + compress(args) |
| 272 | + elif args.command == "decompress": |
| 273 | + if args.input is None or args.output is None: |
| 274 | + raise ValueError("Need input and output filename for decompression.") |
| 275 | + decompress(args) |
0 commit comments