@@ -75,16 +75,19 @@ def load_cached(filename):
75
75
return string
76
76
77
77
78
- def instantiate_model_signature (model , signature , outputs = None ):
78
+ def instantiate_model_signature (model , signature , inputs = None , outputs = None ):
79
79
"""Imports a trained model and returns one of its signatures as a function."""
80
80
string = load_cached (model + ".metagraph" )
81
81
metagraph = tf .compat .v1 .MetaGraphDef ()
82
82
metagraph .ParseFromString (string )
83
83
wrapped_import = tf .compat .v1 .wrap_function (
84
84
lambda : tf .compat .v1 .train .import_meta_graph (metagraph ), [])
85
85
graph = wrapped_import .graph
86
- inputs = metagraph .signature_def [signature ].inputs
87
- inputs = [graph .as_graph_element (inputs [k ].name ) for k in sorted (inputs )]
86
+ if inputs is None :
87
+ inputs = metagraph .signature_def [signature ].inputs
88
+ inputs = [graph .as_graph_element (inputs [k ].name ) for k in sorted (inputs )]
89
+ else :
90
+ inputs = [graph .as_graph_element (t ) for t in inputs ]
88
91
if outputs is None :
89
92
outputs = metagraph .signature_def [signature ].outputs
90
93
outputs = [graph .as_graph_element (outputs [k ].name ) for k in sorted (outputs )]
@@ -174,13 +177,22 @@ def list_models():
174
177
175
178
176
179
def list_tensors (model ):
177
- """Lists all internal tensors of the sender signature of a given model."""
180
+ """Lists all internal tensors of a given model."""
181
+ def get_names_dtypes_shapes (function ):
182
+ for op in function .graph .get_operations ():
183
+ for tensor in op .outputs :
184
+ yield tensor .name , tensor .dtype .name , tensor .shape
185
+
178
186
sender = instantiate_model_signature (model , "sender" )
179
- tensors = sorted (
180
- (t .name , t .dtype .name , t .shape )
181
- for o in sender .graph .get_operations ()
182
- for t in o .outputs
183
- )
187
+ tensors = sorted (get_names_dtypes_shapes (sender ))
188
+ print ("Sender-side tensors:" )
189
+ for name , dtype , shape in tensors :
190
+ print (f"{ name } (dtype={ dtype } , shape={ shape } )" )
191
+ print ()
192
+
193
+ receiver = instantiate_model_signature (model , "receiver" )
194
+ tensors = sorted (get_names_dtypes_shapes (receiver ))
195
+ print ("Receiver-side tensors:" )
184
196
for name , dtype , shape in tensors :
185
197
print (f"{ name } (dtype={ dtype } , shape={ shape } )" )
186
198
@@ -189,6 +201,8 @@ def dump_tensor(model, tensors, input_file, output_file):
189
201
"""Dumps the given tensors of a model in .npz format."""
190
202
if not output_file :
191
203
output_file = input_file + ".npz"
204
+ # Note: if receiver-side tensors are requested, this is no problem, as the
205
+ # metagraph contains the union of the sender and receiver graphs.
192
206
sender = instantiate_model_signature (model , "sender" , outputs = tensors )
193
207
input_image = read_png (input_file )
194
208
# Replace special characters in tensor names with underscores.
0 commit comments