88 ignore_warnings ,
99 hide_stdout ,
1010 requires_torch ,
11- has_transformers ,
1211)
1312from onnx_diagnostic .helpers import string_type
1413from onnx_diagnostic .helpers .cache_helper import make_dynamic_cache , CacheKeyValue
@@ -22,76 +21,72 @@ class TestOnnxExportErrors(ExtTestCase):
2221 @ignore_warnings (UserWarning )
2322 @hide_stdout ()
2423 def test_export_dynamic_cache_update (self ):
25- values = [True , False ] if has_transformers ("4.50" ) else [False ]
26- for strict in self .subloop (values , verbose = 1 ):
27-
28- class SubModelCache (torch .nn .Module ):
29- def forward (self , cache ):
30- cc = CacheKeyValue (cache )
31- # If not patched...
32- # Fails with transformers>=4.54 because function ``parse_processor_args``
33- # relies in inspect and the exporter is not very fond of that.
34- # torch._dynamo.exc.Unsupported: id() with unsupported args
35- # Explanation: Dynamo doesn't know how to trace id()
36- # call with args
37- # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
38- # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
39- # objects from outside the compiled region.
40- # Hint: It may be possible to write Dynamo tracing rules for this code.
41- d = cache .__class__ ()
42- d .update (cc .key_cache [0 ] + 1 , cc .value_cache [0 ] + 2 , 0 )
43- d .update (cc .key_cache [0 ] + 3 , cc .value_cache [0 ] + 5 , 1 )
44- return d
45-
46- class SubModel (torch .nn .Module ):
47- def forward (self , x , cache ):
48- cc = CacheKeyValue (cache )
49- return x + cc .key_cache [0 ] + cc .value_cache [0 ]
50-
51- class Model (torch .nn .Module ):
52- def __init__ (self ):
53- super ().__init__ ()
54- self .sub = SubModel ()
55- self .subcache = SubModelCache ()
56-
57- def forward (self , x , cache ):
58- return self .sub (x , self .subcache (cache ))
59-
60- # no patch
61- cache = make_dynamic_cache (
62- [(torch .ones ((5 , 6 , 5 , 6 )), torch .ones ((5 , 6 , 5 , 6 )) + 2 )]
24+ class SubModelCache (torch .nn .Module ):
25+ def forward (self , cache ):
26+ cc = CacheKeyValue (cache )
27+ # If not patched...
28+ # Fails with transformers>=4.54 because function ``parse_processor_args``
29+ # relies in inspect and the exporter is not very fond of that.
30+ # torch._dynamo.exc.Unsupported: id() with unsupported args
31+ # Explanation: Dynamo doesn't know how to trace id()
32+ # call with args
33+ # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
34+ # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
35+ # objects from outside the compiled region.
36+ # Hint: It may be possible to write Dynamo tracing rules for this code.
37+ d = cache .__class__ ()
38+ d .update (cc .key_cache [0 ] + 1 , cc .value_cache [0 ] + 2 , 0 )
39+ d .update (cc .key_cache [0 ] + 3 , cc .value_cache [0 ] + 5 , 1 )
40+ return d
41+
42+ class SubModel (torch .nn .Module ):
43+ def forward (self , x , cache ):
44+ cc = CacheKeyValue (cache )
45+ y = cc .key_cache [0 ] + cc .value_cache [0 ]
46+ return x + y
47+
48+ class Model (torch .nn .Module ):
49+ def __init__ (self ):
50+ super ().__init__ ()
51+ self .sub = SubModel ()
52+ self .subcache = SubModelCache ()
53+
54+ def forward (self , x , cache ):
55+ return self .sub (x , self .subcache (cache ))
56+
57+ # no patch
58+ cache = make_dynamic_cache ([(torch .ones ((5 , 6 , 5 , 6 )), torch .ones ((5 , 6 , 5 , 6 )) + 2 )])
59+ model = Model ()
60+ inputs = (torch .randn ((5 , 6 , 5 , 6 )), cache )
61+ expected = model (* inputs )
62+
63+ DYN = torch .export .Dim .DYNAMIC
64+
65+ # patching
66+ with torch_export_patches (patch_transformers = True , verbose = 10 ):
67+ got = model (* inputs )
68+ self .assertEqualArray (expected , got )
69+ ep = torch .export .export (
70+ model ,
71+ inputs ,
72+ dynamic_shapes = (
73+ {0 : DYN , 2 : DYN },
74+ [[{0 : DYN , 2 : DYN }], [{0 : DYN , 2 : DYN }]],
75+ ),
76+ strict = False ,
6377 )
64- model = Model ()
65- inputs = (torch .randn ((5 , 6 , 5 , 6 )), cache )
66- expected = model (* inputs )
67-
68- DYN = torch .export .Dim .DYNAMIC
69-
70- # patching
71- with torch_export_patches (patch_transformers = True , verbose = 10 ):
72- got = model (* inputs )
73- self .assertEqualArray (expected , got )
74- ep = torch .export .export (
75- model ,
76- inputs ,
77- dynamic_shapes = (
78- {0 : DYN , 2 : DYN },
79- [[{0 : DYN , 2 : DYN }], [{0 : DYN , 2 : DYN }]],
80- ),
81- strict = strict ,
82- )
83- mod = ep .module ()
84- got = mod (* inputs )
85- self .assertEqualArray (expected , got )
86-
87- class MyInterpreter (torch .fx .Interpreter ):
88- def call_function (self , target , args , kwargs ):
89- res = super ().call_function (target , args , kwargs )
90- return res
91-
92- args , _spec = torch .utils ._pytree .tree_flatten (inputs )
93- got = MyInterpreter (ep .module ()).run (* args )
94- self .assertEqualAny (expected , got )
78+ mod = ep .module ()
79+ got = mod (* inputs )
80+ self .assertEqualArray (expected , got )
81+
82+ class MyInterpreter (torch .fx .Interpreter ):
83+ def call_function (self , target , args , kwargs ):
84+ res = super ().call_function (target , args , kwargs )
85+ return res
86+
87+ args , _spec = torch .utils ._pytree .tree_flatten (inputs )
88+ got = MyInterpreter (ep .module ()).run (* args )
89+ self .assertEqualAny (expected , got )
9590
9691 @ignore_warnings (UserWarning )
9792 @requires_torch (
0 commit comments