35
35
36
36
37
37
def load_image (filename ):
38
+ """Loads a PNG image file."""
39
+
38
40
string = tf .read_file (filename )
39
41
image = tf .image .decode_image (string , channels = 3 )
40
42
image = tf .cast (image , tf .float32 )
@@ -43,6 +45,8 @@ def load_image(filename):
43
45
44
46
45
47
def save_image (filename , image ):
48
+ """Saves an image to a PNG file."""
49
+
46
50
image = tf .clip_by_value (image , 0 , 1 )
47
51
image = tf .round (image * 255 )
48
52
image = tf .cast (image , tf .uint8 )
@@ -51,6 +55,8 @@ def save_image(filename, image):
51
55
52
56
53
57
def analysis_transform (tensor , num_filters ):
58
+ """Builds the analysis transform."""
59
+
54
60
with tf .variable_scope ("analysis" ):
55
61
with tf .variable_scope ("layer_0" ):
56
62
layer = tfc .SignalConv2D (
@@ -74,6 +80,8 @@ def analysis_transform(tensor, num_filters):
74
80
75
81
76
82
def synthesis_transform (tensor , num_filters ):
83
+ """Builds the synthesis transform."""
84
+
77
85
with tf .variable_scope ("synthesis" ):
78
86
with tf .variable_scope ("layer_0" ):
79
87
layer = tfc .SignalConv2D (
@@ -96,11 +104,16 @@ def synthesis_transform(tensor, num_filters):
96
104
return tensor
97
105
98
106
99
- def train (args ):
107
+ def train ():
108
+ """Trains the model."""
109
+
110
+ if args .verbose :
111
+ tf .logging .set_verbosity (tf .logging .INFO )
112
+
100
113
# Load all training images into a constant.
101
114
images = tf .map_fn (
102
- load_image , tf .matching_files (args .data_glob ),
103
- dtype = tf .float32 , back_prop = False )
115
+ load_image , tf .matching_files (args .data_glob ),
116
+ dtype = tf .float32 , back_prop = False )
104
117
with tf .Session () as sess :
105
118
images = tf .constant (sess .run (images ), name = "images" )
106
119
@@ -119,7 +132,9 @@ def train(args):
119
132
train_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
120
133
121
134
# Mean squared error across pixels.
122
- train_mse = tf .reduce_sum (tf .squared_difference (x , x_tilde )) / num_pixels
135
+ train_mse = tf .reduce_sum (tf .squared_difference (x , x_tilde ))
136
+ # Multiply by 255^2 to correct for rescaling.
137
+ train_mse *= 255 ** 2 / num_pixels
123
138
124
139
# The rate-distortion cost.
125
140
train_loss = args .lmbda * train_mse + train_bpp
@@ -134,17 +149,25 @@ def train(args):
134
149
135
150
train_op = tf .group (main_step , aux_step , entropy_bottleneck .updates [0 ])
136
151
152
+ logged_tensors = [
153
+ tf .identity (train_loss , name = "train_loss" ),
154
+ tf .identity (train_bpp , name = "train_bpp" ),
155
+ tf .identity (train_mse , name = "train_mse" ),
156
+ ]
137
157
hooks = [
138
158
tf .train .StopAtStepHook (last_step = args .last_step ),
139
159
tf .train .NanTensorHook (train_loss ),
160
+ tf .train .LoggingTensorHook (logged_tensors , every_n_secs = 60 ),
140
161
]
141
162
with tf .train .MonitoredTrainingSession (
142
163
hooks = hooks , checkpoint_dir = args .checkpoint_dir ) as sess :
143
164
while not sess .should_stop ():
144
165
sess .run (train_op )
145
166
146
167
147
- def compress (args ):
168
+ def compress ():
169
+ """Compresses an image."""
170
+
148
171
# Load input image and add batch dimension.
149
172
x = load_image (args .input )
150
173
x = tf .expand_dims (x , 0 )
@@ -166,7 +189,9 @@ def compress(args):
166
189
eval_bpp = tf .reduce_sum (tf .log (likelihoods )) / (- np .log (2 ) * num_pixels )
167
190
168
191
# Mean squared error across pixels.
169
- mse = tf .reduce_sum (tf .squared_difference (x , x_hat )) / num_pixels
192
+ x_hat = tf .clip_by_value (x_hat , 0 , 1 )
193
+ x_hat = tf .round (x_hat * 255 )
194
+ mse = tf .reduce_sum (tf .squared_difference (x * 255 , x_hat )) / num_pixels
170
195
171
196
with tf .Session () as sess :
172
197
# Load the latest model checkpoint, get the compressed string and the tensor
@@ -176,10 +201,10 @@ def compress(args):
176
201
string , x_shape , y_shape = sess .run ([string , tf .shape (x ), tf .shape (y )])
177
202
178
203
# Write a binary file with the shape information and the compressed string.
179
- with open (args .output , "wb" ) as file :
180
- file .write (np .array (x_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
181
- file .write (np .array (y_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
182
- file .write (string )
204
+ with open (args .output , "wb" ) as f :
205
+ f .write (np .array (x_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
206
+ f .write (np .array (y_shape [1 :- 1 ], dtype = np .uint16 ).tobytes ())
207
+ f .write (string )
183
208
184
209
# If requested, transform the quantized image back and measure performance.
185
210
if args .verbose :
@@ -193,14 +218,15 @@ def compress(args):
193
218
print ("Actual bits per pixel for this image: {:0.4}" .format (bpp ))
194
219
195
220
196
- def decompress (args ):
221
+ def decompress ():
222
+ """Decompresses an image."""
223
+
197
224
# Read the shape information and compressed string from the binary file.
198
- with open (args .input , "rb" ) as file :
199
- x_shape = np .frombuffer (file .read (4 ), dtype = np .uint16 )
200
- y_shape = np .frombuffer (file .read (4 ), dtype = np .uint16 )
201
- string = file .read ()
225
+ with open (args .input , "rb" ) as f :
226
+ x_shape = np .frombuffer (f .read (4 ), dtype = np .uint16 )
227
+ y_shape = np .frombuffer (f .read (4 ), dtype = np .uint16 )
228
+ string = f .read ()
202
229
203
- bits = 8 * len (string )
204
230
y_shape = [int (s ) for s in y_shape ] + [args .num_filters ]
205
231
206
232
# Add a batch dimension, then decompress and transform the image back.
@@ -242,34 +268,42 @@ def decompress(args):
242
268
parser .add_argument (
243
269
"output" , nargs = "?" ,
244
270
help = "Output filename." )
245
- parser .add_argument ("--verbose" , "-v" , action = "store_true" ,
271
+ parser .add_argument (
272
+ "--verbose" , "-v" , action = "store_true" ,
246
273
help = "Report bitrate and distortion when training or compressing." )
247
- parser .add_argument ("--num_filters" , type = int , default = 128 ,
274
+ parser .add_argument (
275
+ "--num_filters" , type = int , default = 128 ,
248
276
help = "Number of filters per layer." )
249
- parser .add_argument ("--checkpoint_dir" , default = "train" ,
277
+ parser .add_argument (
278
+ "--checkpoint_dir" , default = "train" ,
250
279
help = "Directory where to save/load model checkpoints." )
251
- parser .add_argument ("--data_glob" , default = "images/*.png" ,
280
+ parser .add_argument (
281
+ "--data_glob" , default = "images/*.png" ,
252
282
help = "Glob pattern identifying training data. This pattern must expand "
253
283
"to a list of RGB images in PNG format which all have the same "
254
284
"shape." )
255
- parser .add_argument ("--batchsize" , type = int , default = 8 ,
285
+ parser .add_argument (
286
+ "--batchsize" , type = int , default = 8 ,
256
287
help = "Batch size for training." )
257
- parser .add_argument ("--patchsize" , type = int , default = 128 ,
288
+ parser .add_argument (
289
+ "--patchsize" , type = int , default = 128 ,
258
290
help = "Size of image patches for training." )
259
- parser .add_argument ("--lambda" , type = float , default = 0.1 , dest = "lmbda" ,
291
+ parser .add_argument (
292
+ "--lambda" , type = float , default = 0.1 , dest = "lmbda" ,
260
293
help = "Lambda for rate-distortion tradeoff." )
261
- parser .add_argument ("--last_step" , type = int , default = 1000000 ,
294
+ parser .add_argument (
295
+ "--last_step" , type = int , default = 1000000 ,
262
296
help = "Train up to this number of steps." )
263
297
264
298
args = parser .parse_args ()
265
299
266
300
if args .command == "train" :
267
- train (args )
301
+ train ()
268
302
elif args .command == "compress" :
269
303
if args .input is None or args .output is None :
270
304
raise ValueError ("Need input and output filename for compression." )
271
- compress (args )
305
+ compress ()
272
306
elif args .command == "decompress" :
273
307
if args .input is None or args .output is None :
274
308
raise ValueError ("Need input and output filename for decompression." )
275
- decompress (args )
309
+ decompress ()
0 commit comments