@@ -71,6 +71,7 @@ def get_inputs(
7171 num_key_value_heads : Optional [int ] = None ,
7272 head_dim : Optional [int ] = None ,
7373 cls_cache : Optional [Union [type , str ]] = None ,
74+ add_second_input : bool = False ,
7475 ** kwargs , # unused
7576):
7677 """
@@ -88,6 +89,7 @@ def get_inputs(
8889 :class:`transformers.cache_utils.DynamicCache`
8990 :return: dictionary
9091 """
92+ assert not add_second_input , "add_second_input=True not yet implemented"
9193 batch = torch .export .Dim ("batch" , min = 1 , max = 1024 )
9294 seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
9395 cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -192,7 +194,23 @@ def get_inputs(
192194 ]
193195 ),
194196 )
195- return dict (inputs = inputs , dynamic_shapes = shapes )
197+ res = dict (inputs = inputs , dynamic_shapes = shapes )
198+ if add_second_input :
199+ res ["inputs2" ] = get_inputs (
200+ model = model ,
201+ config = config ,
202+ dummy_max_token_id = dummy_max_token_id ,
203+ num_hidden_layers = num_hidden_layers ,
204+ batch_size = batch_size + 1 ,
205+ sequence_length = sequence_length + 1 ,
206+ sequence_length2 = sequence_length2 + 1 ,
207+ dynamic_rope = dynamic_rope ,
208+ num_key_value_heads = num_key_value_heads ,
209+ head_dim = head_dim ,
210+ cls_cache = cls_cache ,
211+ ** kwargs ,
212+ )
213+ return res
196214
197215
198216def random_input_kwargs (config : Any ) -> Tuple [Dict [str , Any ], Callable ]:
0 commit comments