Skip to content

Commit 540110a

Browse files
author
jballe
committed
Adds support for trained models and TFCI image format.
PiperOrigin-RevId: 243349288
1 parent 637c00d commit 540110a

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

examples/tfci.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2019 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Converts an image between PNG and TFCI formats."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import argparse
23+
import os
24+
25+
# Dependency imports
26+
27+
from absl import app
28+
from absl.flags import argparse_flags
29+
import numpy as np
30+
from six.moves import urllib
31+
import tensorflow as tf
32+
import tensorflow_compression as tfc # pylint:disable=unused-import
33+
34+
35+
def read_png(filename):
36+
"""Creates graph to load a PNG image file."""
37+
string = tf.io.read_file(filename)
38+
image = tf.image.decode_image(string)
39+
image = tf.expand_dims(image, 0)
40+
return image
41+
42+
43+
def write_png(filename, image):
44+
"""Creates graph to write a PNG image file."""
45+
image = tf.squeeze(image, 0)
46+
if image.dtype.is_floating:
47+
image = tf.round(image)
48+
if image.dtype != tf.uint8:
49+
image = tf.saturate_cast(image, tf.uint8)
50+
string = tf.image.encode_png(image)
51+
return tf.io.write_file(filename, string)
52+
53+
54+
def load_metagraph(model, url_prefix, metagraph_cache):
55+
"""Loads and caches a trained model metagraph."""
56+
filename = os.path.join(metagraph_cache, model + ".metagraph")
57+
try:
58+
with tf.io.gfile.GFile(filename, "rb") as f:
59+
string = f.read()
60+
except tf.errors.NotFoundError:
61+
url = url_prefix + "/" + model + ".metagraph"
62+
try:
63+
request = urllib.request.urlopen(url)
64+
string = request.read()
65+
finally:
66+
request.close()
67+
tf.io.gfile.makedirs(os.path.dirname(filename))
68+
with tf.io.gfile.GFile(filename, "wb") as f:
69+
f.write(string)
70+
metagraph = tf.MetaGraphDef()
71+
metagraph.ParseFromString(string)
72+
tf.train.import_meta_graph(metagraph)
73+
return metagraph.signature_def
74+
75+
76+
def instantiate_signature(signature_def):
77+
"""Fetches tensors defined in a signature from the graph."""
78+
graph = tf.get_default_graph()
79+
inputs = {
80+
k: graph.get_tensor_by_name(v.name)
81+
for k, v in signature_def.inputs.items()
82+
}
83+
outputs = {
84+
k: graph.get_tensor_by_name(v.name)
85+
for k, v in signature_def.outputs.items()
86+
}
87+
return inputs, outputs
88+
89+
90+
def compress(model, input_file, output_file, url_prefix, metagraph_cache):
91+
"""Compresses a PNG file to a TFCI file."""
92+
if not output_file:
93+
output_file = input_file + ".tfci"
94+
95+
with tf.Graph().as_default():
96+
# Load model metagraph.
97+
signature_defs = load_metagraph(model, url_prefix, metagraph_cache)
98+
inputs, outputs = instantiate_signature(signature_defs["sender"])
99+
100+
# Just one input tensor.
101+
inputs = inputs["input_image"]
102+
# Multiple output tensors, ordered alphabetically, without names.
103+
outputs = [outputs[k] for k in sorted(outputs) if k.startswith("channel:")]
104+
105+
# Run encoder.
106+
with tf.Session() as sess:
107+
feed_dict = {inputs: sess.run(read_png(input_file))}
108+
arrays = sess.run(outputs, feed_dict=feed_dict)
109+
110+
# Pack data into tf.Example.
111+
example = tf.train.Example()
112+
example.features.feature["MD"].bytes_list.value[:] = [model]
113+
for i, (array, tensor) in enumerate(zip(arrays, outputs)):
114+
feature = example.features.feature[chr(i + 1)]
115+
if array.ndim != 1:
116+
raise RuntimeError("Unexpected tensor rank: {}.".format(array.ndim))
117+
if tensor.dtype.is_integer:
118+
feature.int64_list.value[:] = array
119+
elif tensor.dtype == tf.string:
120+
feature.bytes_list.value[:] = array
121+
else:
122+
raise RuntimeError(
123+
"Unexpected tensor dtype: '{}'.".format(tensor.dtype))
124+
125+
# Write serialized tf.Example to disk.
126+
with tf.io.gfile.GFile(output_file, "wb") as f:
127+
f.write(example.SerializeToString())
128+
129+
130+
def decompress(input_file, output_file, url_prefix, metagraph_cache):
131+
"""Decompresses a TFCI file and writes a PNG file."""
132+
if not output_file:
133+
output_file = input_file + ".png"
134+
135+
with tf.Graph().as_default():
136+
# Deserialize tf.Example from disk and determine model.
137+
with tf.io.gfile.GFile(input_file, "rb") as f:
138+
example = tf.train.Example()
139+
example.ParseFromString(f.read())
140+
model = example.features.feature["MD"].bytes_list.value[0]
141+
142+
# Load model metagraph.
143+
signature_defs = load_metagraph(model, url_prefix, metagraph_cache)
144+
inputs, outputs = instantiate_signature(signature_defs["receiver"])
145+
146+
# Multiple input tensors, ordered alphabetically, without names.
147+
inputs = [inputs[k] for k in sorted(inputs) if k.startswith("channel:")]
148+
# Just one output operation.
149+
outputs = write_png(output_file, outputs["output_image"])
150+
151+
# Unpack data from tf.Example.
152+
arrays = []
153+
for i, tensor in enumerate(inputs):
154+
feature = example.features.feature[chr(i + 1)]
155+
np_dtype = tensor.dtype.as_numpy_dtype
156+
if tensor.dtype.is_integer:
157+
arrays.append(np.array(feature.int64_list.value, dtype=np_dtype))
158+
elif tensor.dtype == tf.string:
159+
arrays.append(np.array(feature.bytes_list.value, dtype=np_dtype))
160+
else:
161+
raise RuntimeError(
162+
"Unexpected tensor dtype: '{}'.".format(tensor.dtype))
163+
164+
# Run decoder.
165+
with tf.Session() as sess:
166+
feed_dict = dict(zip(inputs, arrays))
167+
sess.run(outputs, feed_dict=feed_dict)
168+
169+
170+
def list_models(url_prefix):
171+
url = url_prefix + "/models.txt"
172+
try:
173+
request = urllib.request.urlopen(url)
174+
print(request.read())
175+
finally:
176+
request.close()
177+
178+
179+
def parse_args(argv):
180+
parser = argparse_flags.ArgumentParser(
181+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
182+
183+
# High-level options.
184+
parser.add_argument(
185+
"--url_prefix",
186+
default="https://storage.googleapis.com/tensorflow_compression/"
187+
"metagraphs",
188+
help="URL prefix for downloading model metagraphs.")
189+
parser.add_argument(
190+
"--metagraph_cache",
191+
default="/tmp/tfc_metagraphs",
192+
help="Directory where to cache model metagraphs.")
193+
subparsers = parser.add_subparsers(
194+
title="commands", help="Invoke '<command> -h' for more information.")
195+
196+
# 'compress' subcommand.
197+
compress_cmd = subparsers.add_parser(
198+
"compress",
199+
description="Reads a PNG file, compresses it using the given model, and "
200+
"writes a TFCI file.")
201+
compress_cmd.set_defaults(
202+
f=compress,
203+
a=["model", "input_file", "output_file", "url_prefix", "metagraph_cache"])
204+
compress_cmd.add_argument(
205+
"model",
206+
help="Unique model identifier. See 'models' command for options.")
207+
208+
# 'decompress' subcommand.
209+
decompress_cmd = subparsers.add_parser(
210+
"decompress",
211+
description="Reads a TFCI file, reconstructs the image using the model "
212+
"it was compressed with, and writes back a PNG file.")
213+
decompress_cmd.set_defaults(
214+
f=decompress,
215+
a=["input_file", "output_file", "url_prefix", "metagraph_cache"])
216+
217+
# Arguments for both 'compress' and 'decompress'.
218+
for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")):
219+
cmd.add_argument(
220+
"input_file",
221+
help="Input filename.")
222+
cmd.add_argument(
223+
"output_file", nargs="?",
224+
help="Output filename (optional). If not provided, appends '{}' to "
225+
"the input filename.".format(ext))
226+
227+
# 'models' subcommand.
228+
models_cmd = subparsers.add_parser(
229+
"models",
230+
description="Lists available trained models. Requires an internet "
231+
"connection.")
232+
models_cmd.set_defaults(f=list_models, a=["url_prefix"])
233+
234+
# Parse arguments.
235+
return parser.parse_args(argv[1:])
236+
237+
238+
if __name__ == "__main__":
239+
# Parse arguments and run function determined by subcommand.
240+
app.run(
241+
lambda args: args.f(**{k: getattr(args, k) for k in args.a}),
242+
flags_parser=parse_args)

0 commit comments

Comments
 (0)