Skip to content

Commit 903ea4a

Browse files
Improve FX node naming (#4418)
* draft commit * Polish and add corresponding test * Update docs * Update torchvision/models/feature_extraction.py * Update docs/source/feature_extraction.rst Co-authored-by: Francisco Massa <[email protected]>
1 parent 6518372 commit 903ea4a

File tree

3 files changed

+139
-57
lines changed

3 files changed

+139
-57
lines changed

docs/source/feature_extraction.rst

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ It works by following roughly these steps:
1919

2020
1. Symbolically tracing the model to get a graphical representation of
2121
how it transforms the input, step by step.
22-
2. Setting the user-selected graph nodes as ouputs.
23-
3. Removing all redundant nodes (anything downstream of the ouput nodes).
22+
2. Setting the user-selected graph nodes as outputs.
23+
3. Removing all redundant nodes (anything downstream of the output nodes).
2424
4. Generating python code from the resulting graph and bundling that into a
2525
PyTorch module together with the graph itself.
2626

@@ -30,6 +30,39 @@ The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
3030
provides a more general and detailed explanation of the above procedure and
3131
the inner workings of the symbolic tracing.
3232

33+
.. _about-node-names:
34+
35+
**About Node Names**
36+
37+
In order to specify which nodes should be output nodes for extracted
38+
features, one should be familiar with the node naming convention used here
39+
(which differs slightly from that used in ``torch.fx``). A node name is
40+
specified as a ``.`` separated path walking the module hierarchy from top level
41+
module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
42+
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
43+
layer of the ``ResNet`` module. Here are some finer points to keep in mind:
44+
45+
- When specifying node names for :func:`create_feature_extractor`, you may
46+
provide a truncated version of a node name as a shortcut. To see how this
47+
works, try creating a ResNet-50 model and printing the node names with
48+
``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
49+
observe that the last node pertaining to ``layer4`` is
50+
``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
51+
node, or just ``"layer4"`` as this, by convention, refers to the last node
52+
(in order of execution) of ``layer4``.
53+
- If a certain module or operation is repeated more than once, node names get
54+
an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
55+
addition (``+``) operation is used three times in the same ``forward``
56+
method. Then there would be ``"path.to.module.add"``,
57+
``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
58+
maintained within the scope of the direct parent. So in ResNet-50 there is
59+
a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
60+
operations reside in different blocks, there is no need for a postfix to
61+
disambiguate.
62+
63+
64+
**An Example**
65+
3366
Here is an example of how we might extract features for MaskRCNN:
3467

3568
.. code-block:: python
@@ -80,10 +113,10 @@ Here is an example of how we might extract features for MaskRCNN:
80113
# Now you can build the feature extractor. This returns a module whose forward
81114
# method returns a dictionary like:
82115
# {
83-
# 'layer1': ouput of layer 1,
84-
# 'layer2': ouput of layer 2,
85-
# 'layer3': ouput of layer 3,
86-
# 'layer4': ouput of layer 4,
116+
# 'layer1': output of layer 1,
117+
# 'layer2': output of layer 2,
118+
# 'layer3': output of layer 3,
119+
# 'layer4': output of layer 4,
87120
# }
88121
create_feature_extractor(m, return_nodes=return_nodes)
89122

test/test_backbone_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,41 @@ def leaf_function(x):
3333
return int(x)
3434

3535

36+
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
37+
# are respected. Particularly the index postfix of repeated node names
38+
class TestSubModule(torch.nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
self.relu = torch.nn.ReLU()
42+
43+
def forward(self, x):
44+
x = x + 1
45+
x = x + 1
46+
x = self.relu(x)
47+
x = self.relu(x)
48+
return x
49+
50+
51+
class TestModule(torch.nn.Module):
52+
def __init__(self):
53+
super().__init__()
54+
self.submodule = TestSubModule()
55+
self.relu = torch.nn.ReLU()
56+
57+
def forward(self, x):
58+
x = self.submodule(x)
59+
x = x + 1
60+
x = x + 1
61+
x = self.relu(x)
62+
x = self.relu(x)
63+
return x
64+
65+
66+
test_module_nodes = [
67+
'x', 'submodule.add', 'submodule.add_1', 'submodule.relu',
68+
'submodule.relu_1', 'add', 'add_1', 'relu', 'relu_1']
69+
70+
3671
class TestFxFeatureExtraction:
3772
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
3873
model_defaults = {
@@ -104,6 +139,11 @@ def test_build_fx_feature_extractor(self, model_name):
104139
else: # otherwise skip this check
105140
raise ValueError
106141

142+
def test_node_name_conventions(self):
143+
model = TestModule()
144+
train_nodes, _ = get_graph_node_names(model)
145+
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
146+
107147
@pytest.mark.parametrize('model_name', get_available_models())
108148
def test_forward_backward(self, model_name):
109149
model = models.__dict__[model_name](**self.model_defaults).train()

torchvision/models/feature_extraction.py

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -97,31 +97,39 @@ def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
9797
def _get_node_qualname(
9898
self, module_qualname: str, node: fx.node.Node) -> str:
9999
node_qualname = module_qualname
100-
if node.op == 'call_module':
101-
# Node terminates in a leaf module so the module_qualname is a
102-
# complete description of the node
103-
for existing_qualname in reversed(self.node_to_qualname.values()):
104-
# Check to see if existing_qualname is of the form
105-
# {node_qualname} or {node_qualname}_{int}
106-
if re.match(rf'{node_qualname}(_[0-9]+)?$',
107-
existing_qualname) is not None:
108-
postfix = existing_qualname.replace(node_qualname, '')
109-
if len(postfix):
110-
# Existing_qualname is of the form {node_qualname}_{int}
111-
next_index = int(postfix[1:]) + 1
112-
else:
113-
# existing_qualname is of the form {node_qualname}
114-
next_index = 1
115-
node_qualname += f'_{next_index}'
116-
break
117-
pass
118-
else:
119-
# Node terminates in non- leaf module so the node name needs to be
120-
# appended
100+
101+
if node.op != 'call_module':
102+
# In this case module_qualname from torch.fx doesn't go all the
103+
# way to the leaf function/op so we need to append it
121104
if len(node_qualname) > 0:
122105
# Only append '.' if we are deeper than the top level module
123106
node_qualname += '.'
124107
node_qualname += str(node)
108+
109+
# Now we need to add an _{index} postfix on any repeated node names
110+
# For modules we do this from scratch
111+
# But for anything else, torch.fx already has a globally scoped
112+
# _{index} postfix. But we want it locally (relative to direct parent)
113+
# scoped. So first we need to undo the torch.fx postfix
114+
if re.match(r'.+_[0-9]+$', node_qualname) is not None:
115+
node_qualname = node_qualname.rsplit('_', 1)[0]
116+
117+
# ... and now we add on our own postfix
118+
for existing_qualname in reversed(self.node_to_qualname.values()):
119+
# Check to see if existing_qualname is of the form
120+
# {node_qualname} or {node_qualname}_{int}
121+
if re.match(rf'{node_qualname}(_[0-9]+)?$',
122+
existing_qualname) is not None:
123+
postfix = existing_qualname.replace(node_qualname, '')
124+
if len(postfix):
125+
# existing_qualname is of the form {node_qualname}_{int}
126+
next_index = int(postfix[1:]) + 1
127+
else:
128+
# existing_qualname is of the form {node_qualname}
129+
next_index = 1
130+
node_qualname += f'_{next_index}'
131+
break
132+
125133
return node_qualname
126134

127135

@@ -171,19 +179,23 @@ def get_graph_node_names(
171179
names are available for feature extraction. There are two reasons that
172180
node names can't easily be read directly from the code for a model:
173181
174-
1. Not all submodules are traced through. Modules from `torch.nn` all
182+
1. Not all submodules are traced through. Modules from ``torch.nn`` all
175183
fall within this category.
176184
2. Nodes representing the repeated application of the same operation
177-
or leaf module get a `_{counter}` postfix.
185+
or leaf module get a ``_{counter}`` postfix.
178186
179187
The model is traced twice: once in train mode, and once in eval mode. Both
180-
sets of nodes are returned.
188+
sets of node names are returned.
189+
190+
For more details on the node naming conventions used here, please see the
191+
:ref:`relevant subheading <about-node-names>` in the
192+
`documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
181193
182194
Args:
183195
model (nn.Module): model for which we'd like to print node names
184196
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
185-
`NodePathTracer` (they are eventually passed onto
186-
`torch.fx.Tracer`).
197+
``NodePathTracer`` (they are eventually passed onto
198+
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
187199
suppress_diff_warning (bool, optional): whether to suppress a warning
188200
when there are discrepancies between the train and eval version of
189201
the graph. Defaults to False.
@@ -289,58 +301,55 @@ def create_feature_extractor(
289301
the model via FX to return the desired nodes as outputs. All unused nodes
290302
are removed, together with their corresponding parameters.
291303
292-
A note on node specification: For the purposes of this feature extraction
293-
utility, a node name is specified as a `.` seperated path walking the
294-
hierarchy from top level module down to leaf operation or leaf module. For
295-
instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should
296-
point to either a node's name, or some truncated version of it. For
297-
example, one could provide `blocks.5` as a key, and the last node with
298-
that prefix will be selected. :func:`get_graph_node_names` is a useful
299-
helper function for getting a list of node names of a model.
304+
Desired output nodes must be specified as a ``.`` separated
305+
path walking the module hierarchy from top level module down to leaf
306+
operation or leaf module. For more details on the node naming conventions
307+
used here, please see the :ref:`relevant subheading <about-node-names>`
308+
in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
300309
301310
Not all models will be FX traceable, although with some massaging they can
302311
be made to cooperate. Here's a (not exhaustive) list of tips:
303312
304313
- If you don't need to trace through a particular, problematic
305314
sub-module, turn it into a "leaf module" by passing a list of
306-
`leaf_modules` as one of the `tracer_kwargs` (see example below). It
307-
will not be traced through, but rather, the resulting graph will
315+
``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
316+
It will not be traced through, but rather, the resulting graph will
308317
hold a reference to that module's forward method.
309318
- Likewise, you may turn functions into leaf functions by passing a
310-
list of `autowrap_functions` as one of the `tracer_kwargs` (see
319+
list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
311320
example below).
312321
- Some inbuilt Python functions can be problematic. For instance,
313-
`int` will raise an error during tracing. You may wrap them in your
314-
own function and then pass that in `autowrap_functions` as one of
315-
the `tracer_kwargs`.
322+
``int`` will raise an error during tracing. You may wrap them in your
323+
own function and then pass that in ``autowrap_functions`` as one of
324+
the ``tracer_kwargs``.
316325
317326
For further information on FX see the
318327
`torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
319328
320329
Args:
321330
model (nn.Module): model on which we will extract the features
322-
return_nodes (list or dict, optional): either a `List` or a `Dict`
331+
return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
323332
containing the names (or partial names - see note above)
324333
of the nodes for which the activations will be returned. If it is
325-
a `Dict`, the keys are the node names, and the values
334+
a ``Dict``, the keys are the node names, and the values
326335
are the user-specified keys for the graph module's returned
327-
dictionary. If it is a `List`, it is treated as a `Dict` mapping
336+
dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
328337
node specification strings directly to output names. In the case
329-
that `train_return_nodes` and `eval_return_nodes` are specified,
338+
that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
330339
this should not be specified.
331340
train_return_nodes (list or dict, optional): similar to
332-
`return_nodes`. This can be used if the return nodes
341+
``return_nodes``. This can be used if the return nodes
333342
for train mode are different than those from eval mode.
334-
If this is specified, `eval_return_nodes` must also be specified,
335-
and `return_nodes` should not be specified.
343+
If this is specified, ``eval_return_nodes`` must also be specified,
344+
and ``return_nodes`` should not be specified.
336345
eval_return_nodes (list or dict, optional): similar to
337-
`return_nodes`. This can be used if the return nodes
346+
``return_nodes``. This can be used if the return nodes
338347
for train mode are different than those from eval mode.
339-
If this is specified, `train_return_nodes` must also be specified,
348+
If this is specified, ``train_return_nodes`` must also be specified,
340349
and `return_nodes` should not be specified.
341350
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
342-
`NodePathTracer` (which passes them onto it's parent class
343-
`torch.fx.Tracer`).
351+
``NodePathTracer`` (which passes them onto it's parent class
352+
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
344353
suppress_diff_warning (bool, optional): whether to suppress a warning
345354
when there are discrepancies between the train and eval version of
346355
the graph. Defaults to False.

0 commit comments

Comments
 (0)