Skip to content

Commit 198fe48

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds subcommands to explore internal tensors of pretrained models.
PiperOrigin-RevId: 363414919 Change-Id: I26a1f5148f6f47d31969bcc5bfcde98cd374d72a
1 parent 64eba71 commit 198fe48

File tree

1 file changed

+83
-16
lines changed

1 file changed

+83
-16
lines changed

models/tfci.py

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
"""
2222

2323
import argparse
24+
import io
2425
import os
2526
import sys
2627
import urllib
2728
from absl import app
2829
from absl.flags import argparse_flags
30+
import numpy as np
2931
import tensorflow as tf
3032
import tensorflow_compression as tfc # pylint:disable=unused-import
3133

@@ -73,7 +75,7 @@ def load_cached(filename):
7375
return string
7476

7577

76-
def instantiate_model_signature(model, signature):
78+
def instantiate_model_signature(model, signature, outputs=None):
7779
"""Imports a trained model and returns one of its signatures as a function."""
7880
string = load_cached(model + ".metagraph")
7981
metagraph = tf.compat.v1.MetaGraphDef()
@@ -82,9 +84,12 @@ def instantiate_model_signature(model, signature):
8284
lambda: tf.compat.v1.train.import_meta_graph(metagraph), [])
8385
graph = wrapped_import.graph
8486
inputs = metagraph.signature_def[signature].inputs
85-
outputs = metagraph.signature_def[signature].outputs
8687
inputs = [graph.as_graph_element(inputs[k].name) for k in sorted(inputs)]
87-
outputs = [graph.as_graph_element(outputs[k].name) for k in sorted(outputs)]
88+
if outputs is None:
89+
outputs = metagraph.signature_def[signature].outputs
90+
outputs = [graph.as_graph_element(outputs[k].name) for k in sorted(outputs)]
91+
else:
92+
outputs = [graph.as_graph_element(t) for t in outputs]
8893
return wrapped_import.prune(inputs, outputs)
8994

9095

@@ -159,14 +164,45 @@ def decompress(input_file, output_file):
159164

160165

161166
def list_models():
167+
"""Lists available models in web storage with a description."""
162168
url = URL_PREFIX + "/models.txt"
169+
request = urllib.request.urlopen(url)
163170
try:
164-
request = urllib.request.urlopen(url)
165171
print(request.read().decode("utf-8"))
166172
finally:
167173
request.close()
168174

169175

176+
def list_tensors(model):
177+
"""Lists all internal tensors of the sender signature of a given model."""
178+
sender = instantiate_model_signature(model, "sender")
179+
tensors = sorted(
180+
(t.name, t.dtype.name, t.shape)
181+
for o in sender.graph.get_operations()
182+
for t in o.outputs
183+
)
184+
for name, dtype, shape in tensors:
185+
print(f"{name} (dtype={dtype}, shape={shape})")
186+
187+
188+
def dump_tensor(model, tensors, input_file, output_file):
189+
"""Dumps the given tensors of a model in .npz format."""
190+
if not output_file:
191+
output_file = input_file + ".npz"
192+
sender = instantiate_model_signature(model, "sender", outputs=tensors)
193+
input_image = read_png(input_file)
194+
# Replace special characters in tensor names with underscores.
195+
table = str.maketrans(r"^./-:", r"_____")
196+
tensors = [t.translate(table) for t in tensors]
197+
values = [t.numpy() for t in sender(input_image)]
198+
assert len(tensors) == len(values)
199+
# Write to buffer first, since GFile might not be random accessible.
200+
with io.BytesIO() as buf:
201+
np.savez(buf, **dict(zip(tensors, values)))
202+
with tf.io.gfile.GFile(output_file, mode="wb") as f:
203+
f.write(buf.getvalue())
204+
205+
170206
def parse_args(argv):
171207
"""Parses command line arguments."""
172208
parser = argparse_flags.ArgumentParser(
@@ -214,23 +250,48 @@ def parse_args(argv):
214250
description="Reads a TFCI file, reconstructs the image using the model "
215251
"it was compressed with, and writes back a PNG file.")
216252

217-
# Arguments for both 'compress' and 'decompress'.
218-
for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")):
219-
cmd.add_argument(
220-
"input_file",
221-
help="Input filename.")
222-
cmd.add_argument(
223-
"output_file", nargs="?",
224-
help="Output filename (optional). If not provided, appends '{}' to "
225-
"the input filename.".format(ext))
226-
227253
# 'models' subcommand.
228254
subparsers.add_parser(
229255
"models",
230256
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
231257
description="Lists available trained models. Requires an internet "
232258
"connection.")
233259

260+
tensors_cmd = subparsers.add_parser(
261+
"tensors",
262+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
263+
description="Lists names of internal tensors of a given model.")
264+
tensors_cmd.add_argument(
265+
"model",
266+
help="Unique model identifier. See 'models' command for options.")
267+
268+
dump_cmd = subparsers.add_parser(
269+
"dump",
270+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
271+
description="Dumps values of given internal tensors of a model in "
272+
"NumPy's .npz format.")
273+
dump_cmd.add_argument(
274+
"model",
275+
help="Unique model identifier. See 'models' command for options.")
276+
dump_cmd.add_argument(
277+
"--tensor", "-t", nargs="+",
278+
help="Name(s) of tensor(s) to dump. Must provide at least one. See "
279+
"'tensors' command for options.")
280+
281+
# Arguments for 'compress', 'decompress', and 'dump'.
282+
for cmd, ext in (
283+
(compress_cmd, ".tfci"),
284+
(decompress_cmd, ".png"),
285+
(dump_cmd, ".npz"),
286+
):
287+
cmd.add_argument(
288+
"input_file",
289+
help="Input filename.")
290+
cmd.add_argument(
291+
"output_file", nargs="?",
292+
help=f"Output filename (optional). If not provided, appends '{ext}' to "
293+
f"the input filename.")
294+
234295
# Parse arguments.
235296
args = parser.parse_args(argv[1:])
236297
if args.command is None:
@@ -249,10 +310,16 @@ def main(args):
249310
if args.command == "compress":
250311
compress(args.model, args.input_file, args.output_file,
251312
args.target_bpp, args.bpp_strict)
252-
if args.command == "decompress":
313+
elif args.command == "decompress":
253314
decompress(args.input_file, args.output_file)
254-
if args.command == "models":
315+
elif args.command == "models":
255316
list_models()
317+
elif args.command == "tensors":
318+
list_tensors(args.model)
319+
elif args.command == "dump":
320+
if not args.tensor:
321+
raise ValueError("Must provide at least one tensor to dump.")
322+
dump_tensor(args.model, args.tensor, args.input_file, args.output_file)
256323

257324

258325
if __name__ == "__main__":

0 commit comments

Comments
 (0)