@@ -137,6 +137,81 @@ def forward(self, x, y):
137137 self .assertEqualAny (expected , ep .module ()(x , y ))
138138 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
139139
140+ def test_rewrite_forward_return_noelse (self ):
141+
142+ class Model (torch .nn .Module ):
143+ def forward (self , x , y ):
144+ if x .sum () > 0 :
145+ return torch .abs (x ) + 1 + y
146+ return x + y
147+
148+ self .assertRaise (
149+ lambda : transform_method (Model .forward , verbose = 0 ), NotImplementedError
150+ )
151+
152+ def test_rewrite_forward_assign2_in_2 (self ):
153+
154+ class Model (torch .nn .Module ):
155+ def forward (self , x , y ):
156+ if x .sum () > 0 :
157+ w = x + y
158+ z = x - y
159+ else :
160+ w = torch .abs (x ) + y + 1
161+ z = torch .abs (x ) - y + 1
162+ return w , z
163+
164+ x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
165+ expected , expected_ = Model ()(x , y ), Model ()(- x , y )
166+
167+ rewritten = transform_method (Model .forward , verbose = 0 )
168+ self .assertIn ("torch.abs(" , rewritten .code )
169+ self .assertIn ("abs" , rewritten .dump )
170+ Model .forward = rewritten .func
171+ self .assertEqualAny (expected , Model ()(x , y ))
172+ self .assertEqualAny (expected_ , Model ()(- x , y ))
173+
174+ DYN = torch .export .Dim .DYNAMIC
175+ ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
176+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
177+ self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
178+ self .assertEqualAny (expected , ep .module ()(x , y ))
179+ self .assertEqualAny (expected_ , ep .module ()(- x , y ))
180+
181+ def test_rewrite_forward_assign2_in_3 (self ):
182+
183+ class Model (torch .nn .Module ):
184+ def forward (self , x , y ):
185+ if x .sum () > 0 :
186+ w = x + y
187+ z = x - y
188+ else :
189+ u = y + 1
190+ w = torch .abs (x ) + u
191+ z = torch .abs (x ) - u
192+ return w , z
193+
194+ x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
195+ expected , expected_ = Model ()(x , y ), Model ()(- x , y )
196+
197+ rewritten = transform_method (Model .forward , verbose = 0 )
198+ self .assertIn ("torch.abs(" , rewritten .code )
199+ self .assertIn ("abs" , rewritten .dump )
200+ code = rewritten .code
201+ assert ("w, z, u" in code and "u, w, z" not in code ) or (
202+ "w, z, u" not in code and "u, w, z" in code
203+ ), f"Order mismatch in\n { code } "
204+ Model .forward = rewritten .func
205+ self .assertEqualAny (expected , Model ()(x , y ))
206+ self .assertEqualAny (expected_ , Model ()(- x , y ))
207+
208+ DYN = torch .export .Dim .DYNAMIC
209+ ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
210+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
211+ self .assertIn ("cond" , [str (getattr (n , "target" , "?" )) for n in ep .graph .nodes ])
212+ self .assertEqualAny (expected , ep .module ()(x , y ))
213+ self .assertEqualAny (expected_ , ep .module ()(- x , y ))
214+
140215
141216if __name__ == "__main__" :
142217 unittest .main (verbosity = 2 )
0 commit comments