Skip to content

Commit 05e2fb0

Browse files
author
Johannes Ballé
committed
Implements --target_bpp for TFCI.
PiperOrigin-RevId: 258114868 Change-Id: I18bc785435281eecb3f69fe1fc8347daff9e74ad
1 parent 9a2e72e commit 05e2fb0

File tree

1 file changed

+122
-42
lines changed

1 file changed

+122
-42
lines changed

examples/tfci.py

Lines changed: 122 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import argparse
2323
import os
24+
import sys
2425

2526
from absl import app
2627
from absl.flags import argparse_flags
@@ -30,6 +31,11 @@
3031

3132
import tensorflow_compression as tfc # pylint:disable=unused-import
3233

34+
# Default URL to fetch metagraphs from.
35+
URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs"
36+
# Default location to store cached metagraphs.
37+
METAGRAPH_CACHE = "/tmp/tfc_metagraphs"
38+
3339

3440
def read_png(filename):
3541
"""Creates graph to load a PNG image file."""
@@ -50,22 +56,28 @@ def write_png(filename, image):
5056
return tf.io.write_file(filename, string)
5157

5258

53-
def load_metagraph(model, url_prefix, metagraph_cache):
54-
"""Loads and caches a trained model metagraph."""
55-
filename = os.path.join(metagraph_cache, model + ".metagraph")
59+
def load_cached(filename):
60+
"""Downloads and caches files from web storage."""
61+
pathname = os.path.join(METAGRAPH_CACHE, filename)
5662
try:
57-
with tf.io.gfile.GFile(filename, "rb") as f:
63+
with tf.io.gfile.GFile(pathname, "rb") as f:
5864
string = f.read()
5965
except tf.errors.NotFoundError:
60-
url = url_prefix + "/" + model + ".metagraph"
66+
url = URL_PREFIX + "/" + filename
6167
try:
6268
request = urllib.request.urlopen(url)
6369
string = request.read()
6470
finally:
6571
request.close()
66-
tf.io.gfile.makedirs(os.path.dirname(filename))
67-
with tf.io.gfile.GFile(filename, "wb") as f:
72+
tf.io.gfile.makedirs(os.path.dirname(pathname))
73+
with tf.io.gfile.GFile(pathname, "wb") as f:
6874
f.write(string)
75+
return string
76+
77+
78+
def import_metagraph(model):
79+
"""Imports a trained model metagraph into the current graph."""
80+
string = load_cached(model + ".metagraph")
6981
metagraph = tf.MetaGraphDef()
7082
metagraph.ParseFromString(string)
7183
tf.train.import_meta_graph(metagraph)
@@ -86,14 +98,11 @@ def instantiate_signature(signature_def):
8698
return inputs, outputs
8799

88100

89-
def compress(model, input_file, output_file, url_prefix, metagraph_cache):
90-
"""Compresses a PNG file to a TFCI file."""
91-
if not output_file:
92-
output_file = input_file + ".tfci"
93-
101+
def compress_image(model, input_image):
102+
"""Compresses an image array into a bitstring."""
94103
with tf.Graph().as_default():
95104
# Load model metagraph.
96-
signature_defs = load_metagraph(model, url_prefix, metagraph_cache)
105+
signature_defs = import_metagraph(model)
97106
inputs, outputs = instantiate_signature(signature_defs["sender"])
98107

99108
# Just one input tensor.
@@ -103,12 +112,12 @@ def compress(model, input_file, output_file, url_prefix, metagraph_cache):
103112

104113
# Run encoder.
105114
with tf.Session() as sess:
106-
feed_dict = {inputs: sess.run(read_png(input_file))}
115+
feed_dict = {inputs: input_image}
107116
arrays = sess.run(outputs, feed_dict=feed_dict)
108117

