Skip to content

Commit accedfc

Browse files
author
Johannes Ballé
committed
Implemented example image compression model.
1 parent 866f2a5 commit accedfc

File tree

3 files changed

+305
-10
lines changed

3 files changed

+305
-10
lines changed

README.md

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,39 @@ For usage questions and discussions, please head over to our
1010
**Please note**: You need TensorFlow 1.9 (or the master branch as of May 2018)
1111
or later.
1212

13-
To make sure the library imports succeed, try running the unit tests.
14-
```
13+
To make sure the library imports succeed, try running the unit tests:
14+
15+
```bash
1516
for i in tensorflow_compression/python/*/*_test.py; do
1617
python $i
1718
done
1819
```
1920

21+
## Example model
22+
23+
The `examples` directory contains an implementation of the image compression
24+
model described in:
25+
26+
> J. Ballé, V. Laparra, E. P. Simoncelli:
27+
> "End-to-end optimized image compression"
28+
> https://arxiv.org/abs/1611.01704
29+
30+
To see a list of options, change to the directory and run:
31+
32+
```bash
33+
python BLS2017.py -h
34+
```
35+
36+
To train the model, you need to supply it with a dataset of RGB training images.
37+
They should be provided in PNG format and must all have the same shape.
38+
Following training, the python script can be used to compress and decompress
39+
images as follows:
40+
41+
```bash
42+
python BLS2017.py [options] compress original.png compressed.bin
43+
python BLS2017.py [options] decompress compressed.bin reconstruction.png
44+
```
45+
2046
## Entropy bottleneck layer
2147

2248
This layer exposes a high-level interface to model the entropy (the amount of
@@ -95,7 +121,7 @@ main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits)
95121

96122
# Minimize loss and auxiliary loss, and execute update op.
97123
main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
98-
main_step = optimizer.minimize(main_loss)
124+
main_step = main_optimizer.minimize(main_loss)
99125
# 1e-3 is a good starting point for the learning rate of the auxiliary loss,
100126
# assuming Adam is used.
101127
aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)

examples/BLS2017.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2018 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Basic nonlinear transform coder for RGB images.
17+
18+
This is a close approximation of the image compression model of
19+
Ballé, Laparra, Simoncelli (2017):
20+
End-to-end optimized image compression
21+
https://arxiv.org/abs/1611.01704
22+
"""
23+
24+
from __future__ import absolute_import
25+
from __future__ import division
26+
from __future__ import print_function
27+
28+
import argparse
29+
30+
# Dependency imports
31+
32+
import numpy as np
33+
import tensorflow as tf
34+
import tensorflow_compression as tfc
35+
36+
37+
def load_image(filename):
38+
string = tf.read_file(filename)
39+
image = tf.image.decode_image(string, channels=3)
40+
image = tf.cast(image, tf.float32)
41+
image /= 255
42+
return image
43+
44+
45+
def save_image(filename, image):
46+
image = tf.clip_by_value(image, 0, 1)
47+
image = tf.round(image * 255)
48+
image = tf.cast(image, tf.uint8)
49+
string = tf.image.encode_png(image)
50+
return tf.write_file(filename, string)
51+
52+
53+
def analysis_transform(tensor, num_filters):
54+
with tf.variable_scope("analysis"):
55+
with tf.variable_scope("layer_0"):
56+
layer = tfc.SignalConv2D(
57+
num_filters, (9, 9), corr=True, strides_down=4, padding="same_zeros",
58+
use_bias=True, activation=tfc.GDN())
59+
tensor = layer(tensor)
60+
61+
with tf.variable_scope("layer_1"):
62+
layer = tfc.SignalConv2D(
63+
num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros",
64+
use_bias=True, activation=tfc.GDN())
65+
tensor = layer(tensor)
66+
67+
with tf.variable_scope("layer_2"):
68+
layer = tfc.SignalConv2D(
69+
num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros",
70+
use_bias=False, activation=None)
71+
tensor = layer(tensor)
72+
73+
return tensor
74+
75+
76+
def synthesis_transform(tensor, num_filters):
77+
with tf.variable_scope("synthesis"):
78+
with tf.variable_scope("layer_0"):
79+
layer = tfc.SignalConv2D(
80+
num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros",
81+
use_bias=True, activation=tfc.GDN(inverse=True))
82+
tensor = layer(tensor)
83+
84+
with tf.variable_scope("layer_1"):
85+
layer = tfc.SignalConv2D(
86+
num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros",
87+
use_bias=True, activation=tfc.GDN(inverse=True))
88+
tensor = layer(tensor)
89+
90+
with tf.variable_scope("layer_2"):
91+
layer = tfc.SignalConv2D(
92+
3, (9, 9), corr=False, strides_up=4, padding="same_zeros",
93+
use_bias=True, activation=None)
94+
tensor = layer(tensor)
95+
96+
return tensor
97+
98+
99+
def train(args):
100+
# Load all training images into a constant.
101+
images = tf.map_fn(
102+
load_image, tf.matching_files(args.data_glob),
103+
dtype=tf.float32, back_prop=False)
104+
with tf.Session() as sess:
105+
images = tf.constant(sess.run(images), name="images")
106+
107+
# Training inputs are random crops out of the images tensor.
108+
crop_shape = (args.batchsize, args.patchsize, args.patchsize, 3)
109+
x = tf.random_crop(images, crop_shape)
110+
num_pixels = np.prod(crop_shape[:-1])
111+
112+
# Build autoencoder.
113+
y = analysis_transform(x, args.num_filters)
114+
entropy_bottleneck = tfc.EntropyBottleneck()
115+
y_tilde, likelihoods = entropy_bottleneck(y, training=True)
116+
x_tilde = synthesis_transform(y_tilde, args.num_filters)
117+
118+
# Total number of bits divided by number of pixels.
119+
train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
120+
121+
# Mean squared error across pixels.
122+
train_mse = tf.reduce_sum(tf.squared_difference(x, x_tilde)) / num_pixels
123+
124+
# The rate-distortion cost.
125+
train_loss = args.lmbda * train_mse + train_bpp
126+
127+
# Minimize loss and auxiliary loss, and execute update op.
128+
step = tf.train.create_global_step()
129+
main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
130+
main_step = main_optimizer.minimize(train_loss, global_step=step)
131+
132+
aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
133+
aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0])
134+
135+
train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
136+
137+
hooks = [
138+
tf.train.StopAtStepHook(last_step=args.last_step),
139+
tf.train.NanTensorHook(train_loss),
140+
]
141+
with tf.train.MonitoredTrainingSession(
142+
hooks=hooks, checkpoint_dir=args.checkpoint_dir) as sess:
143+
while not sess.should_stop():
144+
sess.run(train_op)
145+
146+
147+
def compress(args):
148+
# Load input image and add batch dimension.
149+
x = load_image(args.input)
150+
x = tf.expand_dims(x, 0)
151+
x.set_shape([1, None, None, 3])
152+
153+
# Transform and compress the image, then remove batch dimension.
154+
y = analysis_transform(x, args.num_filters)
155+
entropy_bottleneck = tfc.EntropyBottleneck()
156+
string = entropy_bottleneck.compress(y)
157+
string = tf.squeeze(string, axis=0)
158+
159+
# Transform the quantized image back (if requested).
160+
y_hat, likelihoods = entropy_bottleneck(y, training=False)
161+
x_hat = synthesis_transform(y_hat, args.num_filters)
162+
163+
num_pixels = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1]))
164+
165+
# Total number of bits divided by number of pixels.
166+
eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
167+
168+
# Mean squared error across pixels.
169+
mse = tf.reduce_sum(tf.squared_difference(x, x_hat)) / num_pixels
170+
171+
with tf.Session() as sess:
172+
# Load the latest model checkpoint, get the compressed string and the tensor
173+
# shapes.
174+
latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
175+
tf.train.Saver().restore(sess, save_path=latest)
176+
string, x_shape, y_shape = sess.run([string, tf.shape(x), tf.shape(y)])
177+
178+
# Write a binary file with the shape information and the compressed string.
179+
with open(args.output, "wb") as file:
180+
file.write(np.array(x_shape[1:-1], dtype=np.uint16).tobytes())
181+
file.write(np.array(y_shape[1:-1], dtype=np.uint16).tobytes())
182+
file.write(string)
183+
184+
# If requested, transform the quantized image back and measure performance.
185+
if args.verbose:
186+
eval_bpp, mse, num_pixels = sess.run([eval_bpp, mse, num_pixels])
187+
188+
# The actual bits per pixel including overhead.
189+
bpp = (8 + len(string)) * 8 / num_pixels
190+
191+
print("Mean squared error: {:0.4}".format(mse))
192+
print("Information content of this image in bpp: {:0.4}".format(eval_bpp))
193+
print("Actual bits per pixel for this image: {:0.4}".format(bpp))
194+
195+
196+
def decompress(args):
197+
# Read the shape information and compressed string from the binary file.
198+
with open(args.input, "rb") as file:
199+
x_shape = np.frombuffer(file.read(4), dtype=np.uint16)
200+
y_shape = np.frombuffer(file.read(4), dtype=np.uint16)
201+
string = file.read()
202+
203+
bits = 8 * len(string)
204+
y_shape = [int(s) for s in y_shape] + [args.num_filters]
205+
206+
# Add a batch dimension, then decompress and transform the image back.
207+
strings = tf.expand_dims(string, 0)
208+
entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32)
209+
y_hat = entropy_bottleneck.decompress(
210+
strings, y_shape, channels=args.num_filters)
211+
x_hat = synthesis_transform(y_hat, args.num_filters)
212+
213+
# Remove batch dimension, and crop away any extraneous padding on the bottom
214+
# or right boundaries.
215+
x_hat = x_hat[0, :x_shape[0], :x_shape[1], :]
216+
217+
# Write reconstructed image out as a PNG file.
218+
op = save_image(args.output, x_hat)
219+
220+
# Load the latest model checkpoint, and perform the above actions.
221+
with tf.Session() as sess:
222+
latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
223+
tf.train.Saver().restore(sess, save_path=latest)
224+
sess.run(op)
225+
226+
227+
if __name__ == "__main__":
228+
parser = argparse.ArgumentParser(
229+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
230+
231+
parser.add_argument(
232+
"command", choices=["train", "compress", "decompress"],
233+
help="What to do: 'train' loads training data and trains (or continues "
234+
"to train) a new model. 'compress' reads an image file (lossless "
235+
"PNG format) and writes a compressed binary file. 'decompress' "
236+
"reads a binary file and reconstructs the image (in PNG format). "
237+
"input and output filenames need to be provided for the latter "
238+
"two options.")
239+
parser.add_argument(
240+
"input", nargs="?",
241+
help="Input filename.")
242+
parser.add_argument(
243+
"output", nargs="?",
244+
help="Output filename.")
245+
parser.add_argument("--verbose", "-v", action="store_true",
246+
help="Report bitrate and distortion when training or compressing.")
247+
parser.add_argument("--num_filters", type=int, default=128,
248+
help="Number of filters per layer.")
249+
parser.add_argument("--checkpoint_dir", default="train",
250+
help="Directory where to save/load model checkpoints.")
251+
parser.add_argument("--data_glob", default="images/*.png",
252+
help="Glob pattern identifying training data. This pattern must expand "
253+
"to a list of RGB images in PNG format which all have the same "
254+
"shape.")
255+
parser.add_argument("--batchsize", type=int, default=8,
256+
help="Batch size for training.")
257+
parser.add_argument("--patchsize", type=int, default=128,
258+
help="Size of image patches for training.")
259+
parser.add_argument("--lambda", type=float, default=0.1, dest="lmbda",
260+
help="Lambda for rate-distortion tradeoff.")
261+
parser.add_argument("--last_step", type=int, default=1000000,
262+
help="Train up to this number of steps.")
263+
264+
args = parser.parse_args()
265+
266+
if args.command == "train":
267+
train(args)
268+
elif args.command == "compress":
269+
if args.input is None or args.output is None:
270+
raise ValueError("Need input and output filename for compression.")
271+
compress(args)
272+
elif args.command == "decompress":
273+
if args.input is None or args.output is None:
274+
raise ValueError("Need input and output filename for decompression.")
275+
decompress(args)

tensorflow_compression/python/layers/entropy_models.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,9 @@ def quantiles_initializer(shape, dtype=None, partition_info=None):
359359

360360
cdf = coder_ops.pmf_to_quantized_cdf(
361361
pmf, precision=self.range_coder_precision)
362-
def cdf_getter(*args, **kwargs):
363-
del args, kwargs # ignored
364-
return variable_scope.get_variable(
365-
"quantized_cdf", dtype=dtypes.int32, initializer=cdf,
366-
trainable=False, validate_shape=False, collections=())
367-
# Need to provide a fake shape here since add_variable insists on it.
368362
self._quantized_cdf = self.add_variable(
369363
"quantized_cdf", shape=(channels, 1), dtype=dtypes.int32,
370-
getter=cdf_getter, trainable=False)
364+
trainable=False)
371365

372366
update_op = state_ops.assign(
373367
self._quantized_cdf, cdf, validate_shape=False)

0 commit comments

Comments
 (0)