22import torch
33import transformers
44from onnx_diagnostic .ext_test_case import ExtTestCase , requires_transformers
5- from onnx_diagnostic .helpers import string_type
5+ from onnx_diagnostic .helpers import string_type , max_diff
66from onnx_diagnostic .helpers .cache_helper import (
77 flatten_unflatten_for_dynamic_shapes ,
88 make_dynamic_cache ,
99 make_encoder_decoder_cache ,
1010 make_mamba_cache ,
1111 make_sliding_window_cache ,
12+ make_static_cache ,
1213)
1314from onnx_diagnostic .export import CoupleInputsDynamicShapes
1415from onnx_diagnostic .torch_export_patches .patch_inputs import (
@@ -104,6 +105,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
104105 ]
105106 ),
106107 )
108+ self .assertEqual (0 , max_diff (c2 , c2 )["abs" ])
107109 self .assertIsInstance (c2 , transformers .cache_utils .EncoderDecoderCache )
108110 flat , _spec = torch .utils ._pytree .tree_flatten (c2 )
109111 self .assertIsInstance (flat , list )
@@ -149,6 +151,7 @@ def test_make_mamba_cache(self):
149151 "ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])" ,
150152 text ,
151153 )
154+ self .assertEqual (0 , max_diff (cache , cache )["abs" ])
152155
153156 def test_make_sliding_window_cache (self ):
154157 cache = make_sliding_window_cache (
@@ -164,6 +167,45 @@ def test_make_sliding_window_cache(self):
164167 "value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])" ,
165168 text ,
166169 )
170+ self .assertEqual (0 , max_diff (cache , cache )["abs" ])
171+
172+ def test_make_static_cache (self ):
173+ cache = make_static_cache (
174+ [
175+ (torch .rand ((4 , 5 , 6 , 7 )), torch .rand ((4 , 5 , 6 , 7 ))),
176+ (torch .rand ((4 , 5 , 6 , 7 )), torch .rand ((4 , 5 , 6 , 7 ))),
177+ (torch .rand ((4 , 5 , 6 , 7 )), torch .rand ((4 , 5 , 6 , 7 ))),
178+ ]
179+ )
180+ text = self .string_type (cache , with_shape = True )
181+ self .assertEqual (
182+ "StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
183+ "value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])" ,
184+ text ,
185+ )
186+ self .assertEqual (0 , max_diff (cache , cache )["abs" ])
187+
188+ def test_unflatten_flatten_static_cache (self ):
189+ with torch_export_patches (patch_transformers = True ):
190+ c2 = make_static_cache (
191+ [
192+ (torch .rand ((4 , 5 , 6 , 7 )), torch .rand ((4 , 5 , 6 , 7 ))),
193+ (torch .rand ((4 , 5 , 6 , 7 )), torch .rand ((4 , 5 , 6 , 7 ))),
194+ (torch .rand ((4 , 5 , 6 , 7 )), torch .rand ((4 , 5 , 6 , 7 ))),
195+ ]
196+ )
197+ self .assertEqual (0 , max_diff (c2 , c2 )["abs" ])
198+ self .assertIsInstance (c2 , transformers .cache_utils .StaticCache )
199+ flat , _spec = torch .utils ._pytree .tree_flatten (c2 )
200+ self .assertIsInstance (flat , list )
201+ self .assertEqual (len (flat ), 6 )
202+ unflat = flatten_unflatten_for_dynamic_shapes (c2 )
203+ self .assertIsInstance (unflat , list )
204+ self .assertEqual (len (unflat ), 2 )
205+ self .assertEqual (
206+ "#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]" ,
207+ self .string_type (unflat , with_shape = True ),
208+ )
167209
168210
169211if __name__ == "__main__" :
0 commit comments