88 make_static_cache ,
99 make_sliding_window_cache ,
1010 flatten_unflatten_for_dynamic_shapes ,
11+ make_dynamic_shapes_kv_cache ,
1112 CacheKeyValue ,
1213)
1314from onnx_diagnostic .torch_export_patches .onnx_export_errors import (
@@ -64,8 +65,8 @@ def forward(self, cache):
6465 model (cache )
6566 DYN = torch .export .Dim .DYNAMIC
6667 ds = [
67- [[{ 0 : DYN } , {0 : DYN }, { 0 : DYN }], [{ 0 : DYN }, { 0 : DYN }, { 0 : DYN }]] ,
68- [[{ 0 : DYN } , {0 : DYN }, { 0 : DYN }], [{ 0 : DYN }, { 0 : DYN }, { 0 : DYN }]] ,
68+ make_dynamic_shapes_kv_cache ( cache1 , {0 : DYN }) ,
69+ make_dynamic_shapes_kv_cache ( cache2 , {0 : DYN }) ,
6970 ]
7071
7172 with torch_export_patches (patch_transformers = True ):
@@ -99,9 +100,15 @@ def forward(self, cache):
99100 model = Model ()
100101 model (cache )
101102 DYN = torch .export .Dim .DYNAMIC
102- ds = [[{0 : DYN }, {0 : DYN }, {0 : DYN }], [{0 : DYN }, {0 : DYN }, {0 : DYN }]]
103+ ds = make_dynamic_shapes_kv_cache (cache , {0 : DYN })
104+ self .assertEqual (len (ds ), 6 )
103105
104- with torch_export_patches ():
106+ with torch_export_patches (patch_transformers = True ):
107+ flat , _spec = torch .utils ._pytree .tree_flatten (cache )
108+ self .assertEqual (len (flat ), len (ds ))
109+ unflat = torch .utils ._pytree .tree_unflatten (flat , _spec )
110+ if hasattr (unflat , "layers" ):
111+ self .assertEqual (len (unflat .layers ), 3 )
105112 torch .export .export (model , (cache ,), dynamic_shapes = (ds ,))
106113
107114 @ignore_warnings (UserWarning )
@@ -195,7 +202,7 @@ def forward(self, cache):
195202 model = Model ()
196203 model (cache )
197204 DYN = torch .export .Dim .DYNAMIC
198- ds = [[{ 0 : DYN } , {0 : DYN }], [{ 0 : DYN }, { 0 : DYN }]]
205+ ds = make_dynamic_shapes_kv_cache ( cache , {0 : DYN })
199206
200207 with torch_export_patches (patch_transformers = True ):
201208 torch .export .export (model , (cache ,), dynamic_shapes = (ds ,))
@@ -265,9 +272,7 @@ def test_static_cache(self):
265272 flat , _spec = torch .utils ._pytree .tree_flatten (bo )
266273 unflat = flatten_unflatten_for_dynamic_shapes (bo , use_dict = True )
267274 self .assertIsInstance (unflat , list )
268- self .assertEqual (
269- "#2[#3[T1r4,T1r4,T1r4],#3[T1r4,T1r4,T1r4]]" , self .string_type (unflat )
270- )
275+ self .assertEqual ("#6[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]" , self .string_type (unflat ))
271276
272277 # export
273278 class Model (torch .nn .Module ):
@@ -278,7 +283,7 @@ def forward(self, cache):
278283 model = Model ()
279284 model (bo )
280285 DYN = torch .export .Dim .DYNAMIC
281- ds = [[{ 0 : DYN } , {0 : DYN }, { 0 : DYN }], [{ 0 : DYN }, { 0 : DYN }, { 0 : DYN }]]
286+ ds = make_dynamic_shapes_kv_cache ( bo , {0 : DYN })
282287
283288 with torch_export_patches (patch_transformers = True , stop_if_static = 1 ):
284289 torch .export .export (model , (bo ,), dynamic_shapes = (ds ,))
0 commit comments