11import copy
2+ import os
23import unittest
4+ from typing import Any , Dict , List , Tuple
35import torch
4- from onnx_diagnostic .ext_test_case import ExtTestCase , ignore_warnings , hide_stdout
6+ from onnx_diagnostic .ext_test_case import (
7+ ExtTestCase ,
8+ ignore_warnings ,
9+ hide_stdout ,
10+ requires_torch ,
11+ )
512from onnx_diagnostic .helpers import string_type
613from onnx_diagnostic .cache_helpers import make_dynamic_cache
714from onnx_diagnostic .torch_export_patches .onnx_export_errors import (
@@ -45,21 +52,12 @@ def forward(self, x, cache):
4552 expected = model (* inputs )
4653
4754 DYN = torch .export .Dim .DYNAMIC
48- ep = torch .export .export (
49- model ,
50- inputs ,
51- dynamic_shapes = ({0 : DYN , 2 : DYN }, [[{0 : DYN , 2 : DYN }], [{0 : DYN , 2 : DYN }]]),
52- strict = strict ,
53- )
54- mod = ep .module ()
55- got = mod (* inputs )
56- self .assertEqualArray (expected , got )
5755
5856 # patching
5957 with bypass_export_some_errors (patch_transformers = True ):
6058 got = model (* inputs )
6159 self .assertEqualArray (expected , got )
62- ep2 = torch .export .export (
60+ ep = torch .export .export (
6361 model ,
6462 inputs ,
6563 dynamic_shapes = (
@@ -68,11 +66,201 @@ def forward(self, x, cache):
6866 ),
6967 strict = strict ,
7068 )
71- mod = ep2 .module ()
69+ mod = ep .module ()
7270 got = mod (* inputs )
7371 self .assertEqualArray (expected , got )
7472
73+ class MyInterpreter (torch .fx .Interpreter ):
74+ def call_function (self , target , args , kwargs ):
75+ res = super ().call_function (target , args , kwargs )
76+ return res
77+
78+ args , _spec = torch .utils ._pytree .tree_flatten (inputs )
79+ got = MyInterpreter (ep .module ()).run (* args )
80+ self .assertEqualAny (expected , got )
81+
82+ @ignore_warnings (UserWarning )
83+ def test_export_mycache_list_cat (self ):
84+ TreeContext = torch .utils ._pytree .Context
85+ MappingKey = torch .utils ._pytree .MappingKey
86+ KeyEntry = torch .utils ._pytree .KeyEntry
87+
88+ class MyCache77 :
89+ def __init__ (self , key = None , value = None ):
90+ self .key_cache = [key ] if key is not None else []
91+ self .value_cache = [value ] if value is not None else []
92+
93+ class ModelMyCache (torch .nn .Module ):
94+ def forward (self , x , dc ):
95+ y = (
96+ (
97+ torch .cat (dc .key_cache , axis = 1 ) + torch .cat (dc .value_cache , axis = 1 )
98+ ).reshape ((- 1 , x .shape [1 ]))
99+ ).transpose (1 , 0 )
100+ return x @ y
101+
102+ inputs = {
103+ "x" : torch .randn (3 , 8 ),
104+ "dc" : MyCache77 (torch .ones ((3 , 8 , 3 , 8 )), torch .ones ((3 , 8 , 3 , 8 ))),
105+ }
106+ model = ModelMyCache ()
107+ expected = model (** inputs )
108+
109+ def flatten_my_cache77 (cache : MyCache77 ) -> Tuple [List [Any ], TreeContext ]:
110+ flat = [
111+ (k , getattr (cache , k ))
112+ for k in ["key_cache" , "value_cache" ]
113+ if hasattr (cache , k )
114+ ]
115+ return [f [1 ] for f in flat ], [f [0 ] for f in flat ]
116+
117+ def flatten_with_keys_my_cache77 (
118+ d : Dict [Any , Any ],
119+ ) -> Tuple [List [Tuple [KeyEntry , Any ]], TreeContext ]:
120+ values , context = flatten_my_cache77 (d )
121+ return [(MappingKey (k ), v ) for k , v in zip (context , values )], context
122+
123+ def unflatten_my_cache_77 (
124+ values : List [Any ], context : TreeContext , output_type = None
125+ ) -> MyCache77 :
126+ cache = MyCache77 ()
127+ values = dict (zip (context , values ))
128+ for k , v in values .items ():
129+ setattr (cache , k , v )
130+ return cache
131+
132+ torch .utils ._pytree .register_pytree_node (
133+ MyCache77 ,
134+ flatten_my_cache77 ,
135+ unflatten_my_cache_77 ,
136+ serialized_type_name = "MyCache77" ,
137+ flatten_with_keys_fn = flatten_with_keys_my_cache77 ,
138+ )
139+
140+ # DYN = torch.export.Dim.DYNAMIC
141+ ep = torch .export .export (model , (), kwargs = inputs )
142+
143+ args , _spec = torch .utils ._pytree .tree_flatten (inputs )
144+ got = torch .fx .Interpreter (ep .module ()).run (* args )
145+ self .assertEqualAny (expected , got )
146+
147+ mod = ep .module ()
148+ got = mod (** inputs )
149+ self .assertEqualArray (expected , got )
150+
151+ @ignore_warnings (UserWarning )
152+ def test_export_mycache_dict_cat (self ):
153+ TreeContext = torch .utils ._pytree .Context
154+
155+ class MyCache78 :
156+ def __init__ (self , key = None , value = None ):
157+ self .key_cache = [key ] if key is not None else []
158+ self .value_cache = [value ] if value is not None else []
159+
160+ class ModelMyCache (torch .nn .Module ):
161+ def forward (self , x , dc ):
162+ y = (
163+ (
164+ torch .cat (dc .key_cache , axis = 1 ) + torch .cat (dc .value_cache , axis = 1 )
165+ ).reshape ((- 1 , x .shape [1 ]))
166+ ).transpose (1 , 0 )
167+ return x @ y
168+
169+ inputs = {
170+ "x" : torch .randn (3 , 8 ),
171+ "dc" : MyCache78 (torch .ones ((3 , 8 , 3 , 8 )), torch .ones ((3 , 8 , 3 , 8 ))),
172+ }
173+ model = ModelMyCache ()
174+ expected = model (** inputs )
175+
176+ def flatten_my_cache78 (cache : MyCache78 ):
177+ dictionary = {
178+ "key_cache" : cache .key_cache ,
179+ "value_cache" : cache .value_cache ,
180+ }
181+ return torch .utils ._pytree ._dict_flatten (dictionary )
182+
183+ def flatten_with_keys_my_cache78 (cache : MyCache78 ):
184+ dictionary = {
185+ "key_cache" : cache .key_cache ,
186+ "value_cache" : cache .value_cache ,
187+ }
188+ return torch .utils ._pytree ._dict_flatten_with_keys (dictionary )
189+
190+ def unflatten_my_cache_78 (values , context : TreeContext , output_type = None ) -> MyCache78 :
191+ dictionary = torch .utils ._pytree ._dict_unflatten (values , context )
192+ cache = MyCache78 ()
193+ for k , v in dictionary .items ():
194+ setattr (cache , k , v )
195+ return cache
196+
197+ torch .utils ._pytree .register_pytree_node (
198+ MyCache78 ,
199+ flatten_my_cache78 ,
200+ unflatten_my_cache_78 ,
201+ serialized_type_name = "MyCache78" ,
202+ flatten_with_keys_fn = flatten_with_keys_my_cache78 ,
203+ )
204+
205+ # DYN = torch.export.Dim.DYNAMIC
206+ ep = torch .export .export (model , (), kwargs = inputs )
207+
208+ args , _spec = torch .utils ._pytree .tree_flatten (inputs )
209+ got = torch .fx .Interpreter (ep .module ()).run (* args )
210+ self .assertEqualAny (expected , got )
211+
212+ mod = ep .module ()
213+ got = mod (** inputs )
214+ self .assertEqualArray (expected , got )
215+
75216 @ignore_warnings (UserWarning )
217+ def test_export_dynamic_cache_cat (self ):
218+
219+ class ModelDynamicCache (torch .nn .Module ):
220+ def forward (self , x , dc ):
221+ y = (
222+ (
223+ torch .cat (dc .key_cache , axis = 1 ) + torch .cat (dc .value_cache , axis = 1 )
224+ ).reshape ((- 1 , x .shape [1 ]))
225+ ).transpose (1 , 0 )
226+ return x @ y
227+
228+ inputs = {
229+ "x" : torch .randn (3 , 8 ),
230+ "dc" : make_dynamic_cache (
231+ [(torch .ones ((3 , 8 , 3 , 8 )), (torch .ones ((3 , 8 , 3 , 8 )) * 2 ))]
232+ ),
233+ }
234+ model = ModelDynamicCache ()
235+ expected = model (** inputs )
236+
237+ # DYN = torch.export.Dim.DYNAMIC
238+ NOBYPASS = int (os .environ .get ("NOBYBASS" , "0" ))
239+ if NOBYPASS :
240+ ep = torch .export .export (model , (), kwargs = inputs )
241+
242+ args , _spec = torch .utils ._pytree .tree_flatten (inputs )
243+ got = torch .fx .Interpreter (ep .module ()).run (* args )
244+ self .assertEqualAny (expected , got )
245+
246+ mod = ep .module ()
247+ got = mod (** inputs )
248+ self .assertEqualArray (expected , got )
249+ return
250+
251+ with bypass_export_some_errors (patch_transformers = True ):
252+ ep = torch .export .export (model , (), kwargs = inputs )
253+
254+ args , _spec = torch .utils ._pytree .tree_flatten (inputs )
255+ got = torch .fx .Interpreter (ep .module ()).run (* args )
256+ self .assertEqualAny (expected , got )
257+
258+ mod = ep .module ()
259+ got = mod (** inputs )
260+ self .assertEqualArray (expected , got )
261+
262+ @ignore_warnings (UserWarning )
263+ @requires_torch ("2.9" )
76264 def test_phi2_export_module (self ):
77265 data = get_untrained_model_with_inputs ("microsoft/phi-2" )
78266 model , inputs , dyn_shapes = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
@@ -100,6 +288,7 @@ def test_phi2_export_module(self):
100288 dynamic_shapes = dyn_shapes ,
101289 strict = False , # True works but then the it fails during the execution
102290 )
291+ # ep = ep.run_decompositions()
103292 mod = ep .module ()
104293 inputs_copied = copy .deepcopy (inputs )
105294 self .assertEqual (
@@ -108,15 +297,8 @@ def test_phi2_export_module(self):
108297 got = mod (** inputs_copied )
109298 self .assertEqualAny (expected , got )
110299
111- inputs_copied = copy .deepcopy (inputs )
112- self .assertEqual (
113- str_inputs , string_type (inputs_copied , with_shape = True , with_min_max = True )
114- )
115- mod = ep .module ()
116- got = mod (** inputs_copied )
117- self .assertEqualAny (expected , got )
118-
119300 @ignore_warnings (UserWarning )
301+ @requires_torch ("2.9" )
120302 def test_phi2_export_interpreter (self ):
121303 data = get_untrained_model_with_inputs ("microsoft/phi-2" )
122304 model , inputs , dyn_shapes = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
@@ -144,6 +326,7 @@ def test_phi2_export_interpreter(self):
144326 dynamic_shapes = dyn_shapes ,
145327 strict = False , # True works but then the it fails during the execution
146328 )
329+ # ep = ep.run_decompositions()
147330
148331 # from experimental_experiment.torch_interpreter.tracing import CustomTracer
149332 # CustomTracer.remove_unnecessary_slices(ep.graph)
0 commit comments