@@ -97,31 +97,39 @@ def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
97
97
def _get_node_qualname (
98
98
self , module_qualname : str , node : fx .node .Node ) -> str :
99
99
node_qualname = module_qualname
100
- if node .op == 'call_module' :
101
- # Node terminates in a leaf module so the module_qualname is a
102
- # complete description of the node
103
- for existing_qualname in reversed (self .node_to_qualname .values ()):
104
- # Check to see if existing_qualname is of the form
105
- # {node_qualname} or {node_qualname}_{int}
106
- if re .match (rf'{ node_qualname } (_[0-9]+)?$' ,
107
- existing_qualname ) is not None :
108
- postfix = existing_qualname .replace (node_qualname , '' )
109
- if len (postfix ):
110
- # Existing_qualname is of the form {node_qualname}_{int}
111
- next_index = int (postfix [1 :]) + 1
112
- else :
113
- # existing_qualname is of the form {node_qualname}
114
- next_index = 1
115
- node_qualname += f'_{ next_index } '
116
- break
117
- pass
118
- else :
119
- # Node terminates in non- leaf module so the node name needs to be
120
- # appended
100
+
101
+ if node .op != 'call_module' :
102
+ # In this case module_qualname from torch.fx doesn't go all the
103
+ # way to the leaf function/op so we need to append it
121
104
if len (node_qualname ) > 0 :
122
105
# Only append '.' if we are deeper than the top level module
123
106
node_qualname += '.'
124
107
node_qualname += str (node )
108
+
109
+ # Now we need to add an _{index} postfix on any repeated node names
110
+ # For modules we do this from scratch
111
+ # But for anything else, torch.fx already has a globally scoped
112
+ # _{index} postfix. But we want it locally (relative to direct parent)
113
+ # scoped. So first we need to undo the torch.fx postfix
114
+ if re .match (r'.+_[0-9]+$' , node_qualname ) is not None :
115
+ node_qualname = node_qualname .rsplit ('_' , 1 )[0 ]
116
+
117
+ # ... and now we add on our own postfix
118
+ for existing_qualname in reversed (self .node_to_qualname .values ()):
119
+ # Check to see if existing_qualname is of the form
120
+ # {node_qualname} or {node_qualname}_{int}
121
+ if re .match (rf'{ node_qualname } (_[0-9]+)?$' ,
122
+ existing_qualname ) is not None :
123
+ postfix = existing_qualname .replace (node_qualname , '' )
124
+ if len (postfix ):
125
+ # existing_qualname is of the form {node_qualname}_{int}
126
+ next_index = int (postfix [1 :]) + 1
127
+ else :
128
+ # existing_qualname is of the form {node_qualname}
129
+ next_index = 1
130
+ node_qualname += f'_{ next_index } '
131
+ break
132
+
125
133
return node_qualname
126
134
127
135
@@ -171,19 +179,23 @@ def get_graph_node_names(
171
179
names are available for feature extraction. There are two reasons that
172
180
node names can't easily be read directly from the code for a model:
173
181
174
- 1. Not all submodules are traced through. Modules from `torch.nn` all
182
+ 1. Not all submodules are traced through. Modules from `` torch.nn` ` all
175
183
fall within this category.
176
184
2. Nodes representing the repeated application of the same operation
177
- or leaf module get a `_{counter}` postfix.
185
+ or leaf module get a `` _{counter}` ` postfix.
178
186
179
187
The model is traced twice: once in train mode, and once in eval mode. Both
180
- sets of nodes are returned.
188
+ sets of node names are returned.
189
+
190
+ For more details on the node naming conventions used here, please see the
191
+ :ref:`relevant subheading <about-node-names>` in the
192
+ `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
181
193
182
194
Args:
183
195
model (nn.Module): model for which we'd like to print node names
184
196
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
185
- `NodePathTracer` (they are eventually passed onto
186
- `torch.fx.Tracer` ).
197
+ `` NodePathTracer` ` (they are eventually passed onto
198
+ `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_ ).
187
199
suppress_diff_warning (bool, optional): whether to suppress a warning
188
200
when there are discrepancies between the train and eval version of
189
201
the graph. Defaults to False.
@@ -289,58 +301,55 @@ def create_feature_extractor(
289
301
the model via FX to return the desired nodes as outputs. All unused nodes
290
302
are removed, together with their corresponding parameters.
291
303
292
- A note on node specification: For the purposes of this feature extraction
293
- utility, a node name is specified as a `.` seperated path walking the
294
- hierarchy from top level module down to leaf operation or leaf module. For
295
- instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should
296
- point to either a node's name, or some truncated version of it. For
297
- example, one could provide `blocks.5` as a key, and the last node with
298
- that prefix will be selected. :func:`get_graph_node_names` is a useful
299
- helper function for getting a list of node names of a model.
304
+ Desired output nodes must be specified as a ``.`` separated
305
+ path walking the module hierarchy from top level module down to leaf
306
+ operation or leaf module. For more details on the node naming conventions
307
+ used here, please see the :ref:`relevant subheading <about-node-names>`
308
+ in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
300
309
301
310
Not all models will be FX traceable, although with some massaging they can
302
311
be made to cooperate. Here's a (not exhaustive) list of tips:
303
312
304
313
- If you don't need to trace through a particular, problematic
305
314
sub-module, turn it into a "leaf module" by passing a list of
306
- `leaf_modules` as one of the `tracer_kwargs` (see example below). It
307
- will not be traced through, but rather, the resulting graph will
315
+ `` leaf_modules`` as one of the `` tracer_kwargs`` (see example below).
316
+ It will not be traced through, but rather, the resulting graph will
308
317
hold a reference to that module's forward method.
309
318
- Likewise, you may turn functions into leaf functions by passing a
310
- list of `autowrap_functions` as one of the `tracer_kwargs` (see
319
+ list of `` autowrap_functions`` as one of the `` tracer_kwargs` ` (see
311
320
example below).
312
321
- Some inbuilt Python functions can be problematic. For instance,
313
- `int` will raise an error during tracing. You may wrap them in your
314
- own function and then pass that in `autowrap_functions` as one of
315
- the `tracer_kwargs`.
322
+ `` int` ` will raise an error during tracing. You may wrap them in your
323
+ own function and then pass that in `` autowrap_functions` ` as one of
324
+ the `` tracer_kwargs` `.
316
325
317
326
For further information on FX see the
318
327
`torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
319
328
320
329
Args:
321
330
model (nn.Module): model on which we will extract the features
322
- return_nodes (list or dict, optional): either a `List` or a `Dict`
331
+ return_nodes (list or dict, optional): either a `` List`` or a `` Dict` `
323
332
containing the names (or partial names - see note above)
324
333
of the nodes for which the activations will be returned. If it is
325
- a `Dict`, the keys are the node names, and the values
334
+ a `` Dict` `, the keys are the node names, and the values
326
335
are the user-specified keys for the graph module's returned
327
- dictionary. If it is a `List`, it is treated as a `Dict` mapping
336
+ dictionary. If it is a `` List`` , it is treated as a `` Dict` ` mapping
328
337
node specification strings directly to output names. In the case
329
- that `train_return_nodes` and `eval_return_nodes` are specified,
338
+ that `` train_return_nodes`` and `` eval_return_nodes` ` are specified,
330
339
this should not be specified.
331
340
train_return_nodes (list or dict, optional): similar to
332
- `return_nodes`. This can be used if the return nodes
341
+ `` return_nodes` `. This can be used if the return nodes
333
342
for train mode are different than those from eval mode.
334
- If this is specified, `eval_return_nodes` must also be specified,
335
- and `return_nodes` should not be specified.
343
+ If this is specified, `` eval_return_nodes` ` must also be specified,
344
+ and `` return_nodes` ` should not be specified.
336
345
eval_return_nodes (list or dict, optional): similar to
337
- `return_nodes`. This can be used if the return nodes
346
+ `` return_nodes` `. This can be used if the return nodes
338
347
for train mode are different than those from eval mode.
339
- If this is specified, `train_return_nodes` must also be specified,
348
+ If this is specified, `` train_return_nodes` ` must also be specified,
340
349
and `return_nodes` should not be specified.
341
350
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
342
- `NodePathTracer` (which passes them onto it's parent class
343
- `torch.fx.Tracer` ).
351
+ `` NodePathTracer` ` (which passes them onto it's parent class
352
+ `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_ ).
344
353
suppress_diff_warning (bool, optional): whether to suppress a warning
345
354
when there are discrepancies between the train and eval version of
346
355
the graph. Defaults to False.
0 commit comments