109118
# Pack data into tf.Example.
110119
example = tf.train.Example()
111-
example.features.feature["MD"].bytes_list.value[:] = [model]
120+
example.features.feature["MD"].bytes_list.value[:] = [model.encode("ascii")]
112121
for i, (array, tensor) in enumerate(zip(arrays, outputs)):
113122
feature = example.features.feature[chr(i + 1)]
114123
if array.ndim != 1:
@@ -121,12 +130,60 @@ def compress(model, input_file, output_file, url_prefix, metagraph_cache):
121130
raise RuntimeError(
122131
"Unexpected tensor dtype: '{}'.".format(tensor.dtype))
123132

124-
# Write serialized tf.Example to disk.
125-
with tf.io.gfile.GFile(output_file, "wb") as f:
126-
f.write(example.SerializeToString())
133+
return example.SerializeToString()
127134

128135

129-
def decompress(input_file, output_file, url_prefix, metagraph_cache):
136+
def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
137+
"""Compresses a PNG file to a TFCI file."""
138+
if not output_file:
139+
output_file = input_file + ".tfci"
140+
141+
# Load image.
142+
with tf.Graph().as_default():
143+
with tf.Session() as sess:
144+
input_image = sess.run(read_png(input_file))
145+
num_pixels = input_image.shape[-2] * input_image.shape[-3]
146+
147+
if not target_bpp:
148+
# Just compress with a specific model.
149+
bitstring = compress_image(model, input_image)
150+
else:
151+
# Get model list.
152+
models = load_cached(model + ".models")
153+
models = models.decode("ascii").split()
154+
155+
# Do a binary search over all RD points.
156+
lower = -1
157+
upper = len(models)
158+
bpp = None
159+
best_bitstring = None
160+
best_bpp = None
161+
while bpp != target_bpp and upper - lower > 1:
162+
i = (upper + lower) // 2
163+
bitstring = compress_image(models[i], input_image)
164+
bpp = 8 * len(bitstring) / num_pixels
165+
is_admissible = bpp <= target_bpp or not bpp_strict
166+
is_better = (best_bpp is None or
167+
abs(bpp - target_bpp) < abs(best_bpp - target_bpp))
168+
if is_admissible and is_better:
169+
best_bitstring = bitstring
170+
best_bpp = bpp
171+
if bpp < target_bpp:
172+
lower = i
173+
if bpp > target_bpp:
174+
upper = i
175+
if best_bpp is None:
176+
assert bpp_strict
177+
raise RuntimeError(
178+
"Could not compress image to less than {} bpp.".format(target_bpp))
179+
bitstring = best_bitstring
180+
181+
# Write bitstring to disk.
182+
with tf.io.gfile.GFile(output_file, "wb") as f:
183+
f.write(bitstring)
184+
185+
186+
def decompress(input_file, output_file):
130187
"""Decompresses a TFCI file and writes a PNG file."""
131188
if not output_file:
132189
output_file = input_file + ".png"
@@ -136,10 +193,10 @@ def decompress(input_file, output_file, url_prefix, metagraph_cache):
136193
with tf.io.gfile.GFile(input_file, "rb") as f:
137194
example = tf.train.Example()
138195
example.ParseFromString(f.read())
139-
model = example.features.feature["MD"].bytes_list.value[0]
196+
model = example.features.feature["MD"].bytes_list.value[0].decode("ascii")
140197

141198
# Load model metagraph.
142-
signature_defs = load_metagraph(model, url_prefix, metagraph_cache)
199+
signature_defs = import_metagraph(model)
143200
inputs, outputs = instantiate_signature(signature_defs["receiver"])
144201

145202
# Multiple input tensors, ordered alphabetically, without names.
@@ -166,52 +223,59 @@ def decompress(input_file, output_file, url_prefix, metagraph_cache):
166223
sess.run(outputs, feed_dict=feed_dict)
167224

168225

169-
def list_models(url_prefix):
170-
url = url_prefix + "/models.txt"
226+
def list_models():
227+
url = URL_PREFIX + "/models.txt"
171228
try:
172229
request = urllib.request.urlopen(url)
173-
print(request.read())
230+
print(request.read().decode("utf-8"))
174231
finally:
175232
request.close()
176233

177234

