Skip to content

Commit 164bc19

Browse files
author
Johannes Ballé
committed
Cherry picked changes from Victor Xing.
- Overhaul input pipeline. - Create summaries for Tensorboard. - Compute PSNR and MS-SSIM. In addition: - Modified computation of MSE to be consistent with results reported in paper. Squared error is now averaged across channels, rather than added. - Adjusted some defaults to be more useful out of the box (lambda, patch size).
1 parent 198fc0a commit 164bc19

File tree

1 file changed

+62
-33
lines changed

1 file changed

+62
-33
lines changed

examples/bls2017.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
Ballé, Laparra, Simoncelli (2017):
2020
End-to-end optimized image compression
2121
https://arxiv.org/abs/1611.01704
22+
23+
With patches from Victor Xing <[email protected]>
2224
"""
2325

2426
from __future__ import absolute_import
2527
from __future__ import division
2628
from __future__ import print_function
2729

2830
import argparse
31+
import glob
2932

3033
# Dependency imports
3134

@@ -44,12 +47,16 @@ def load_image(filename):
4447
return image
4548

4649

47-
def save_image(filename, image):
48-
"""Saves an image to a PNG file."""
49-
50+
def quantize_image(image):
5051
image = tf.clip_by_value(image, 0, 1)
5152
image = tf.round(image * 255)
5253
image = tf.cast(image, tf.uint8)
54+
return image
55+
56+
57+
def save_image(filename, image):
58+
"""Saves an image to a PNG file."""
59+
image = quantize_image(image)
5360
string = tf.image.encode_png(image)
5461
return tf.write_file(filename, string)
5562

@@ -110,17 +117,22 @@ def train():
110117
if args.verbose:
111118
tf.logging.set_verbosity(tf.logging.INFO)
112119

113-
# Load all training images into a constant.
114-
images = tf.map_fn(
115-
load_image, tf.matching_files(args.data_glob),
116-
dtype=tf.float32, back_prop=False)
117-
with tf.Session() as sess:
118-
images = tf.constant(sess.run(images), name="images")
120+
# Create input data pipeline.
121+
with tf.device('/cpu:0'):
122+
train_files = glob.glob(args.train_glob)
123+
train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
124+
train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat()
125+
train_dataset = train_dataset.map(
126+
load_image, num_parallel_calls=args.preprocess_threads)
127+
train_dataset = train_dataset.map(
128+
lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3)))
129+
train_dataset = train_dataset.batch(args.batchsize)
130+
train_dataset = train_dataset.prefetch(32)
131+
132+
num_pixels = args.batchsize * args.patchsize ** 2
119133

120-
# Training inputs are random crops out of the images tensor.
121-
crop_shape = (args.batchsize, args.patchsize, args.patchsize, 3)
122-
x = tf.random_crop(images, crop_shape)
123-
num_pixels = np.prod(crop_shape[:-1])
134+
# Get training patch from dataset.
135+
x = train_dataset.make_one_shot_iterator().get_next()
124136

125137
# Build autoencoder.
126138
y = analysis_transform(x, args.num_filters)
@@ -132,9 +144,9 @@ def train():
132144
train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
133145

134146
# Mean squared error across pixels.
135-
train_mse = tf.reduce_sum(tf.squared_difference(x, x_tilde))
147+
train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
136148
# Multiply by 255^2 to correct for rescaling.
137-
train_mse *= 255 ** 2 / num_pixels
149+
train_mse *= 255 ** 2
138150

139151
# The rate-distortion cost.
140152
train_loss = args.lmbda * train_mse + train_bpp
@@ -149,18 +161,24 @@ def train():
149161

150162
train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
151163

152-
logged_tensors = [
153-
tf.identity(train_loss, name="train_loss"),
154-
tf.identity(train_bpp, name="train_bpp"),
155-
tf.identity(train_mse, name="train_mse"),
156-
]
164+
tf.summary.scalar("loss", train_loss)
165+
tf.summary.scalar("bpp", train_bpp)
166+
tf.summary.scalar("mse", train_mse)
167+
168+
tf.summary.image("original", quantize_image(x))
169+
tf.summary.image("reconstruction", quantize_image(x_tilde))
170+
171+
# Creates summary for the probability mass function (PMF) estimated in the
172+
# bottleneck.
173+
entropy_bottleneck.visualize()
174+
157175
hooks = [
158176
tf.train.StopAtStepHook(last_step=args.last_step),
159177
tf.train.NanTensorHook(train_loss),
160-
tf.train.LoggingTensorHook(logged_tensors, every_n_secs=60),
161178
]
162179
with tf.train.MonitoredTrainingSession(
163-
hooks=hooks, checkpoint_dir=args.checkpoint_dir) as sess:
180+
hooks=hooks, checkpoint_dir=args.checkpoint_dir,
181+
save_checkpoint_secs=300, save_summaries_secs=60) as sess:
164182
while not sess.should_stop():
165183
sess.run(train_op)
166184

@@ -188,10 +206,14 @@ def compress():
188206
# Total number of bits divided by number of pixels.
189207
eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
190208

191-
# Mean squared error across pixels.
209+
# Bring both images back to 0..255 range.
210+
x *= 255
192211
x_hat = tf.clip_by_value(x_hat, 0, 1)
193212
x_hat = tf.round(x_hat * 255)
194-
mse = tf.reduce_sum(tf.squared_difference(x * 255, x_hat)) / num_pixels
213+
214+
mse = tf.reduce_mean(tf.squared_difference(x, x_hat))
215+
psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255))
216+
msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255))
195217

196218
with tf.Session() as sess:
197219
# Load the latest model checkpoint, get the compressed string and the tensor
@@ -208,14 +230,18 @@ def compress():
208230

209231
# If requested, transform the quantized image back and measure performance.
210232
if args.verbose:
211-
eval_bpp, mse, num_pixels = sess.run([eval_bpp, mse, num_pixels])
233+
eval_bpp, mse, psnr, msssim, num_pixels = sess.run(
234+
[eval_bpp, mse, psnr, msssim, num_pixels])
212235

213236
# The actual bits per pixel including overhead.
214237
bpp = (8 + len(string)) * 8 / num_pixels
215238

216-
print("Mean squared error: {:0.4}".format(mse))
217-
print("Information content of this image in bpp: {:0.4}".format(eval_bpp))
218-
print("Actual bits per pixel for this image: {:0.4}".format(bpp))
239+
print("Mean squared error: {:0.4f}".format(mse))
240+
print("PSNR (dB): {:0.2f}".format(psnr))
241+
print("Multiscale SSIM: {:0.4f}".format(msssim))
242+
print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim)))
243+
print("Information content in bpp: {:0.4f}".format(eval_bpp))
244+
print("Actual bits per pixel: {:0.4f}".format(bpp))
219245

220246

221247
def decompress():
@@ -278,22 +304,25 @@ def decompress():
278304
"--checkpoint_dir", default="train",
279305
help="Directory where to save/load model checkpoints.")
280306
parser.add_argument(
281-
"--data_glob", default="images/*.png",
307+
"--train_glob", default="images/*.png",
282308
help="Glob pattern identifying training data. This pattern must expand "
283-
"to a list of RGB images in PNG format which all have the same "
284-
"shape.")
309+
"to a list of RGB images in PNG format.")
285310
parser.add_argument(
286311
"--batchsize", type=int, default=8,
287312
help="Batch size for training.")
288313
parser.add_argument(
289-
"--patchsize", type=int, default=128,
314+
"--patchsize", type=int, default=256,
290315
help="Size of image patches for training.")
291316
parser.add_argument(
292-
"--lambda", type=float, default=0.1, dest="lmbda",
317+
"--lambda", type=float, default=0.01, dest="lmbda",
293318
help="Lambda for rate-distortion tradeoff.")
294319
parser.add_argument(
295320
"--last_step", type=int, default=1000000,
296321
help="Train up to this number of steps.")
322+
parser.add_argument(
323+
"--preprocess_threads", type=int, default=16,
324+
help="Number of CPU threads to use for parallel decoding of training "
325+
"images.")
297326

298327
args = parser.parse_args()
299328

0 commit comments

Comments
 (0)