17
17
Use this script to compress images with pre-trained models as published. See the
18
18
'models' subcommand for a list of available models.
19
19
20
- Currently, this script requires tensorflow-compression v1.3 .
20
+ This script requires TFC v2 (`pip install tensorflow-compression==2.*`) .
21
21
"""
22
22
23
23
import argparse
24
24
import os
25
25
import sys
26
26
import urllib
27
-
28
27
from absl import app
29
28
from absl .flags import argparse_flags
30
- import tensorflow .compat .v1 as tf
31
-
29
+ import tensorflow as tf
32
30
import tensorflow_compression as tfc # pylint:disable=unused-import
33
31
32
+
34
33
# Default URL to fetch metagraphs from.
35
34
URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs"
36
35
# Default location to store cached metagraphs.
37
36
METAGRAPH_CACHE = "/tmp/tfc_metagraphs"
38
37
39
38
40
39
def read_png (filename ):
41
- """Creates graph to load a PNG image file."""
40
+ """Loads a PNG image file."""
42
41
string = tf .io .read_file (filename )
43
42
image = tf .image .decode_image (string )
44
- image = tf .expand_dims (image , 0 )
45
- return image
43
+ return tf .expand_dims (image , 0 )
46
44
47
45
48
46
def write_png (filename , image ):
49
- """Creates graph to write a PNG image file."""
47
+ """Writes a PNG image file."""
50
48
image = tf .squeeze (image , 0 )
51
49
if image .dtype .is_floating :
52
50
image = tf .round (image )
53
51
if image .dtype != tf .uint8 :
54
52
image = tf .saturate_cast (image , tf .uint8 )
55
53
string = tf .image .encode_png (image )
56
- return tf .io .write_file (filename , string )
54
+ tf .io .write_file (filename , string )
57
55
58
56
59
57
def load_cached (filename ):
@@ -63,9 +61,9 @@ def load_cached(filename):
63
61
with tf .io .gfile .GFile (pathname , "rb" ) as f :
64
62
string = f .read ()
65
63
except tf .errors .NotFoundError :
66
- url = URL_PREFIX + "/" + filename
64
+ url = f"{ URL_PREFIX } /{ filename } "
65
+ request = urllib .request .urlopen (url )
67
66
try :
68
- request = urllib .request .urlopen (url )
69
67
string = request .read ()
70
68
finally :
71
69
request .close ()
@@ -75,50 +73,29 @@ def load_cached(filename):
75
73
return string
76
74
77
75
78
- def import_metagraph (model ):
79
- """Imports a trained model metagraph into the current graph ."""
76
+ def instantiate_model_signature (model , signature ):
77
+ """Imports a trained model and returns one of its signatures as a function ."""
80
78
string = load_cached (model + ".metagraph" )
81
- metagraph = tf .MetaGraphDef ()
79
+ metagraph = tf .compat . v1 . MetaGraphDef ()
82
80
metagraph .ParseFromString (string )
83
- tf .train .import_meta_graph (metagraph )
84
- return metagraph .signature_def
85
-
86
-
87
- def instantiate_signature (signature_def ):
88
- """Fetches tensors defined in a signature from the graph."""
89
- graph = tf .get_default_graph ()
90
- inputs = {
91
- k : graph .get_tensor_by_name (v .name )
92
- for k , v in signature_def .inputs .items ()
93
- }
94
- outputs = {
95
- k : graph .get_tensor_by_name (v .name )
96
- for k , v in signature_def .outputs .items ()
97
- }
98
- return inputs , outputs
81
+ wrapped_import = tf .compat .v1 .wrap_function (
82
+ lambda : tf .compat .v1 .train .import_meta_graph (metagraph ), [])
83
+ graph = wrapped_import .graph
84
+ inputs = metagraph .signature_def [signature ].inputs
85
+ outputs = metagraph .signature_def [signature ].outputs
86
+ inputs = [graph .as_graph_element (inputs [k ].name ) for k in sorted (inputs )]
87
+ outputs = [graph .as_graph_element (outputs [k ].name ) for k in sorted (outputs )]
88
+ return wrapped_import .prune (inputs , outputs )
99
89
100
90
101
91
def compress_image (model , input_image ):
102
- """Compresses an image array into a bitstring."""
103
- with tf .Graph ().as_default ():
104
- # Load model metagraph.
105
- signature_defs = import_metagraph (model )
106
- inputs , outputs = instantiate_signature (signature_defs ["sender" ])
107
-
108
- # Just one input tensor.
109
- inputs = inputs ["input_image" ]
110
- # Multiple output tensors, ordered alphabetically, without names.
111
- outputs = [outputs [k ] for k in sorted (outputs ) if k .startswith ("channel:" )]
112
-
113
- # Run encoder.
114
- with tf .Session () as sess :
115
- arrays = sess .run (outputs , feed_dict = {inputs : input_image })
116
-
117
- # Pack data into bitstring.
118
- packed = tfc .PackedTensors ()
119
- packed .model = model
120
- packed .pack (outputs , arrays )
121
- return packed .string
92
+ """Compresses an image tensor into a bitstring."""
93
+ sender = instantiate_model_signature (model , "sender" )
94
+ tensors = sender (input_image )
95
+ packed = tfc .PackedTensors ()
96
+ packed .model = model
97
+ packed .pack (tensors )
98
+ return packed .string
122
99
123
100
124
101
def compress (model , input_file , output_file , target_bpp = None , bpp_strict = False ):
@@ -127,10 +104,8 @@ def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
127
104
output_file = input_file + ".tfci"
128
105
129
106
# Load image.
130
- with tf .Graph ().as_default ():
131
- with tf .Session () as sess :
132
- input_image = sess .run (read_png (input_file ))
133
- num_pixels = input_image .shape [- 2 ] * input_image .shape [- 3 ]
107
+ input_image = read_png (input_file )
108
+ num_pixels = input_image .shape [- 2 ] * input_image .shape [- 3 ]
134
109
135
110
if not target_bpp :
136
111
# Just compress with a specific model.
@@ -175,27 +150,12 @@ def decompress(input_file, output_file):
175
150
"""Decompresses a TFCI file and writes a PNG file."""
176
151
if not output_file :
177
152
output_file = input_file + ".png"
178
-
179
- with tf .Graph ().as_default ():
180
- # Unserialize packed data from disk.
181
- with tf .io .gfile .GFile (input_file , "rb" ) as f :
182
- packed = tfc .PackedTensors (f .read ())
183
-
184
- # Load model metagraph.
185
- signature_defs = import_metagraph (packed .model )
186
- inputs , outputs = instantiate_signature (signature_defs ["receiver" ])
187
-
188
- # Multiple input tensors, ordered alphabetically, without names.
189
- inputs = [inputs [k ] for k in sorted (inputs ) if k .startswith ("channel:" )]
190
- # Just one output operation.
191
- outputs = write_png (output_file , outputs ["output_image" ])
192
-
193
- # Unpack data.
194
- arrays = packed .unpack (inputs )
195
-
196
- # Run decoder.
197
- with tf .Session () as sess :
198
- sess .run (outputs , feed_dict = dict (zip (inputs , arrays )))
153
+ with tf .io .gfile .GFile (input_file , "rb" ) as f :
154
+ packed = tfc .PackedTensors (f .read ())
155
+ receiver = instantiate_model_signature (packed .model , "receiver" )
156
+ tensors = packed .unpack ([t .dtype for t in receiver .inputs ])
157
+ output_image , = receiver (* tensors )
158
+ write_png (output_file , output_image )
199
159
200
160
201
161
def list_models ():
0 commit comments