Skip to content

Commit 1329dff

Browse files
author
Johannes Ballé
committed
Fixed lint warnings, some improvements to example model.
1 parent acbf653 commit 1329dff

File tree

3 files changed

+65
-31
lines changed

3 files changed

+65
-31
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ model described in:
3030
To see a list of options, change to the directory and run:
3131

3232
```bash
33-
python BLS2017.py -h
33+
python bls2017.py -h
3434
```
3535

3636
To train the model, you need to supply it with a dataset of RGB training images.
@@ -39,8 +39,8 @@ Following training, the python script can be used to compress and decompress
3939
images as follows:
4040

4141
```bash
42-
python BLS2017.py [options] compress original.png compressed.bin
43-
python BLS2017.py [options] decompress compressed.bin reconstruction.png
42+
python bls2017.py [options] compress original.png compressed.bin
43+
python bls2017.py [options] decompress compressed.bin reconstruction.png
4444
```
4545

4646
## Entropy bottleneck layer

examples/BLS2017.py renamed to examples/bls2017.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636

3737
def load_image(filename):
38+
"""Loads a PNG image file."""
39+
3840
string = tf.read_file(filename)
3941
image = tf.image.decode_image(string, channels=3)
4042
image = tf.cast(image, tf.float32)
@@ -43,6 +45,8 @@ def load_image(filename):
4345

4446

4547
def save_image(filename, image):
48+
"""Saves an image to a PNG file."""
49+
4650
image = tf.clip_by_value(image, 0, 1)
4751
image = tf.round(image * 255)
4852
image = tf.cast(image, tf.uint8)
@@ -51,6 +55,8 @@ def save_image(filename, image):
5155

5256

