Skip to content

Commit 956a376

Browse files
Johannes Ballécopybara-github
authored andcommitted
Correction to tfci.py: also lists receiver-side tensors of a model.
PiperOrigin-RevId: 363456244 Change-Id: I5a5ac4fc2d91663c1578120bd0c3051f9b0b3512
1 parent ced9fb8 commit 956a376

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

models/tfci.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,19 @@ def load_cached(filename):
7575
return string
7676

7777

78-
def instantiate_model_signature(model, signature, outputs=None):
78+
def instantiate_model_signature(model, signature, inputs=None, outputs=None):
7979
"""Imports a trained model and returns one of its signatures as a function."""
8080
string = load_cached(model + ".metagraph")
8181
metagraph = tf.compat.v1.MetaGraphDef()
8282
metagraph.ParseFromString(string)
8383
wrapped_import = tf.compat.v1.wrap_function(
8484
lambda: tf.compat.v1.train.import_meta_graph(metagraph), [])
8585
graph = wrapped_import.graph
86-
inputs = metagraph.signature_def[signature].inputs
87-
inputs = [graph.as_graph_element(inputs[k].name) for k in sorted(inputs)]
86+
if inputs is None:
87+
inputs = metagraph.signature_def[signature].inputs
88+
inputs = [graph.as_graph_element(inputs[k].name) for k in sorted(inputs)]
89+
else:
90+
inputs = [graph.as_graph_element(t) for t in inputs]
8891
if outputs is None:
8992
outputs = metagraph.signature_def[signature].outputs
9093
outputs = [graph.as_graph_element(outputs[k].name) for k in sorted(outputs)]
@@ -174,13 +177,22 @@ def list_models():
174177

175178

176179
def list_tensors(model):
177-
"""Lists all internal tensors of the sender signature of a given model."""
180+
"""Lists all internal tensors of a given model."""
181+
def get_names_dtypes_shapes(function):
182+
for op in function.graph.get_operations():
183+
for tensor in op.outputs:
184+
yield tensor.name, tensor.dtype.name, tensor.shape
185+
178186
sender = instantiate_model_signature(model, "sender")
179-
tensors = sorted(
180-
(t.name, t.dtype.name, t.shape)
181-
for o in sender.graph.get_operations()
182-
for t in o.outputs
183-
)
187+
tensors = sorted(get_names_dtypes_shapes(sender))
188+
print("Sender-side tensors:")
189+
for name, dtype, shape in tensors:
190+
print(f"{name} (dtype={dtype}, shape={shape})")
191+
print()
192+
193+
receiver = instantiate_model_signature(model, "receiver")
194+
tensors = sorted(get_names_dtypes_shapes(receiver))
195+
print("Receiver-side tensors:")
184196
for name, dtype, shape in tensors:
185197
print(f"{name} (dtype={dtype}, shape={shape})")
186198

@@ -189,6 +201,8 @@ def dump_tensor(model, tensors, input_file, output_file):
189201
"""Dumps the given tensors of a model in .npz format."""
190202
if not output_file:
191203
output_file = input_file + ".npz"
204+
# Note: if receiver-side tensors are requested, this is no problem, as the
205+
# metagraph contains the union of the sender and receiver graphs.
192206
sender = instantiate_model_signature(model, "sender", outputs=tensors)
193207
input_image = read_png(input_file)
194208
# Replace special characters in tensor names with underscores.

0 commit comments

Comments
 (0)