@@ -144,17 +144,18 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
144144 """
145145 See
146146 :func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
147+ If caches are used, it requires ``transformers>=4.57``.
147148 """
148149 if x is None :
149150 return None , None
150- if isinstance ( x , (list , tuple ) ):
151+ if type ( x ) in (list , tuple ):
151152 return x .__class__ (
152153 [
153154 self .make_fake_with_dynamic_dimensions (i , dynamic_shapes = ds )
154155 for i , ds in zip (x , dynamic_shapes )
155156 ]
156157 )
157- if isinstance ( x , dict ) :
158+ if type ( x ) is dict :
158159 return {
159160 k : self .make_fake_with_dynamic_dimensions (v , dynamic_shapes = dynamic_shapes [k ])
160161 for k , v in x .items ()
@@ -187,6 +188,17 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
187188 x .cross_attention_cache , dynamic_shapes = dynamic_shapes [1 ]
188189 )
189190 return x
191+ if x .__class__ .__name__ == "BaseModelOutput" :
192+ assert (
193+ list (x .keys ()) == ["last_hidden_state" ] and x .last_hidden_state is not None
194+ ), (
195+ f"Field 'last_hidden_state' is empty for { type (x )} or other fields "
196+ f"{ list (x .keys ())} are used."
197+ )
198+ x .last_hidden_state = self .make_fake_with_dynamic_dimensions (
199+ x .last_hidden_state , dynamic_shapes = dynamic_shapes [0 ]
200+ )
201+ return x
190202 if hasattr (x , "shape" ):
191203 assert dynamic_shapes is None or isinstance (dynamic_shapes , dict ), (
192204 f"dynamic_shapes must be a dictionary at this stage but "
@@ -197,9 +209,11 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
197209 for idim , dim in enumerate (x .shape ):
198210 if dynamic_shapes is not None and idim in dynamic_shapes :
199211 s = dynamic_shapes [idim ]
212+ if s .__class__ .__name__ == "Dim" :
213+ s = s .__name__
200214 assert isinstance (s , str ), (
201215 f"Unexpected type { type (s )} in dynamic_shapes={ dynamic_shapes } "
202- f"at index { idim } "
216+ f"at index { idim } , self._mapping_str= { self . _mapping_str } "
203217 )
204218 if s in self ._mapping_str :
205219 dim = self ._mapping_str [s ]
@@ -221,6 +235,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
221235 assert t .device == x .device , f"device mismatch { x .device } -> { t .device } "
222236 assert t .dtype == x .dtype , f"dtype mismatch { x .dtype } -> { t .dtype } "
223237 return t
238+ if isinstance (x , (int , bool , float )):
239+ # It is a constant, we don't change that.
240+ return x
224241 from ..helpers import string_type
225242
226243 raise TypeError (
0 commit comments