@@ -155,21 +155,31 @@ def make_mamba_cache(
155155 key_value_pairs : List [Tuple [torch .Tensor , torch .Tensor ]],
156156) -> transformers .cache_utils .MambaCache :
157157 "Creates a :class:`transformers.cache_utils.MambaCache`."
158+ dtype = key_value_pairs [0 ][0 ].dtype
158159
159160 class _config :
160161 def __init__ (self ):
161162 self .intermediate_size = key_value_pairs [0 ][0 ].shape [1 ]
162163 self .conv_kernel = key_value_pairs [0 ][0 ].shape [- 1 ]
163164 self .state_size = key_value_pairs [0 ][1 ].shape [- 1 ]
164165 self .num_hidden_layers = len (key_value_pairs )
165- self .dtype = key_value_pairs [ 0 ][ 0 ]. dtype
166+ self .dtype = dtype
166167
167168 cache = transformers .cache_utils .MambaCache (
168169 _config (),
169170 max_batch_size = key_value_pairs [0 ][0 ].shape [0 ],
170171 device = key_value_pairs [0 ][0 ].device ,
172+ dtype = dtype ,
171173 )
172174 for i in range (len (key_value_pairs )):
175+ assert cache .conv_states [i ].dtype == dtype , (
176+ f"Type mismatch for cache.conv_states[{ i } ].dtype="
177+ f"{ cache .conv_states [i ].dtype } != { dtype } "
178+ )
179+ assert cache .ssm_states [i ].dtype == dtype , (
180+ f"Type mismatch for cache.ssm_states[{ i } ].dtype="
181+ f"{ cache .ssm_states [i ].dtype } != { dtype } "
182+ )
173183 assert cache .conv_states [i ].shape == key_value_pairs [i ][0 ].shape , (
174184 f"Shape mismatch, expected { cache .conv_states [i ].shape } , "
175185 f"got { key_value_pairs [i ][0 ].shape } "
0 commit comments