Skip to content

Commit 413074e

Browse files
committed
fix
1 parent ea5170b commit 413074e

File tree

3 files changed

+170
-51
lines changed

3 files changed

+170
-51
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1+
import ast
2+
import inspect
3+
import textwrap
14
import unittest
25
import torch
36
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4-
from onnx_diagnostic.torch_export_patches.patch_module import transform_method
7+
from onnx_diagnostic.torch_export_patches.patch_module import (
8+
transform_method,
9+
inplace_add_parent,
10+
)
511

612

713
class TestPatchModule(ExtTestCase):
14+
def test_parent(self):
15+
class Model(torch.nn.Module):
16+
def forward(self, x, y):
17+
if x.sum() > 0:
18+
return x + y
19+
else:
20+
return torch.abs(x) + y + 1
21+
22+
src = inspect.getsource(Model.forward)
23+
tree = ast.parse(textwrap.dedent(src))
24+
inplace_add_parent(tree)
25+
assert all(
26+
hasattr(node, "parent") for node in ast.walk(tree)
27+
), f"Missing parent in {ast.dump(tree, indent=2)}"
28+
829
def test_rewrite_forward_return1(self):
930

1031
class Model(torch.nn.Module):
@@ -71,7 +92,7 @@ def forward(self, x, y):
7192
x, y = torch.rand((3, 4)), torch.rand((3, 4))
7293
expected, expected_ = Model()(x, y), Model()(-x, y)
7394

74-
rewritten = transform_method(Model.forward, verbose=0)
95+
rewritten = transform_method(Model.forward, verbose=self.verbose)
7596
self.assertIn("torch.abs(", rewritten.code)
7697
self.assertIn("abs", rewritten.dump)
7798
Model.forward = rewritten.func
@@ -98,7 +119,7 @@ def forward(self, x, y):
98119
x, y = torch.rand((3, 4)), torch.rand((3, 4))
99120
expected, expected_ = Model()(x, y), Model()(-x, y)
100121

101-
rewritten = transform_method(Model.forward, verbose=0)
122+
rewritten = transform_method(Model.forward, verbose=self.verbose)
102123
self.assertIn("torch.abs(", rewritten.code)
103124
self.assertIn("abs", rewritten.dump)
104125
Model.forward = rewritten.func
@@ -123,7 +144,7 @@ def forward(self, x, y):
123144
x, y = torch.rand((3, 4)), torch.rand((3, 4))
124145
expected, expected_ = Model()(x, y), Model()(-x, y)
125146

126-
rewritten = transform_method(Model.forward, verbose=0)
147+
rewritten = transform_method(Model.forward, verbose=self.verbose)
127148
self.assertIn("torch.abs(", rewritten.code)
128149
self.assertIn("abs", rewritten.dump)
129150
Model.forward = rewritten.func
@@ -146,7 +167,7 @@ def forward(self, x, y):
146167
return x + y
147168

148169
self.assertRaise(
149-
lambda: transform_method(Model.forward, verbose=0), NotImplementedError
170+
lambda: transform_method(Model.forward, verbose=self.verbose), NotImplementedError
150171
)
151172

152173
def test_rewrite_forward_assign2_in_2(self):
@@ -164,7 +185,7 @@ def forward(self, x, y):
164185
x, y = torch.rand((3, 4)), torch.rand((3, 4))
165186
expected, expected_ = Model()(x, y), Model()(-x, y)
166187

167-
rewritten = transform_method(Model.forward, verbose=0)
188+
rewritten = transform_method(Model.forward, verbose=self.verbose)
168189
self.assertIn("torch.abs(", rewritten.code)
169190
self.assertIn("abs", rewritten.dump)
170191
Model.forward = rewritten.func
@@ -194,7 +215,7 @@ def forward(self, x, y):
194215
x, y = torch.rand((3, 4)), torch.rand((3, 4))
195216
expected, expected_ = Model()(x, y), Model()(-x, y)
196217

197-
rewritten = transform_method(Model.forward, verbose=1)
218+
rewritten = transform_method(Model.forward, verbose=self.verbose)
198219
self.assertIn("torch.abs(", rewritten.code)
199220
self.assertIn("abs", rewritten.dump)
200221
code = rewritten.code

onnx_diagnostic/ext_test_case.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,11 @@ class ExtTestCase(unittest.TestCase):
731731
_warns: List[Tuple[str, int, Warning]] = []
732732
_todos: List[Tuple[Callable, str]] = []
733733

734+
@property
735+
def verbose(self):
736+
"Returns the the value of environment variable ``VERBOSE``."
737+
return int(os.environ.get("VERBOSE", "0"))
738+
734739
@classmethod
735740
def setUpClass(cls):
736741
logger = logging.getLogger("onnxscript.optimizer.constant_folding")

0 commit comments

Comments
 (0)