1+ import ast
2+ import inspect
3+ import textwrap
14import unittest
25import torch
36from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
4- from onnx_diagnostic .torch_export_patches .patch_module import transform_method
7+ from onnx_diagnostic .torch_export_patches .patch_module import (
8+ transform_method ,
9+ inplace_add_parent ,
10+ )
511
612
713class TestPatchModule (ExtTestCase ):
14+ def test_parent (self ):
15+ class Model (torch .nn .Module ):
16+ def forward (self , x , y ):
17+ if x .sum () > 0 :
18+ return x + y
19+ else :
20+ return torch .abs (x ) + y + 1
21+
22+ src = inspect .getsource (Model .forward )
23+ tree = ast .parse (textwrap .dedent (src ))
24+ inplace_add_parent (tree )
25+ assert all (
26+ hasattr (node , "parent" ) for node in ast .walk (tree )
27+ ), f"Missing parent in { ast .dump (tree , indent = 2 )} "
28+
829 def test_rewrite_forward_return1 (self ):
930
1031 class Model (torch .nn .Module ):
@@ -71,7 +92,7 @@ def forward(self, x, y):
7192 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
7293 expected , expected_ = Model ()(x , y ), Model ()(- x , y )
7394
74- rewritten = transform_method (Model .forward , verbose = 0 )
95+ rewritten = transform_method (Model .forward , verbose = self . verbose )
7596 self .assertIn ("torch.abs(" , rewritten .code )
7697 self .assertIn ("abs" , rewritten .dump )
7798 Model .forward = rewritten .func
@@ -98,7 +119,7 @@ def forward(self, x, y):
98119 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
99120 expected , expected_ = Model ()(x , y ), Model ()(- x , y )
100121
101- rewritten = transform_method (Model .forward , verbose = 0 )
122+ rewritten = transform_method (Model .forward , verbose = self . verbose )
102123 self .assertIn ("torch.abs(" , rewritten .code )
103124 self .assertIn ("abs" , rewritten .dump )
104125 Model .forward = rewritten .func
@@ -123,7 +144,7 @@ def forward(self, x, y):
123144 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
124145 expected , expected_ = Model ()(x , y ), Model ()(- x , y )
125146
126- rewritten = transform_method (Model .forward , verbose = 0 )
147+ rewritten = transform_method (Model .forward , verbose = self . verbose )
127148 self .assertIn ("torch.abs(" , rewritten .code )
128149 self .assertIn ("abs" , rewritten .dump )
129150 Model .forward = rewritten .func
@@ -146,7 +167,7 @@ def forward(self, x, y):
146167 return x + y
147168
148169 self .assertRaise (
149- lambda : transform_method (Model .forward , verbose = 0 ), NotImplementedError
170+ lambda : transform_method (Model .forward , verbose = self . verbose ), NotImplementedError
150171 )
151172
152173 def test_rewrite_forward_assign2_in_2 (self ):
@@ -164,7 +185,7 @@ def forward(self, x, y):
164185 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
165186 expected , expected_ = Model ()(x , y ), Model ()(- x , y )
166187
167- rewritten = transform_method (Model .forward , verbose = 0 )
188+ rewritten = transform_method (Model .forward , verbose = self . verbose )
168189 self .assertIn ("torch.abs(" , rewritten .code )
169190 self .assertIn ("abs" , rewritten .dump )
170191 Model .forward = rewritten .func
@@ -194,7 +215,7 @@ def forward(self, x, y):
194215 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
195216 expected , expected_ = Model ()(x , y ), Model ()(- x , y )
196217
197- rewritten = transform_method (Model .forward , verbose = 1 )
218+ rewritten = transform_method (Model .forward , verbose = self . verbose )
198219 self .assertIn ("torch.abs(" , rewritten .code )
199220 self .assertIn ("abs" , rewritten .dump )
200221 code = rewritten .code
0 commit comments