19
19
Ballé, Laparra, Simoncelli (2017):
20
20
End-to-end optimized image compression
21
21
https://arxiv.org/abs/1611.01704
22
+
23
+ With patches from Victor Xing <[email protected] >
22
24
"""
23
25
24
26
from __future__ import absolute_import
25
27
from __future__ import division
26
28
from __future__ import print_function
27
29
28
30
import argparse
31
+ import glob
29
32
30
33
# Dependency imports
31
34
@@ -44,12 +47,16 @@ def load_image(filename):
44
47
return image
45
48
46
49
47
- def save_image (filename , image ):
48
- """Saves an image to a PNG file."""
49
-
50
+ def quantize_image (image ):
50
51
image = tf .clip_by_value (image , 0 , 1 )
51
52
image = tf .round (image * 255 )
52
53
image = tf .cast (image , tf .uint8 )
54
+ return image
55
+
56
+
57
+ def save_image (filename , image ):
58
+ """Saves an image to a PNG file."""
59
+ image = quantize_image (image )
53
60
string = tf .image .encode_png (image )
54
61
return tf .write_file (filename , string )
55
62
@@ -110,17 +117,22 @@ def train():
110
117
if args .verbose :
111
118
tf .logging .set_verbosity (tf .logging .INFO )
112
119
113
- # Load all training images into a constant.
114
- images = tf .map_fn (
115
- load_image , tf .matching_files (args .data_glob ),
116
- dtype = tf .float32 , back_prop = False )
117
- with tf .Session () as sess :
118
- images = tf .constant (sess .run (images ), name = "images" )
120
+ # Create input data pipeline.
121
+ with tf .device ('/cpu:0' ):
122
+ train_files = glob .glob (args .train_glob )
123
+ train_dataset = tf .data .Dataset .from_tensor_slices (train_files )
124
+ train_dataset = train_dataset .shuffle (buffer_size = len (train_files )).repeat ()
125
+ train_dataset = train_dataset .map (
126
+ load_image , num_parallel_calls = args .preprocess_threads )
127
+ train_dataset = train_dataset .map (
128
+ lambda x : tf .random_crop (x , (args .patchsize , args .patchsize , 3 )))
129
+ train_dataset = train_dataset .batch (args .batchsize )
130
+ train_dataset = train_dataset .prefetch (32 )
131
+
132
+ num_pixels = args .batchsize * args .patchsize ** 2
119
133
120
- # Training inputs are random crops out of the images tensor.
121
- crop_shape = (args .batchsize , args .patchsize , args .patchsize , 3 )
122
- x = tf .random_crop (images , crop_shape )
123
- num_pixels = np .prod (crop_shape [:- 1 ])
134
+ # Get training patch from dataset.
135
+ x = train_dataset .make_one_shot_iterator ().get_next ()
124
136
125
137
# Build autoencoder.
126
138
y = analysis_transform (x , args .num_filters )
@@ -132,9 +144,9 @@ def train():
132
144
train_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
133
145
134
146
# Mean squared error across pixels.
135
- train_mse = tf .reduce_sum (tf .squared_difference (x , x_tilde ))
147
+ train_mse = tf .reduce_mean (tf .squared_difference (x , x_tilde ))
136
148
# Multiply by 255^2 to correct for rescaling.
137
- train_mse *= 255 ** 2 / num_pixels
149
+ train_mse *= 255 ** 2
138
150
139
151
# The rate-distortion cost.
140
152
train_loss = args .lmbda * train_mse + train_bpp
@@ -149,18 +161,24 @@ def train():
149
161
150
162
train_op = tf .group (main_step , aux_step , entropy_bottleneck .updates [0 ])
151
163
152
- logged_tensors = [
153
- tf .identity (train_loss , name = "train_loss" ),
154
- tf .identity (train_bpp , name = "train_bpp" ),
155
- tf .identity (train_mse , name = "train_mse" ),
156
- ]
164
+ tf .summary .scalar ("loss" , train_loss )
165
+ tf .summary .scalar ("bpp" , train_bpp )
166
+ tf .summary .scalar ("mse" , train_mse )
167
+
168
+ tf .summary .image ("original" , quantize_image (x ))
169
+ tf .summary .image ("reconstruction" , quantize_image (x_tilde ))
170
+
171
+ # Creates summary for the probability mass function (PMF) estimated in the
172
+ # bottleneck.
173
+ entropy_bottleneck .visualize ()
174
+
157
175
hooks = [
158
176
tf .train .StopAtStepHook (last_step = args .last_step ),
159
177
tf .train .NanTensorHook (train_loss ),
160
- tf .train .LoggingTensorHook (logged_tensors , every_n_secs = 60 ),
161
178
]
162
179
with tf .train .MonitoredTrainingSession (
163
- hooks = hooks , checkpoint_dir = args .checkpoint_dir ) as sess :
180
+ hooks = hooks , checkpoint_dir = args .checkpoint_dir ,
181
+ save_checkpoint_secs = 300 , save_summaries_secs = 60 ) as sess :
164
182
while not sess .should_stop ():
165
183
sess .run (train_op )
166
184
@@ -188,10 +206,14 @@ def compress():
188
206
# Total number of bits divided by number of pixels.
189
207
eval_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
190
208
191
- # Mean squared error across pixels.
209
+ # Bring both images back to 0..255 range.
210
+ x *= 255
192
211
x_hat = tf .clip_by_value (x_hat , 0 , 1 )
193
212
x_hat = tf .round (x_hat * 255 )
194
- mse = tf .reduce_sum (tf .squared_difference (x * 255 , x_hat )) / num_pixels
213
+
214
+ mse = tf .reduce_mean (tf .squared_difference (x , x_hat ))
215
+ psnr = tf .squeeze (tf .image .psnr (x_hat , x , 255 ))
216
+ msssim = tf .squeeze (tf .image .ssim_multiscale (x_hat , x , 255 ))
195
217
196
218
with tf .Session () as sess :
197
219
# Load the latest model checkpoint, get the compressed string and the tensor
@@ -208,14 +230,18 @@ def compress():
208
230
209
231
# If requested, transform the quantized image back and measure performance.
210
232
if args .verbose :
211
- eval_bpp , mse , num_pixels = sess .run ([eval_bpp , mse , num_pixels ])
233
+ eval_bpp , mse , psnr , msssim , num_pixels = sess .run (
234
+ [eval_bpp , mse , psnr , msssim , num_pixels ])
212
235
213
236
# The actual bits per pixel including overhead.
214
237
bpp = (8 + len (string )) * 8 / num_pixels
215
238
216
- print ("Mean squared error: {:0.4}" .format (mse ))
217
- print ("Information content of this image in bpp: {:0.4}" .format (eval_bpp ))
218
- print ("Actual bits per pixel for this image: {:0.4}" .format (bpp ))
239
+ print ("Mean squared error: {:0.4f}" .format (mse ))
240
+ print ("PSNR (dB): {:0.2f}" .format (psnr ))
241
+ print ("Multiscale SSIM: {:0.4f}" .format (msssim ))
242
+ print ("Multiscale SSIM (dB): {:0.2f}" .format (- 10 * np .log10 (1 - msssim )))
243
+ print ("Information content in bpp: {:0.4f}" .format (eval_bpp ))
244
+ print ("Actual bits per pixel: {:0.4f}" .format (bpp ))
219
245
220
246
221
247
def decompress ():
@@ -278,22 +304,25 @@ def decompress():
278
304
"--checkpoint_dir" , default = "train" ,
279
305
help = "Directory where to save/load model checkpoints." )
280
306
parser .add_argument (
281
- "--data_glob " , default = "images/*.png" ,
307
+ "--train_glob " , default = "images/*.png" ,
282
308
help = "Glob pattern identifying training data. This pattern must expand "
283
- "to a list of RGB images in PNG format which all have the same "
284
- "shape." )
309
+ "to a list of RGB images in PNG format." )
285
310
parser .add_argument (
286
311
"--batchsize" , type = int , default = 8 ,
287
312
help = "Batch size for training." )
288
313
parser .add_argument (
289
- "--patchsize" , type = int , default = 128 ,
314
+ "--patchsize" , type = int , default = 256 ,
290
315
help = "Size of image patches for training." )
291
316
parser .add_argument (
292
- "--lambda" , type = float , default = 0.1 , dest = "lmbda" ,
317
+ "--lambda" , type = float , default = 0.01 , dest = "lmbda" ,
293
318
help = "Lambda for rate-distortion tradeoff." )
294
319
parser .add_argument (
295
320
"--last_step" , type = int , default = 1000000 ,
296
321
help = "Train up to this number of steps." )
322
+ parser .add_argument (
323
+ "--preprocess_threads" , type = int , default = 16 ,
324
+ help = "Number of CPU threads to use for parallel decoding of training "
325
+ "images." )
297
326
298
327
args = parser .parse_args ()
299
328
0 commit comments