11import unittest
22import torch
33from onnx_diagnostic .ext_test_case import ExtTestCase , requires_transformers , requires_torch
4- from onnx_diagnostic .export .shape_helper import (
5- all_dynamic_shape_from_inputs ,
6- guess_dynamic_shapes_from_inputs ,
7- )
4+ from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
5+ from onnx_diagnostic .torch_export_patches import torch_export_patches
6+ from onnx_diagnostic .helpers import flatten_object
87from onnx_diagnostic .helpers .cache_helper import (
98 make_dynamic_cache ,
109 make_sliding_window_cache ,
1110 make_encoder_decoder_cache ,
1211 make_static_cache ,
1312)
14- from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
15- from onnx_diagnostic .torch_export_patches import torch_export_patches
13+ from onnx_diagnostic .export .shape_helper import (
14+ all_dynamic_shapes_from_inputs ,
15+ guess_dynamic_shapes_from_inputs ,
16+ make_fake_with_dynamic_dimensions ,
17+ )
1618
1719
1820class TestShapeHelper (ExtTestCase ):
19-
2021 @requires_transformers ("4.52" )
2122 @requires_torch ("2.7.99" )
2223 def test_all_dynamic_shape_from_cache (self ):
2324 cache = make_dynamic_cache ([(torch .ones ((2 , 2 )), (torch .ones ((2 , 2 )) * 2 ))])
24- ds = all_dynamic_shape_from_inputs (cache )
25+ ds = all_dynamic_shapes_from_inputs (cache )
2526 self .assertEqual ([[{0 : "d_0_0" , 1 : "d_0_1" }], [{0 : "d_1_0" , 1 : "d_1_1" }]], ds )
2627
2728 @requires_torch ("2.7.99" )
@@ -122,17 +123,17 @@ def test_all_dynamic_shape_all_transformers_cache(self):
122123 with torch_export_patches (patch_transformers = True ):
123124 for cache , exds in caches :
124125 with self .subTest (cache_name = cache .__class__ .__name__ ):
125- ds = all_dynamic_shape_from_inputs (cache )
126+ ds = all_dynamic_shapes_from_inputs (cache )
126127 self .assertEqual (exds , ds )
127128
128129 @requires_transformers ("4.52" )
129130 @requires_torch ("2.7.99" )
130- def test_all_dynamic_shape_from_inputs (self ):
131- ds = all_dynamic_shape_from_inputs ((torch .randn ((5 , 6 )), torch .randn ((1 , 6 ))))
131+ def test_all_dynamic_shapes_from_inputs (self ):
132+ ds = all_dynamic_shapes_from_inputs ((torch .randn ((5 , 6 )), torch .randn ((1 , 6 ))))
132133 self .assertEqual (({0 : "d_0_0" , 1 : "d_0_1" }, {0 : "d_1_0" , 1 : "d_1_1" }), ds )
133- ds = all_dynamic_shape_from_inputs ([torch .randn ((5 , 6 )), torch .randn ((1 , 6 ))])
134+ ds = all_dynamic_shapes_from_inputs ([torch .randn ((5 , 6 )), torch .randn ((1 , 6 ))])
134135 self .assertEqual ([{0 : "d_0_0" , 1 : "d_0_1" }, {0 : "d_1_0" , 1 : "d_1_1" }], ds )
135- ds = all_dynamic_shape_from_inputs (
136+ ds = all_dynamic_shapes_from_inputs (
136137 (torch .randn ((5 , 6 )), torch .randn ((1 , 6 ))), dim_prefix = torch .export .Dim .AUTO
137138 )
138139 self .assertEqual (
@@ -145,9 +146,9 @@ def test_all_dynamic_shape_from_inputs(self):
145146
146147 @requires_transformers ("4.52" )
147148 @requires_torch ("2.7.99" )
148- def test_all_dynamic_shape_from_inputs_dynamic_cache (self ):
149+ def test_all_dynamic_shapes_from_inputs_dynamic_cache (self ):
149150 data = get_untrained_model_with_inputs ("arnir0/Tiny-LLM" )
150- ds = all_dynamic_shape_from_inputs (data ["inputs" ])
151+ ds = all_dynamic_shapes_from_inputs (data ["inputs" ])
151152 self .assertEqual (
152153 {
153154 "input_ids" : {0 : "d_0_0" , 1 : "d_0_1" },
@@ -184,6 +185,60 @@ def test_guess_dynamic_shapes_from_inputs(self):
184185 guessed ,
185186 )
186187
188+ @requires_transformers ("4.55" )
189+ @requires_torch ("2.9" )
190+ def test_make_fake_with_dynamic_dimensions_tensor (self ):
191+ res = make_fake_with_dynamic_dimensions (
192+ (torch .rand ((2 , 32 , 30 , 96 ), dtype = torch .float16 ),),
193+ ({0 : "batch" , 2 : "cache_length" },),
194+ )
195+ reshaped = res [0 ][0 ]
196+ self .assertIsInstance (reshaped .shape [0 ], torch .SymInt )
197+ self .assertIsInstance (reshaped .shape [2 ], torch .SymInt )
198+ self .assertEqual (reshaped .shape [1 ], 32 )
199+ self .assertEqual (reshaped .shape [3 ], 96 )
200+ self .assertNotEqual (reshaped .shape [0 ], reshaped .shape [2 ])
201+
202+ @requires_transformers ("4.55" )
203+ @requires_torch ("2.9" )
204+ def test_make_fake_with_dynamic_dimensions_whole (self ):
205+ res = make_fake_with_dynamic_dimensions (
206+ dict (
207+ input_ids = torch .randint (30360 , size = (2 , 3 ), dtype = torch .int64 ),
208+ attention_mask = torch .randint (1 , size = (2 , 33 ), dtype = torch .int64 ),
209+ position_ids = torch .randint (32 , size = (2 , 3 ), dtype = torch .int64 ),
210+ past_key_values = make_dynamic_cache (
211+ [
212+ (
213+ torch .rand ((2 , 32 , 30 , 96 ), dtype = torch .float16 ),
214+ torch .rand ((2 , 32 , 30 , 96 ), dtype = torch .float16 ),
215+ ),
216+ (
217+ torch .rand ((2 , 32 , 30 , 96 ), dtype = torch .float16 ),
218+ torch .rand ((2 , 32 , 30 , 96 ), dtype = torch .float16 ),
219+ ),
220+ ]
221+ ),
222+ ),
223+ dynamic_shapes = {
224+ "input_ids" : {0 : "batch" , 1 : "seq_length" },
225+ "attention_mask" : {0 : "batch" , 1 : "cache+seq" },
226+ "position_ids" : {0 : "batch" , 1 : "seq_length" },
227+ "past_key_values" : [
228+ [{0 : "batch" , 2 : "cache_length" }, {0 : "batch" , 2 : "cache_length" }],
229+ [{0 : "batch" , 2 : "cache_length" }, {0 : "batch" , 2 : "cache_length" }],
230+ ],
231+ },
232+ )
233+ flat = flatten_object (res [0 ], drop_keys = True )
234+ for t in flat :
235+ if len (t .shape ) == 4 :
236+ self .assertIsInstance (t .shape [0 ], torch .SymInt )
237+ self .assertIsInstance (t .shape [2 ], torch .SymInt )
238+ self .assertEqual (t .shape [1 ], 32 )
239+ self .assertEqual (t .shape [3 ], 96 )
240+ self .assertNotEqual (t .shape [0 ], t .shape [2 ])
241+
187242
188243if __name__ == "__main__" :
189244 unittest .main (verbosity = 2 )
0 commit comments