21
21
"""
22
22
23
23
import argparse
24
+ import io
24
25
import os
25
26
import sys
26
27
import urllib
27
28
from absl import app
28
29
from absl .flags import argparse_flags
30
+ import numpy as np
29
31
import tensorflow as tf
30
32
import tensorflow_compression as tfc # pylint:disable=unused-import
31
33
@@ -73,7 +75,7 @@ def load_cached(filename):
73
75
return string
74
76
75
77
76
- def instantiate_model_signature (model , signature ):
78
+ def instantiate_model_signature (model , signature , outputs = None ):
77
79
"""Imports a trained model and returns one of its signatures as a function."""
78
80
string = load_cached (model + ".metagraph" )
79
81
metagraph = tf .compat .v1 .MetaGraphDef ()
@@ -82,9 +84,12 @@ def instantiate_model_signature(model, signature):
82
84
lambda : tf .compat .v1 .train .import_meta_graph (metagraph ), [])
83
85
graph = wrapped_import .graph
84
86
inputs = metagraph .signature_def [signature ].inputs
85
- outputs = metagraph .signature_def [signature ].outputs
86
87
inputs = [graph .as_graph_element (inputs [k ].name ) for k in sorted (inputs )]
87
- outputs = [graph .as_graph_element (outputs [k ].name ) for k in sorted (outputs )]
88
+ if outputs is None :
89
+ outputs = metagraph .signature_def [signature ].outputs
90
+ outputs = [graph .as_graph_element (outputs [k ].name ) for k in sorted (outputs )]
91
+ else :
92
+ outputs = [graph .as_graph_element (t ) for t in outputs ]
88
93
return wrapped_import .prune (inputs , outputs )
89
94
90
95
@@ -159,14 +164,45 @@ def decompress(input_file, output_file):
159
164
160
165
161
166
def list_models ():
167
+ """Lists available models in web storage with a description."""
162
168
url = URL_PREFIX + "/models.txt"
169
+ request = urllib .request .urlopen (url )
163
170
try :
164
- request = urllib .request .urlopen (url )
165
171
print (request .read ().decode ("utf-8" ))
166
172
finally :
167
173
request .close ()
168
174
169
175
176
+ def list_tensors (model ):
177
+ """Lists all internal tensors of the sender signature of a given model."""
178
+ sender = instantiate_model_signature (model , "sender" )
179
+ tensors = sorted (
180
+ (t .name , t .dtype .name , t .shape )
181
+ for o in sender .graph .get_operations ()
182
+ for t in o .outputs
183
+ )
184
+ for name , dtype , shape in tensors :
185
+ print (f"{ name } (dtype={ dtype } , shape={ shape } )" )
186
+
187
+
188
+ def dump_tensor (model , tensors , input_file , output_file ):
189
+ """Dumps the given tensors of a model in .npz format."""
190
+ if not output_file :
191
+ output_file = input_file + ".npz"
192
+ sender = instantiate_model_signature (model , "sender" , outputs = tensors )
193
+ input_image = read_png (input_file )
194
+ # Replace special characters in tensor names with underscores.
195
+ table = str .maketrans (r"^./-:" , r"_____" )
196
+ tensors = [t .translate (table ) for t in tensors ]
197
+ values = [t .numpy () for t in sender (input_image )]
198
+ assert len (tensors ) == len (values )
199
+ # Write to buffer first, since GFile might not be random accessible.
200
+ with io .BytesIO () as buf :
201
+ np .savez (buf , ** dict (zip (tensors , values )))
202
+ with tf .io .gfile .GFile (output_file , mode = "wb" ) as f :
203
+ f .write (buf .getvalue ())
204
+
205
+
170
206
def parse_args (argv ):
171
207
"""Parses command line arguments."""
172
208
parser = argparse_flags .ArgumentParser (
@@ -214,23 +250,48 @@ def parse_args(argv):
214
250
description = "Reads a TFCI file, reconstructs the image using the model "
215
251
"it was compressed with, and writes back a PNG file." )
216
252
217
- # Arguments for both 'compress' and 'decompress'.
218
- for cmd , ext in ((compress_cmd , ".tfci" ), (decompress_cmd , ".png" )):
219
- cmd .add_argument (
220
- "input_file" ,
221
- help = "Input filename." )
222
- cmd .add_argument (
223
- "output_file" , nargs = "?" ,
224
- help = "Output filename (optional). If not provided, appends '{}' to "
225
- "the input filename." .format (ext ))
226
-
227
253
# 'models' subcommand.
228
254
subparsers .add_parser (
229
255
"models" ,
230
256
formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
231
257
description = "Lists available trained models. Requires an internet "
232
258
"connection." )
233
259
260
+ tensors_cmd = subparsers .add_parser (
261
+ "tensors" ,
262
+ formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
263
+ description = "Lists names of internal tensors of a given model." )
264
+ tensors_cmd .add_argument (
265
+ "model" ,
266
+ help = "Unique model identifier. See 'models' command for options." )
267
+
268
+ dump_cmd = subparsers .add_parser (
269
+ "dump" ,
270
+ formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
271
+ description = "Dumps values of given internal tensors of a model in "
272
+ "NumPy's .npz format." )
273
+ dump_cmd .add_argument (
274
+ "model" ,
275
+ help = "Unique model identifier. See 'models' command for options." )
276
+ dump_cmd .add_argument (
277
+ "--tensor" , "-t" , nargs = "+" ,
278
+ help = "Name(s) of tensor(s) to dump. Must provide at least one. See "
279
+ "'tensors' command for options." )
280
+
281
+ # Arguments for 'compress', 'decompress', and 'dump'.
282
+ for cmd , ext in (
283
+ (compress_cmd , ".tfci" ),
284
+ (decompress_cmd , ".png" ),
285
+ (dump_cmd , ".npz" ),
286
+ ):
287
+ cmd .add_argument (
288
+ "input_file" ,
289
+ help = "Input filename." )
290
+ cmd .add_argument (
291
+ "output_file" , nargs = "?" ,
292
+ help = f"Output filename (optional). If not provided, appends '{ ext } ' to "
293
+ f"the input filename." )
294
+
234
295
# Parse arguments.
235
296
args = parser .parse_args (argv [1 :])
236
297
if args .command is None :
@@ -249,10 +310,16 @@ def main(args):
249
310
if args .command == "compress" :
250
311
compress (args .model , args .input_file , args .output_file ,
251
312
args .target_bpp , args .bpp_strict )
252
- if args .command == "decompress" :
313
+ elif args .command == "decompress" :
253
314
decompress (args .input_file , args .output_file )
254
- if args .command == "models" :
315
+ elif args .command == "models" :
255
316
list_models ()
317
+ elif args .command == "tensors" :
318
+ list_tensors (args .model )
319
+ elif args .command == "dump" :
320
+ if not args .tensor :
321
+ raise ValueError ("Must provide at least one tensor to dump." )
322
+ dump_tensor (args .model , args .tensor , args .input_file , args .output_file )
256
323
257
324
258
325
if __name__ == "__main__" :
0 commit comments