File tree Expand file tree Collapse file tree 2 files changed +28
-2
lines changed
_unittests/ut_torch_export_patches
onnx_diagnostic/torch_export_patches Expand file tree Collapse file tree 2 files changed +28
-2
lines changed Original file line number Diff line number Diff line change 11import unittest
22import torch
33from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
4- from onnx_diagnostic .torch_export_patches import torch_export_patches
4+ from onnx_diagnostic .torch_export_patches import (
5+ torch_export_patches ,
6+ bypass_export_some_errors ,
7+ )
58
69
710class TestPatchBaseClass (ExtTestCase ):
@@ -76,6 +79,28 @@ def m1(self, x):
7679 with torch_export_patches (custom_patches = [patched_Model ], verbose = 10 ):
7780 self .assertEqualArray (x ** 3 , model (x ))
7881
82+ @hide_stdout ()
83+ def test_bypass_export_some_errors (self ):
84+ class Model2 (torch .nn .Module ):
85+ def m2 (self , x ):
86+ return x * x
87+
88+ def forward (self , x ):
89+ return self .m2 (x )
90+
91+ class patched_Model :
92+ _PATCHED_CLASS_ = Model2
93+ _PATCHES_ = ["m2" ]
94+
95+ def m2 (self , x ):
96+ return x ** 3
97+
98+ model = Model2 ()
99+ x = torch .arange (4 )
100+ self .assertEqualArray (x * x , model (x ))
101+ with bypass_export_some_errors (custom_patches = [patched_Model ], verbose = 10 ):
102+ self .assertEqualArray (x ** 3 , model (x ))
103+
79104
80105if __name__ == "__main__" :
81106 unittest .main (verbosity = 2 )
Original file line number Diff line number Diff line change 44)
55
66
7- bypass_export_some_errors = torch_export_patches
7+ # bypass_export_some_errors is the first name given to the patches.
8+ bypass_export_some_errors = torch_export_patches # type: ignore
You can’t perform that action at this time.
0 commit comments