@@ -9752,6 +9752,47 @@ def forward(self, x):
97529752 return (foo_functional,)""" ,
97539753 )
97549754
9755+ def test_placeholder_naming_order (self ):
9756+ # See https://github.com/pytorch/pytorch/issues/143732
9757+
9758+ class Mod (torch .nn .Module ):
9759+ def __init__ (self ):
9760+ super ().__init__ ()
9761+ self .layer1 = torch .nn .Linear (3 , 16 )
9762+ self .layer2 = torch .nn .Linear (3 , 32 )
9763+
9764+ def forward (self , x1 , x2 , flag = True ):
9765+ x1o = self .layer1 (x1 )
9766+ x2o = self .layer2 (x2 )
9767+ return torch .cat ([x1o , x2o ], dim = 1 )
9768+
9769+ mod = Mod ()
9770+ args = (torch .rand (1 , 3 ),)
9771+ kwargs = {"flag" : False , "x2" : torch .rand (1 , 3 )}
9772+ ep = export (mod , args , kwargs )
9773+
9774+ # check that graph is behaviorally correct
9775+ self .assertTrue (
9776+ torch .allclose (ep .module ()(* args , ** kwargs ), mod (* args , ** kwargs ))
9777+ )
9778+
9779+ # check that graph input names are as expected
9780+ self .assertEqual (ep .graph_signature .user_inputs , ("x1" , False , "x2" ))
9781+
9782+ def test_placeholder_naming_order_variadic (self ):
9783+ class Mod (torch .nn .Module ):
9784+ def forward (self , a , b , c , ** kwargs ):
9785+ return a - b + c * kwargs ["d" ]
9786+
9787+ mod = Mod ()
9788+ args = (torch .randn (3 ),)
9789+ kwargs = {"c" : torch .randn (3 ), "b" : torch .randn (3 ), "d" : torch .randn (3 )}
9790+ ep = export (mod , args , kwargs )
9791+ self .assertTrue (
9792+ torch .allclose (ep .module ()(* args , ** kwargs ), mod (* args , ** kwargs ))
9793+ )
9794+ self .assertEqual (ep .graph_signature .user_inputs , ("a" , "c" , "b" , "d" ))
9795+
97559796 def test_placeholder_naming_collisions (self ):
97569797 # test collisions between nested user inputs
97579798 class Foo (torch .nn .Module ):
0 commit comments