Skip to content

Commit 12fdb93

Browse files
avikchaudhuripytorchmergebot
authored andcommitted
fix non-strict placeholder naming with kwargs (pytorch#144278)
Fixes pytorch#143732 Differential Revision: [D67872055](https://our.internmc.facebook.com/intern/diff/D67872055/) Pull Request resolved: pytorch#144278 Approved by: https://github.com/yushangdi, https://github.com/pianpwk
1 parent c3b2849 commit 12fdb93

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

test/export/test_export.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9752,6 +9752,47 @@ def forward(self, x):
97529752
return (foo_functional,)""",
97539753
)
97549754

9755+
def test_placeholder_naming_order(self):
9756+
# See https://github.com/pytorch/pytorch/issues/143732
9757+
9758+
class Mod(torch.nn.Module):
9759+
def __init__(self):
9760+
super().__init__()
9761+
self.layer1 = torch.nn.Linear(3, 16)
9762+
self.layer2 = torch.nn.Linear(3, 32)
9763+
9764+
def forward(self, x1, x2, flag=True):
9765+
x1o = self.layer1(x1)
9766+
x2o = self.layer2(x2)
9767+
return torch.cat([x1o, x2o], dim=1)
9768+
9769+
mod = Mod()
9770+
args = (torch.rand(1, 3),)
9771+
kwargs = {"flag": False, "x2": torch.rand(1, 3)}
9772+
ep = export(mod, args, kwargs)
9773+
9774+
# check that graph is behaviorally correct
9775+
self.assertTrue(
9776+
torch.allclose(ep.module()(*args, **kwargs), mod(*args, **kwargs))
9777+
)
9778+
9779+
# check that graph input names are as expected
9780+
self.assertEqual(ep.graph_signature.user_inputs, ("x1", False, "x2"))
9781+
9782+
def test_placeholder_naming_order_variadic(self):
9783+
class Mod(torch.nn.Module):
9784+
def forward(self, a, b, c, **kwargs):
9785+
return a - b + c * kwargs["d"]
9786+
9787+
mod = Mod()
9788+
args = (torch.randn(3),)
9789+
kwargs = {"c": torch.randn(3), "b": torch.randn(3), "d": torch.randn(3)}
9790+
ep = export(mod, args, kwargs)
9791+
self.assertTrue(
9792+
torch.allclose(ep.module()(*args, **kwargs), mod(*args, **kwargs))
9793+
)
9794+
self.assertEqual(ep.graph_signature.user_inputs, ("a", "c", "b", "d"))
9795+
97559796
def test_placeholder_naming_collisions(self):
97569797
# test collisions between nested user inputs
97579798
class Foo(torch.nn.Module):

torch/_export/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,11 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs):
725725
else:
726726
sig = inspect.signature(mod.forward)
727727

728-
return sig.bind(*fake_args, **fake_kwargs).arguments
728+
# Rather than binding both fake_args and fake_kwargs to sig names, we
729+
# (partially) bind only fake_args, while reusing fake_kwarg names. This
730+
# ensures that fake_kwargs do not get reordered, which is important to
731+
# match flattened user inputs.
732+
return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs}
729733

730734

731735
def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:

0 commit comments

Comments
 (0)