5357
def analysis_transform(tensor, num_filters):
58+
"""Builds the analysis transform."""
59+
5460
with tf.variable_scope("analysis"):
5561
with tf.variable_scope("layer_0"):
5662
layer = tfc.SignalConv2D(
@@ -74,6 +80,8 @@ def analysis_transform(tensor, num_filters):
7480

7581

7682
def synthesis_transform(tensor, num_filters):
83+
"""Builds the synthesis transform."""
84+
7785
with tf.variable_scope("synthesis"):
7886
with tf.variable_scope("layer_0"):
7987
layer = tfc.SignalConv2D(
@@ -96,11 +104,16 @@ def synthesis_transform(tensor, num_filters):
96104
return tensor
97105

98106

99-
def train(args):
107+
def train():
108+
"""Trains the model."""
109+
110+
if args.verbose:
111+
tf.logging.set_verbosity(tf.logging.INFO)
112+
100113
# Load all training images into a constant.
101114
images = tf.map_fn(
102-
load_image, tf.matching_files(args.data_glob),
103-
dtype=tf.float32, back_prop=False)
115+
load_image, tf.matching_files(args.data_glob),
116+
dtype=tf.float32, back_prop=False)
104117
with tf.Session() as sess:
105118
images = tf.constant(sess.run(images), name="images")
106119

@@ -119,7 +132,9 @@ def train(args):
119132
train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
120133

121134
# Mean squared error across pixels.
122-
train_mse = tf.reduce_sum(tf.squared_difference(x, x_tilde)) / num_pixels
135+
train_mse = tf.reduce_sum(tf.squared_difference(x, x_tilde))
136+
# Multiply by 255^2 to correct for rescaling.
137+
train_mse *= 255 ** 2 / num_pixels
123138

124139
# The rate-distortion cost.
125140
train_loss = args.lmbda * train_mse + train_bpp
@@ -134,17 +149,25 @@ def train(args):
134149

135150
train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
136151

152+
logged_tensors = [
153+
tf.identity(train_loss, name="train_loss"),
154+
tf.identity(train_bpp, name="train_bpp"),
155+
tf.identity(train_mse, name="train_mse"),
156+
]
137157
hooks = [
138158
tf.train.StopAtStepHook(last_step=args.last_step),
139159
tf.train.NanTensorHook(train_loss),
160+
tf.train.LoggingTensorHook(logged_tensors, every_n_secs=60),
140161
]
141162
with tf.train.MonitoredTrainingSession(
142163
hooks=hooks, checkpoint_dir=args.checkpoint_dir) as sess:
143164
while not sess.should_stop():
144165
sess.run(train_op)
145166

146167

147-
def compress(args):
168+
def compress():
169+
"""Compresses an image."""
170+
148171
# Load input image and add batch dimension.
149172
x = load_image(args.input)
150173
x = tf.expand_dims(x, 0)
@@ -166,7 +189,9 @@ def compress(args):
166189
eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
167190

168191
# Mean squared error across pixels.
169-
mse = tf.reduce_sum(tf.squared_difference(x, x_hat)) / num_pixels
192+
x_hat = tf.clip_by_value(x_hat, 0, 1)
193+
x_hat = tf.round(x_hat * 255)
194+
mse = tf.reduce_sum(tf.squared_difference(x * 255, x_hat)) / num_pixels
170195

171196
with tf.Session() as sess:
172197
# Load the latest model checkpoint, get the compressed string and the tensor
@@ -176,10 +201,10 @@ def compress(args):
176201
string, x_shape, y_shape = sess.run([string, tf.shape(x), tf.shape(y)])
177202

178203
# Write a binary file with the shape information and the compressed string.
179-
with open(args.output, "wb") as file:
180-
file.write(np.array(x_shape[1:-1], dtype=np.uint16).tobytes())
181-
file.write(np.array(y_shape[1:-1], dtype=np.uint16).tobytes())
182-
file.write(string)
204+
with open(args.output, "wb") as f:
205+
f.write(np.array(x_shape[1:-1], dtype=np.uint16).tobytes())
206+
f.write(np.array(y_shape[1:-1], dtype=np.uint16).tobytes())
207+
f.write(string)
183208

184209
# If requested, transform the quantized image back and measure performance.
185210
if args.verbose:
@@ -193,14 +218,15 @@ def compress(args):
193218
print("Actual bits per pixel for this image: {:0.4}".format(bpp))
194219

195220

196-
def decompress(args):
221+
def decompress():
222+
"""Decompresses an image."""
223+
197224
# Read the shape information and compressed string from the binary file.
198-
with open(args.input, "rb") as file:
199-
x_shape = np.frombuffer(file.read(4), dtype=np.uint16)
200-
y_shape = np.frombuffer(file.read(4), dtype=np.uint16)
201-
string = file.read()
225+
with open(args.input, "rb") as f:
226+
x_shape = np.frombuffer(f.read(4), dtype=np.uint16)
227+
y_shape = np.frombuffer(f.read(4), dtype=np.uint16)
228+
string = f.read()
202229

203-
bits = 8 * len(string)
204230
y_shape = [int(s) for s in y_shape] + [args.num_filters]
205231

206232
# Add a batch dimension, then decompress and transform the image back.
@@ -242,34 +268,42 @@ def decompress(args):
242268
parser.add_argument(
243269
"output", nargs="?",
244270
help="Output filename.")
245-
parser.add_argument("--verbose", "-v", action="store_true",
271+
parser.add_argument(
272+
"--verbose", "-v", action="store_true",
246273
help="Report bitrate and distortion when training or compressing.")
247-
parser.add_argument("--num_filters", type=int, default=128,
274+
parser.add_argument(
275+
"--num_filters", type=int, default=128,
248276
help="Number of filters per layer.")
249-
parser.add_argument("--checkpoint_dir", default="train",
277+
parser.add_argument(
278+
"--checkpoint_dir", default="train",
250279
help="Directory where to save/load model checkpoints.")
251-
parser.add_argument("--data_glob", default="images/*.png",
280+
parser.add_argument(
281+
"--data_glob", default="images/*.png",
252282
help="Glob pattern identifying training data. This pattern must expand "
253283
"to a list of RGB images in PNG format which all have the same "
254284
"shape.")
255-
parser.add_argument("--batchsize", type=int, default=8,
285+
parser.add_argument(
286+
"--batchsize", type=int, default=8,
256287
help="Batch size for training.")
257-
parser.add_argument("--patchsize", type=int, default=128,
288+
parser.add_argument(
289+
"--patchsize", type=int, default=128,
258290
help="Size of image patches for training.")
259-
parser.add_argument("--lambda", type=float, default=0.1, dest="lmbda",
291+
parser.add_argument(
292+
"--lambda", type=float, default=0.1, dest="lmbda",
260293
help="Lambda for rate-distortion tradeoff.")
261-
parser.add_argument("--last_step", type=int, default=1000000,
294+
parser.add_argument(
295+
"--last_step", type=int, default=1000000,
262296
help="Train up to this number of steps.")
263297

264298
args = parser.parse_args()
265299

266300
if args.command == "train":
267-
train(args)
301+
train()
268302
elif args.command == "compress":
269303
if args.input is None or args.output is None:
270304
raise ValueError("Need input and output filename for compression.")
271-
compress(args)
305+
compress()
272306
elif args.command == "decompress":
273307
if args.input is None or args.output is None:
274308
raise ValueError("Need input and output filename for decompression.")
275-
decompress(args)
309+
decompress()

tensorflow_compression/python/layers/entropy_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from tensorflow.python.ops import nn
3939
from tensorflow.python.ops import random_ops
4040
from tensorflow.python.ops import state_ops
41-
from tensorflow.python.ops import variable_scope
4241
from tensorflow.python.summary import summary
4342

4443

@@ -364,6 +363,7 @@ def quantiles_initializer(shape, dtype=None, partition_info=None):
364363
# or the variable will return the wrong dynamic shape later. A placeholder
365364
# with default gets the trick done.
366365
def cdf_init(*args, **kwargs):
366+
del args, kwargs # unused
367367
return array_ops.placeholder_with_default(
368368
array_ops.zeros((channels, 1), dtype=dtypes.int32),
369369
shape=(channels, None))

0 commit comments

Comments
 (0)