21
21
22
22
import argparse
23
23
import os
24
+ import sys
24
25
25
26
from absl import app
26
27
from absl .flags import argparse_flags
30
31
31
32
import tensorflow_compression as tfc # pylint:disable=unused-import
32
33
34
+ # Default URL to fetch metagraphs from.
35
+ URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs"
36
+ # Default location to store cached metagraphs.
37
+ METAGRAPH_CACHE = "/tmp/tfc_metagraphs"
38
+
33
39
34
40
def read_png (filename ):
35
41
"""Creates graph to load a PNG image file."""
@@ -50,22 +56,28 @@ def write_png(filename, image):
50
56
return tf .io .write_file (filename , string )
51
57
52
58
53
- def load_metagraph ( model , url_prefix , metagraph_cache ):
54
- """Loads and caches a trained model metagraph ."""
55
- filename = os .path .join (metagraph_cache , model + ".metagraph" )
59
+ def load_cached ( filename ):
60
+ """Downloads and caches files from web storage ."""
61
+ pathname = os .path .join (METAGRAPH_CACHE , filename )
56
62
try :
57
- with tf .io .gfile .GFile (filename , "rb" ) as f :
63
+ with tf .io .gfile .GFile (pathname , "rb" ) as f :
58
64
string = f .read ()
59
65
except tf .errors .NotFoundError :
60
- url = url_prefix + "/" + model + ".metagraph"
66
+ url = URL_PREFIX + "/" + filename
61
67
try :
62
68
request = urllib .request .urlopen (url )
63
69
string = request .read ()
64
70
finally :
65
71
request .close ()
66
- tf .io .gfile .makedirs (os .path .dirname (filename ))
67
- with tf .io .gfile .GFile (filename , "wb" ) as f :
72
+ tf .io .gfile .makedirs (os .path .dirname (pathname ))
73
+ with tf .io .gfile .GFile (pathname , "wb" ) as f :
68
74
f .write (string )
75
+ return string
76
+
77
+
78
+ def import_metagraph (model ):
79
+ """Imports a trained model metagraph into the current graph."""
80
+ string = load_cached (model + ".metagraph" )
69
81
metagraph = tf .MetaGraphDef ()
70
82
metagraph .ParseFromString (string )
71
83
tf .train .import_meta_graph (metagraph )
@@ -86,14 +98,11 @@ def instantiate_signature(signature_def):
86
98
return inputs , outputs
87
99
88
100
89
- def compress (model , input_file , output_file , url_prefix , metagraph_cache ):
90
- """Compresses a PNG file to a TFCI file."""
91
- if not output_file :
92
- output_file = input_file + ".tfci"
93
-
101
+ def compress_image (model , input_image ):
102
+ """Compresses an image array into a bitstring."""
94
103
with tf .Graph ().as_default ():
95
104
# Load model metagraph.
96
- signature_defs = load_metagraph (model , url_prefix , metagraph_cache )
105
+ signature_defs = import_metagraph (model )
97
106
inputs , outputs = instantiate_signature (signature_defs ["sender" ])
98
107
99
108
# Just one input tensor.
@@ -103,12 +112,12 @@ def compress(model, input_file, output_file, url_prefix, metagraph_cache):
103
112
104
113
# Run encoder.
105
114
with tf .Session () as sess :
106
- feed_dict = {inputs : sess . run ( read_png ( input_file )) }
115
+ feed_dict = {inputs : input_image }
107
116
arrays = sess .run (outputs , feed_dict = feed_dict )
108
117
109
118
# Pack data into tf.Example.
110
119
example = tf .train .Example ()
111
- example .features .feature ["MD" ].bytes_list .value [:] = [model ]
120
+ example .features .feature ["MD" ].bytes_list .value [:] = [model . encode ( "ascii" ) ]
112
121
for i , (array , tensor ) in enumerate (zip (arrays , outputs )):
113
122
feature = example .features .feature [chr (i + 1 )]
114
123
if array .ndim != 1 :
@@ -121,12 +130,60 @@ def compress(model, input_file, output_file, url_prefix, metagraph_cache):
121
130
raise RuntimeError (
122
131
"Unexpected tensor dtype: '{}'." .format (tensor .dtype ))
123
132
124
- # Write serialized tf.Example to disk.
125
- with tf .io .gfile .GFile (output_file , "wb" ) as f :
126
- f .write (example .SerializeToString ())
133
+ return example .SerializeToString ()
127
134
128
135
129
- def decompress (input_file , output_file , url_prefix , metagraph_cache ):
136
+ def compress (model , input_file , output_file , target_bpp = None , bpp_strict = False ):
137
+ """Compresses a PNG file to a TFCI file."""
138
+ if not output_file :
139
+ output_file = input_file + ".tfci"
140
+
141
+ # Load image.
142
+ with tf .Graph ().as_default ():
143
+ with tf .Session () as sess :
144
+ input_image = sess .run (read_png (input_file ))
145
+ num_pixels = input_image .shape [- 2 ] * input_image .shape [- 3 ]
146
+
147
+ if not target_bpp :
148
+ # Just compress with a specific model.
149
+ bitstring = compress_image (model , input_image )
150
+ else :
151
+ # Get model list.
152
+ models = load_cached (model + ".models" )
153
+ models = models .decode ("ascii" ).split ()
154
+
155
+ # Do a binary search over all RD points.
156
+ lower = - 1
157
+ upper = len (models )
158
+ bpp = None
159
+ best_bitstring = None
160
+ best_bpp = None
161
+ while bpp != target_bpp and upper - lower > 1 :
162
+ i = (upper + lower ) // 2
163
+ bitstring = compress_image (models [i ], input_image )
164
+ bpp = 8 * len (bitstring ) / num_pixels
165
+ is_admissible = bpp <= target_bpp or not bpp_strict
166
+ is_better = (best_bpp is None or
167
+ abs (bpp - target_bpp ) < abs (best_bpp - target_bpp ))
168
+ if is_admissible and is_better :
169
+ best_bitstring = bitstring
170
+ best_bpp = bpp
171
+ if bpp < target_bpp :
172
+ lower = i
173
+ if bpp > target_bpp :
174
+ upper = i
175
+ if best_bpp is None :
176
+ assert bpp_strict
177
+ raise RuntimeError (
178
+ "Could not compress image to less than {} bpp." .format (target_bpp ))
179
+ bitstring = best_bitstring
180
+
181
+ # Write bitstring to disk.
182
+ with tf .io .gfile .GFile (output_file , "wb" ) as f :
183
+ f .write (bitstring )
184
+
185
+
186
+ def decompress (input_file , output_file ):
130
187
"""Decompresses a TFCI file and writes a PNG file."""
131
188
if not output_file :
132
189
output_file = input_file + ".png"
@@ -136,10 +193,10 @@ def decompress(input_file, output_file, url_prefix, metagraph_cache):
136
193
with tf .io .gfile .GFile (input_file , "rb" ) as f :
137
194
example = tf .train .Example ()
138
195
example .ParseFromString (f .read ())
139
- model = example .features .feature ["MD" ].bytes_list .value [0 ]
196
+ model = example .features .feature ["MD" ].bytes_list .value [0 ]. decode ( "ascii" )
140
197
141
198
# Load model metagraph.
142
- signature_defs = load_metagraph (model , url_prefix , metagraph_cache )
199
+ signature_defs = import_metagraph (model )
143
200
inputs , outputs = instantiate_signature (signature_defs ["receiver" ])
144
201
145
202
# Multiple input tensors, ordered alphabetically, without names.
@@ -166,52 +223,59 @@ def decompress(input_file, output_file, url_prefix, metagraph_cache):
166
223
sess .run (outputs , feed_dict = feed_dict )
167
224
168
225
169
- def list_models (url_prefix ):
170
- url = url_prefix + "/models.txt"
226
+ def list_models ():
227
+ url = URL_PREFIX + "/models.txt"
171
228
try :
172
229
request = urllib .request .urlopen (url )
173
- print (request .read ())
230
+ print (request .read (). decode ( "utf-8" ) )
174
231
finally :
175
232
request .close ()
176
233
177
234
178
235
def parse_args (argv ):
236
+ """Parses command line arguments."""
179
237
parser = argparse_flags .ArgumentParser (
180
238
formatter_class = argparse .ArgumentDefaultsHelpFormatter )
181
239
182
240
# High-level options.
183
241
parser .add_argument (
184
242
"--url_prefix" ,
185
- default = "https://storage.googleapis.com/tensorflow_compression/"
186
- "metagraphs" ,
243
+ default = URL_PREFIX ,
187
244
help = "URL prefix for downloading model metagraphs." )
188
245
parser .add_argument (
189
246
"--metagraph_cache" ,
190
- default = "/tmp/tfc_metagraphs" ,
247
+ default = METAGRAPH_CACHE ,
191
248
help = "Directory where to cache model metagraphs." )
192
249
subparsers = parser .add_subparsers (
193
- title = "commands" , help = "Invoke '<command> -h' for more information." )
250
+ title = "commands" , dest = "command" ,
251
+ help = "Invoke '<command> -h' for more information." )
194
252
195
253
# 'compress' subcommand.
196
254
compress_cmd = subparsers .add_parser (
197
255
"compress" ,
198
256
description = "Reads a PNG file, compresses it using the given model, and "
199
257
"writes a TFCI file." )
200
- compress_cmd .set_defaults (
201
- f = compress ,
202
- a = ["model" , "input_file" , "output_file" , "url_prefix" , "metagraph_cache" ])
203
258
compress_cmd .add_argument (
204
259
"model" ,
205
- help = "Unique model identifier. See 'models' command for options." )
260
+ help = "Unique model identifier. See 'models' command for options. If "
261
+ "'target_bpp' is provided, don't specify the index at the end of "
262
+ "the model identifier." )
263
+ compress_cmd .add_argument (
264
+ "--target_bpp" , type = float ,
265
+ help = "Target bits per pixel. If provided, a binary search is used to try "
266
+ "to match the given bpp as close as possible. In this case, don't "
267
+ "specify the index at the end of the model identifier. It will be "
268
+ "automatically determined." )
269
+ compress_cmd .add_argument (
270
+ "--bpp_strict" , action = "store_true" ,
271
+ help = "Try never to exceed 'target_bpp'. Ignored if 'target_bpp' is not "
272
+ "set." )
206
273
207
274
# 'decompress' subcommand.
208
275
decompress_cmd = subparsers .add_parser (
209
276
"decompress" ,
210
277
description = "Reads a TFCI file, reconstructs the image using the model "
211
278
"it was compressed with, and writes back a PNG file." )
212
- decompress_cmd .set_defaults (
213
- f = decompress ,
214
- a = ["input_file" , "output_file" , "url_prefix" , "metagraph_cache" ])
215
279
216
280
# Arguments for both 'compress' and 'decompress'.
217
281
for cmd , ext in ((compress_cmd , ".tfci" ), (decompress_cmd , ".png" )):
@@ -224,18 +288,34 @@ def parse_args(argv):
224
288
"the input filename." .format (ext ))
225
289
226
290
# 'models' subcommand.
227
- models_cmd = subparsers .add_parser (
291
+ subparsers .add_parser (
228
292
"models" ,
229
293
description = "Lists available trained models. Requires an internet "
230
294
"connection." )
231
- models_cmd .set_defaults (f = list_models , a = ["url_prefix" ])
232
295
233
296
# Parse arguments.
234
- return parser .parse_args (argv [1 :])
297
+ args = parser .parse_args (argv [1 :])
298
+ if args .command is None :
299
+ parser .print_usage ()
300
+ sys .exit (2 )
301
+ return args
302
+
303
+
304
+ def main (args ):
305
+ # Command line can override these defaults.
306
+ global URL_PREFIX , METAGRAPH_CACHE
307
+ URL_PREFIX = args .url_prefix
308
+ METAGRAPH_CACHE = args .metagraph_cache
309
+
310
+ # Invoke subcommand.
311
+ if args .command == "compress" :
312
+ compress (args .model , args .input_file , args .output_file ,
313
+ args .target_bpp , args .bpp_strict )
314
+ if args .command == "decompress" :
315
+ decompress (args .input_file , args .output_file )
316
+ if args .command == "models" :
317
+ list_models ()
235
318
236
319
237
320
if __name__ == "__main__" :
238
- # Parse arguments and run function determined by subcommand.
239
- app .run (
240
- lambda args : args .f (** {k : getattr (args , k ) for k in args .a }),
241
- flags_parser = parse_args )
321
+ app .run (main , flags_parser = parse_args )
0 commit comments