@@ -8,139 +8,134 @@ class TestPatchModule(ExtTestCase):
88 def test_rewrite_forward_return1 (self ):
99
1010 class Model (torch .nn .Module ):
11- def __init__ (self ):
12- super ().__init__ ()
13-
1411 def forward (self , x , y ):
1512 if x .sum () > 0 :
1613 return x + y
1714 else :
18- return torch .abs (x ) + y
15+ return torch .abs (x ) + y + 1
1916
2017 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
21- expected = Model ()(x , y )
18+ expected , expected_ = Model ()(x , y ), Model ()( - x , y )
2219
2320 rewritten = transform_method (Model .forward )
21+ self .assertIn ("torch.abs(" , rewritten .code )
22+ self .assertIn ("'abs'" , rewritten .dump )
2423 Model .forward = rewritten .func
25- Model ()(x , y )
24+ self .assertEqualAny (expected , Model ()(x , y ))
25+ self .assertEqualAny (expected_ , Model ()(- x , y ))
2626
2727 DYN = torch .export .Dim .DYNAMIC
2828 ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
2929 ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
3030 self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
31- got = ep .module ()(x , y )
32- self .assertEqualArray ( expected , got )
31+ self . assertEqualAny ( expected , ep .module ()(x , y ) )
32+ self .assertEqualAny ( expected_ , ep . module ()( - x , y ) )
3333
3434 @hide_stdout ()
3535 def test_rewrite_forward_return2 (self ):
3636
3737 class Model (torch .nn .Module ):
38- def __init__ (self ):
39- super ().__init__ ()
40-
4138 def forward (self , x , y ):
4239 if x .sum () > 0 :
4340 return x + y , x - y
4441 else :
45- return torch .abs (x ) + y , torch .abs (x ) - y
42+ return torch .abs (x ) + y + 1 , torch .abs (x ) - y + 1
4643
4744 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
48- expected = Model ()(x , y )
45+ expected , expected_ = Model ()(x , y ), Model ()( - x , y )
4946
5047 rewritten = transform_method (Model .forward , verbose = 10 )
48+ self .assertIn ("torch.abs(" , rewritten .code )
49+ self .assertIn ("abs" , rewritten .dump )
5150 Model .forward = rewritten .func
52- Model ()(x , y )
51+ self .assertEqualAny (expected , Model ()(x , y ))
52+ self .assertEqualAny (expected_ , Model ()(- x , y ))
5353
5454 DYN = torch .export .Dim .DYNAMIC
5555 ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
5656 ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
5757 self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
58- got = ep .module ()(x , y )
59- self .assertEqualAny (expected , got )
60- self .assertEqualAny (Model ()(- x , y ), ep .module ()(- x , y ))
58+ self .assertEqualAny (expected , ep .module ()(x , y ))
59+ self .assertEqualAny (expected_ , ep .module ()(- x , y ))
6160
6261 def test_rewrite_forward_assign1 (self ):
6362
6463 class Model (torch .nn .Module ):
65- def __init__ (self ):
66- super ().__init__ ()
67-
6864 def forward (self , x , y ):
6965 if x .sum () > 0 :
7066 z = x + y
7167 else :
72- z = torch .abs (x ) + y
68+ z = torch .abs (x ) + y + 1
7369 return z
7470
7571 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
76- expected = Model ()(x , y )
72+ expected , expected_ = Model ()(x , y ), Model ()( - x , y )
7773
7874 rewritten = transform_method (Model .forward , verbose = 0 )
75+ self .assertIn ("torch.abs(" , rewritten .code )
76+ self .assertIn ("abs" , rewritten .dump )
7977 Model .forward = rewritten .func
80- Model ()(x , y )
78+ self .assertEqualAny (expected , Model ()(x , y ))
79+ self .assertEqualAny (expected_ , Model ()(- x , y ))
8180
8281 DYN = torch .export .Dim .DYNAMIC
8382 ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
8483 ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
8584 self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
86- got = ep .module ()(x , y )
87- self .assertEqualArray (expected , got )
88- self .assertEqualArray (Model ()(- x , y ), ep .module ()(- x , y ))
85+ self .assertEqualAny (expected , ep .module ()(x , y ))
86+ self .assertEqualArray (expected_ , ep .module ()(- x , y ))
8987
9088 def test_rewrite_forward_assign2 (self ):
9189
9290 class Model (torch .nn .Module ):
93- def __init__ (self ):
94- super ().__init__ ()
95-
9691 def forward (self , x , y ):
9792 if x .sum () > 0 :
9893 w , z = x + y , x - y
9994 else :
100- w , z = torch .abs (x ) + y , torch .abs (x ) - y
95+ w , z = torch .abs (x ) + y + 1 , torch .abs (x ) - y + 1
10196 return w , z
10297
10398 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
104- expected = Model ()(x , y )
99+ expected , expected_ = Model ()(x , y ), Model ()( - x , y )
105100
106101 rewritten = transform_method (Model .forward , verbose = 0 )
102+ self .assertIn ("torch.abs(" , rewritten .code )
103+ self .assertIn ("abs" , rewritten .dump )
107104 Model .forward = rewritten .func
108- Model ()(x , y )
105+ self .assertEqualAny (expected , Model ()(x , y ))
106+ self .assertEqualAny (expected_ , Model ()(- x , y ))
109107
110108 DYN = torch .export .Dim .DYNAMIC
111109 ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
112110 ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
113111 self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
114- got = ep .module ()(x , y )
115- self .assertEqualAny (expected , got )
116- self .assertEqualAny (Model ()(- x , y ), ep .module ()(- x , y ))
112+ self .assertEqualAny (expected , ep .module ()(x , y ))
113+ self .assertEqualAny (expected_ , ep .module ()(- x , y ))
117114
118115 def test_rewrite_forward_noelse (self ):
119116
120117 class Model (torch .nn .Module ):
121- def __init__ (self ):
122- super ().__init__ ()
123-
124118 def forward (self , x , y ):
125119 if x .sum () > 0 :
126- x = torch .abs (x )
120+ x = torch .abs (x ) + 1
127121 return x + y
128122
129123 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
130- expected = Model ()(x , y )
124+ expected , expected_ = Model ()(x , y ), Model ()( - x , y )
131125
132126 rewritten = transform_method (Model .forward , verbose = 0 )
127+ self .assertIn ("torch.abs(" , rewritten .code )
128+ self .assertIn ("abs" , rewritten .dump )
133129 Model .forward = rewritten .func
134- Model ()(x , y )
130+ self .assertEqualAny (expected , Model ()(x , y ))
131+ self .assertEqualAny (expected_ , Model ()(- x , y ))
135132
136133 DYN = torch .export .Dim .DYNAMIC
137134 ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
138135 ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
139136 self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
140- got = ep .module ()(x , y )
141- self .assertEqualAny (expected , got )
142- self .assertEqualAny (Model ()(- x , y ), ep .module ()(- x , y ))
143- self .assertEqualAny (Model ()(- x , y ), ep .module ()(- x , y ))
137+ self .assertEqualAny (expected , ep .module ()(x , y ))
138+ self .assertEqualAny (expected_ , ep .module ()(- x , y ))
144139
145140
146141if __name__ == "__main__" :
0 commit comments