178235
def parse_args(argv):
236+
"""Parses command line arguments."""
179237
parser = argparse_flags.ArgumentParser(
180238
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
181239

182240
# High-level options.
183241
parser.add_argument(
184242
"--url_prefix",
185-
default="https://storage.googleapis.com/tensorflow_compression/"
186-
"metagraphs",
243+
default=URL_PREFIX,
187244
help="URL prefix for downloading model metagraphs.")
188245
parser.add_argument(
189246
"--metagraph_cache",
190-
default="/tmp/tfc_metagraphs",
247+
default=METAGRAPH_CACHE,
191248
help="Directory where to cache model metagraphs.")
192249
subparsers = parser.add_subparsers(
193-
title="commands", help="Invoke '<command> -h' for more information.")
250+
title="commands", dest="command",
251+
help="Invoke '<command> -h' for more information.")
194252

195253
# 'compress' subcommand.
196254
compress_cmd = subparsers.add_parser(
197255
"compress",
198256
description="Reads a PNG file, compresses it using the given model, and "
199257
"writes a TFCI file.")
200-
compress_cmd.set_defaults(
201-
f=compress,
202-
a=["model", "input_file", "output_file", "url_prefix", "metagraph_cache"])
203258
compress_cmd.add_argument(
204259
"model",
205-
help="Unique model identifier. See 'models' command for options.")
260+
help="Unique model identifier. See 'models' command for options. If "
261+
"'target_bpp' is provided, don't specify the index at the end of "
262+
"the model identifier.")
263+
compress_cmd.add_argument(
264+
"--target_bpp", type=float,
265+
help="Target bits per pixel. If provided, a binary search is used to try "
266+
"to match the given bpp as close as possible. In this case, don't "
267+
"specify the index at the end of the model identifier. It will be "
268+
"automatically determined.")
269+
compress_cmd.add_argument(
270+
"--bpp_strict", action="store_true",
271+
help="Try never to exceed 'target_bpp'. Ignored if 'target_bpp' is not "
272+
"set.")
206273

207274
# 'decompress' subcommand.
208275
decompress_cmd = subparsers.add_parser(
209276
"decompress",
210277
description="Reads a TFCI file, reconstructs the image using the model "
211278
"it was compressed with, and writes back a PNG file.")
212-
decompress_cmd.set_defaults(
213-
f=decompress,
214-
a=["input_file", "output_file", "url_prefix", "metagraph_cache"])
215279

216280
# Arguments for both 'compress' and 'decompress'.
217281
for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")):
@@ -224,18 +288,34 @@ def parse_args(argv):
224288
"the input filename.".format(ext))
225289

226290
# 'models' subcommand.
227-
models_cmd = subparsers.add_parser(
291+
subparsers.add_parser(
228292
"models",
229293
description="Lists available trained models. Requires an internet "
230294
"connection.")
231-
models_cmd.set_defaults(f=list_models, a=["url_prefix"])
232295

233296
# Parse arguments.
234-
return parser.parse_args(argv[1:])
297+
args = parser.parse_args(argv[1:])
298+
if args.command is None:
299+
parser.print_usage()
300+
sys.exit(2)
301+
return args
302+
303+
304+
def main(args):
305+
# Command line can override these defaults.
306+
global URL_PREFIX, METAGRAPH_CACHE
307+
URL_PREFIX = args.url_prefix
308+
METAGRAPH_CACHE = args.metagraph_cache
309+
310+
# Invoke subcommand.
311+
if args.command == "compress":
312+
compress(args.model, args.input_file, args.output_file,
313+
args.target_bpp, args.bpp_strict)
314+
if args.command == "decompress":
315+
decompress(args.input_file, args.output_file)
316+
if args.command == "models":
317+
list_models()
235318

236319

237320
if __name__ == "__main__":
238-
# Parse arguments and run function determined by subcommand.
239-
app.run(
240-
lambda args: args.f(**{k: getattr(args, k) for k in args.a}),
241-
flags_parser=parse_args)
321+
app.run(main, flags_parser=parse_args)

0 commit comments

Comments
 (0)