Skip to content

Commit 8a15be7

Browse files
committed
fix
1 parent a5bc880 commit 8a15be7

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

_unittests/ut_torch_export_patches/test_patch_base_class.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4-
from onnx_diagnostic.torch_export_patches import torch_export_patches
4+
from onnx_diagnostic.torch_export_patches import (
5+
torch_export_patches,
6+
bypass_export_some_errors,
7+
)
58

69

710
class TestPatchBaseClass(ExtTestCase):
@@ -76,6 +79,28 @@ def m1(self, x):
7679
with torch_export_patches(custom_patches=[patched_Model], verbose=10):
7780
self.assertEqualArray(x**3, model(x))
7881

82+
@hide_stdout()
83+
def test_bypass_export_some_errors(self):
84+
class Model2(torch.nn.Module):
85+
def m2(self, x):
86+
return x * x
87+
88+
def forward(self, x):
89+
return self.m2(x)
90+
91+
class patched_Model:
92+
_PATCHED_CLASS_ = Model2
93+
_PATCHES_ = ["m2"]
94+
95+
def m2(self, x):
96+
return x**3
97+
98+
model = Model2()
99+
x = torch.arange(4)
100+
self.assertEqualArray(x * x, model(x))
101+
with bypass_export_some_errors(custom_patches=[patched_Model], verbose=10):
102+
self.assertEqualArray(x**3, model(x))
103+
79104

80105
if __name__ == "__main__":
81106
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
)
55

66

7-
bypass_export_some_errors = torch_export_patches
7+
# bypass_export_some_errors is the first name given to the patches.
8+
bypass_export_some_errors = torch_export_patches # type: ignore

0 commit comments

Comments
 (0)