Skip to content

Commit 606d73b

Browse files
yushangdipytorchmergebot
authored andcommitted
Adding from_node for nodes in gm.module() (pytorch#155053)
Summary: Adding "from_node" information that indicates which nodes are unlifted in `.module()` call. The lifted nodes will have "ExportedProgram.module().unlift()" passname in the last entry of from_node. Test Plan: ``` buck run fbcode//caffe2/test:test_export -- -r test_from_node_metadata_export ``` Rollback Plan: Reviewed By: angelayi Differential Revision: D75837494 Pull Request resolved: pytorch#155053 Approved by: https://github.com/angelayi
1 parent c8c892b commit 606d73b

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

test/export/test_export.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,70 @@ def forward(self, x):
571571

572572
self.assertEqual(counter, 1)
573573

574+
def test_from_node_metadata_export(self):
575+
class Foo(torch.nn.Module):
576+
def __init__(self) -> None:
577+
super().__init__()
578+
self.conv1d = torch.nn.Conv1d(3, 3, 3)
579+
self.conv2d = torch.nn.Conv2d(3, 3, 3)
580+
581+
def forward(self, x):
582+
x = self.conv2d(x)
583+
x = x.squeeze(0)
584+
x = self.conv1d(x)
585+
return x
586+
587+
def example_inputs(self):
588+
return
589+
590+
f = Foo()
591+
inputs = (torch.randn(1, 3, 5, 5),)
592+
gm = export(f, inputs).module()
593+
from torch.fx.traceback import NodeSourceAction
594+
595+
for node in gm.graph.nodes:
596+
if node.op in ("placeholder", "output"):
597+
continue
598+
if "weight" in node.name or "bias" in node.name:
599+
self.assertTrue(
600+
node.meta["from_node"][-1].pass_name
601+
== "ExportedProgram.module().unlift()"
602+
)
603+
self.assertTrue(
604+
node.meta["from_node"][-1].action
605+
== [NodeSourceAction.CREATE, NodeSourceAction.REPLACE]
606+
)
607+
else:
608+
self.assertTrue(
609+
node.meta["from_node"][-1].pass_name == "ExportedProgram.module()"
610+
)
611+
self.assertTrue(
612+
node.meta["from_node"][-1].action == [NodeSourceAction.CREATE]
613+
)
614+
615+
## re-export
616+
gm2 = export(gm, inputs).module()
617+
618+
for node in gm2.graph.nodes:
619+
if node.op in ("placeholder", "output"):
620+
continue
621+
if "weight" in node.name or "bias" in node.name:
622+
self.assertTrue(
623+
node.meta["from_node"][-1].pass_name
624+
== "ExportedProgram.module().unlift()"
625+
)
626+
self.assertTrue(
627+
node.meta["from_node"][-1].action
628+
== [NodeSourceAction.CREATE, NodeSourceAction.REPLACE]
629+
)
630+
else:
631+
self.assertTrue(
632+
node.meta["from_node"][-1].pass_name == "ExportedProgram.module()"
633+
)
634+
self.assertTrue(
635+
node.meta["from_node"][-1].action == [NodeSourceAction.CREATE]
636+
)
637+
574638
def test_bincount(self):
575639
class M(torch.nn.Module):
576640
def __init__(self):

torch/export/_unlift.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.export.unflatten import _assign_attr, _AttrKind
1717
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
1818
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
19+
from torch.fx.traceback import NodeSource, NodeSourceAction
1920

2021
from ._remove_effect_tokens_pass import _remove_effect_tokens
2122
from ._tree_utils import reorder_kwargs
@@ -115,6 +116,13 @@ def _unlift_inputs_as_getattr(
115116
metadata = input_node.meta
116117
gm.graph.erase_node(input_node)
117118
getattr_node.meta = metadata
119+
getattr_node.meta["from_node"] = [
120+
NodeSource(
121+
input_node,
122+
"ExportedProgram.module().unlift()",
123+
[NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
124+
)
125+
]
118126
unlifted_name_to_node[lifted_node] = getattr_node
119127

120128
return unlifted_name_to_node, input_name_to_node
@@ -172,6 +180,13 @@ def _insert_copy_for_mutations(
172180
gm.graph.erase_node(output_node)
173181
new_output.name = output_node.name
174182
new_output.meta.update(output_node.meta)
183+
new_output.meta["from_node"] = [
184+
NodeSource(
185+
output_node,
186+
"ExportedProgram.module().unlift()",
187+
[NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
188+
)
189+
]
175190

176191

177192
def _get_codegen(
@@ -446,6 +461,11 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
446461
for out_spec in ep.graph_signature.output_specs
447462
]
448463

464+
for node in new_gm.graph.nodes:
465+
node.meta["from_node"] = [
466+
NodeSource(node, "ExportedProgram.module()", NodeSourceAction.CREATE)
467+
]
468+
449469
new_gm = _unlift(
450470
new_gm,
451471
lifted_inputs,

0 commit comments

Comments
 (0)