Skip to content

Commit 783eab8

Browse files
Johannes Ballécopybara-github
authored andcommitted
Updates tfci.py to TF2 API.
PiperOrigin-RevId: 361233287 Change-Id: I7c7633f418312229d41cd0a65b6a24b1af6176ca
1 parent 72366f5 commit 783eab8

File tree

1 file changed

+35
-75
lines changed

1 file changed

+35
-75
lines changed

models/tfci.py

Lines changed: 35 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,41 @@
1717
Use this script to compress images with pre-trained models as published. See the
1818
'models' subcommand for a list of available models.
1919
20-
Currently, this script requires tensorflow-compression v1.3.
20+
This script requires TFC v2 (`pip install tensorflow-compression==2.*`).
2121
"""
2222

2323
import argparse
2424
import os
2525
import sys
2626
import urllib
27-
2827
from absl import app
2928
from absl.flags import argparse_flags
30-
import tensorflow.compat.v1 as tf
31-
29+
import tensorflow as tf
3230
import tensorflow_compression as tfc # pylint:disable=unused-import
3331

32+
3433
# Default URL to fetch metagraphs from.
3534
URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs"
3635
# Default location to store cached metagraphs.
3736
METAGRAPH_CACHE = "/tmp/tfc_metagraphs"
3837

3938

4039
def read_png(filename):
41-
"""Creates graph to load a PNG image file."""
40+
"""Loads a PNG image file."""
4241
string = tf.io.read_file(filename)
4342
image = tf.image.decode_image(string)
44-
image = tf.expand_dims(image, 0)
45-
return image
43+
return tf.expand_dims(image, 0)
4644

4745

4846
def write_png(filename, image):
49-
"""Creates graph to write a PNG image file."""
47+
"""Writes a PNG image file."""
5048
image = tf.squeeze(image, 0)
5149
if image.dtype.is_floating:
5250
image = tf.round(image)
5351
if image.dtype != tf.uint8:
5452
image = tf.saturate_cast(image, tf.uint8)
5553
string = tf.image.encode_png(image)
56-
return tf.io.write_file(filename, string)
54+
tf.io.write_file(filename, string)
5755

5856

5957
def load_cached(filename):
@@ -63,9 +61,9 @@ def load_cached(filename):
6361
with tf.io.gfile.GFile(pathname, "rb") as f:
6462
string = f.read()
6563
except tf.errors.NotFoundError:
66-
url = URL_PREFIX + "/" + filename
64+
url = f"{URL_PREFIX}/{filename}"
65+
request = urllib.request.urlopen(url)
6766
try:
68-
request = urllib.request.urlopen(url)
6967
string = request.read()
7068
finally:
7169
request.close()
@@ -75,50 +73,29 @@ def load_cached(filename):
7573
return string
7674

7775

78-
def import_metagraph(model):
79-
"""Imports a trained model metagraph into the current graph."""
76+
def instantiate_model_signature(model, signature):
77+
"""Imports a trained model and returns one of its signatures as a function."""
8078
string = load_cached(model + ".metagraph")
81-
metagraph = tf.MetaGraphDef()
79+
metagraph = tf.compat.v1.MetaGraphDef()
8280
metagraph.ParseFromString(string)
83-
tf.train.import_meta_graph(metagraph)
84-
return metagraph.signature_def
85-
86-
87-
def instantiate_signature(signature_def):
88-
"""Fetches tensors defined in a signature from the graph."""
89-
graph = tf.get_default_graph()
90-
inputs = {
91-
k: graph.get_tensor_by_name(v.name)
92-
for k, v in signature_def.inputs.items()
93-
}
94-
outputs = {
95-
k: graph.get_tensor_by_name(v.name)
96-
for k, v in signature_def.outputs.items()
97-
}
98-
return inputs, outputs
81+
wrapped_import = tf.compat.v1.wrap_function(
82+
lambda: tf.compat.v1.train.import_meta_graph(metagraph), [])
83+
graph = wrapped_import.graph
84+
inputs = metagraph.signature_def[signature].inputs
85+
outputs = metagraph.signature_def[signature].outputs
86+
inputs = [graph.as_graph_element(inputs[k].name) for k in sorted(inputs)]
87+
outputs = [graph.as_graph_element(outputs[k].name) for k in sorted(outputs)]
88+
return wrapped_import.prune(inputs, outputs)
9989

10090

10191
def compress_image(model, input_image):
102-
"""Compresses an image array into a bitstring."""
103-
with tf.Graph().as_default():
104-
# Load model metagraph.
105-
signature_defs = import_metagraph(model)
106-
inputs, outputs = instantiate_signature(signature_defs["sender"])
107-
108-
# Just one input tensor.
109-
inputs = inputs["input_image"]
110-
# Multiple output tensors, ordered alphabetically, without names.
111-
outputs = [outputs[k] for k in sorted(outputs) if k.startswith("channel:")]
112-
113-
# Run encoder.
114-
with tf.Session() as sess:
115-
arrays = sess.run(outputs, feed_dict={inputs: input_image})
116-
117-
# Pack data into bitstring.
118-
packed = tfc.PackedTensors()
119-
packed.model = model
120-
packed.pack(outputs, arrays)
121-
return packed.string
92+
"""Compresses an image tensor into a bitstring."""
93+
sender = instantiate_model_signature(model, "sender")
94+
tensors = sender(input_image)
95+
packed = tfc.PackedTensors()
96+
packed.model = model
97+
packed.pack(tensors)
98+
return packed.string
12299

123100

124101
def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
@@ -127,10 +104,8 @@ def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
127104
output_file = input_file + ".tfci"
128105

129106
# Load image.
130-
with tf.Graph().as_default():
131-
with tf.Session() as sess:
132-
input_image = sess.run(read_png(input_file))
133-
num_pixels = input_image.shape[-2] * input_image.shape[-3]
107+
input_image = read_png(input_file)
108+
num_pixels = input_image.shape[-2] * input_image.shape[-3]
134109

135110
if not target_bpp:
136111
# Just compress with a specific model.
@@ -175,27 +150,12 @@ def decompress(input_file, output_file):
175150
"""Decompresses a TFCI file and writes a PNG file."""
176151
if not output_file:
177152
output_file = input_file + ".png"
178-
179-
with tf.Graph().as_default():
180-
# Unserialize packed data from disk.
181-
with tf.io.gfile.GFile(input_file, "rb") as f:
182-
packed = tfc.PackedTensors(f.read())
183-
184-
# Load model metagraph.
185-
signature_defs = import_metagraph(packed.model)
186-
inputs, outputs = instantiate_signature(signature_defs["receiver"])
187-
188-
# Multiple input tensors, ordered alphabetically, without names.
189-
inputs = [inputs[k] for k in sorted(inputs) if k.startswith("channel:")]
190-
# Just one output operation.
191-
outputs = write_png(output_file, outputs["output_image"])
192-
193-
# Unpack data.
194-
arrays = packed.unpack(inputs)
195-
196-
# Run decoder.
197-
with tf.Session() as sess:
198-
sess.run(outputs, feed_dict=dict(zip(inputs, arrays)))
153+
with tf.io.gfile.GFile(input_file, "rb") as f:
154+
packed = tfc.PackedTensors(f.read())
155+
receiver = instantiate_model_signature(packed.model, "receiver")
156+
tensors = packed.unpack([t.dtype for t in receiver.inputs])
157+
output_image, = receiver(*tensors)
158+
write_png(output_file, output_image)
199159

200160

201161
def list_models():

0 commit comments

Comments
 (0)