@@ -56,6 +56,74 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
5656 return kwargs
5757
5858
59+ def _get_input_falcon_mamba (
60+ model : torch .nn .Module ,
61+ config : Optional [Any ],
62+ dummy_max_token_id : int ,
63+ num_hidden_layers : int ,
64+ batch_size : int = 2 ,
65+ sequence_length : int = 30 ,
66+ sequence_length2 : int = 3 ,
67+ dynamic_rope : bool = False ,
68+ num_key_value_heads : Optional [int ] = None ,
69+ head_dim : Optional [int ] = None ,
70+ cls_cache : Optional [Union [type , str ]] = None ,
71+ ** kwargs , # unused
72+ ):
73+ try :
74+ from transformers .models .mamba .modeling_mamba import MambaCache
75+ except ImportError :
76+ from transformers .cache_utils import MambaCache
77+
78+ assert cls_cache in (
79+ "MambaCache" ,
80+ MambaCache ,
81+ ), f"Unexpected value for cls_cache={ cls_cache } and config={ config } "
82+
83+ batch = "batch"
84+ seq_length_multiple = 8
85+ sequence_length = (
86+ (sequence_length + seq_length_multiple ) // seq_length_multiple * seq_length_multiple
87+ )
88+ # sequence_inc = seq_length_multiple
89+ sequence_length2 = seq_length_multiple
90+
91+ shapes = {
92+ "input_ids" : {0 : batch , 1 : "sequence_length" },
93+ "attention_mask" : {
94+ 0 : batch ,
95+ 1 : "cache+seq" , # cache_length + seq_length
96+ },
97+ "cache_position" : {
98+ 0 : batch ,
99+ 1 : "cache+seq" , # cache_length + seq_length
100+ },
101+ "cache_params" : [{0 : batch } for _ in range (num_hidden_layers * 2 )],
102+ }
103+ inputs = dict (
104+ input_ids = torch .randint (
105+ 0 , dummy_max_token_id , (batch_size , sequence_length + sequence_length2 )
106+ ).to (torch .int64 ),
107+ attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
108+ torch .int64
109+ ),
110+ cache_position = torch .arange (0 , kwargs ["conv_kernel" ]).to (torch .int64 ),
111+ # .expand((batch_size, -1))
112+ cache_params = make_mamba_cache (
113+ [
114+ (
115+ torch .randn (
116+ batch_size , kwargs ["intermediate_size" ], kwargs ["conv_kernel" ]
117+ ),
118+ torch .randn (batch_size , kwargs ["intermediate_size" ], kwargs ["state_size" ]),
119+ )
120+ for i in range (num_hidden_layers )
121+ ]
122+ ),
123+ )
124+ return dict (inputs = inputs , dynamic_shapes = shapes )
125+
126+
59127def get_inputs (
60128 model : torch .nn .Module ,
61129 config : Optional [Any ],
@@ -68,7 +136,7 @@ def get_inputs(
68136 num_key_value_heads : Optional [int ] = None ,
69137 head_dim : Optional [int ] = None ,
70138 cls_cache : Optional [Union [type , str ]] = None ,
71- add_second_input : int = 1 ,
139+ add_second_input : Optional [ int ] = None ,
72140 ** kwargs , # unused
73141):
74142 """
@@ -84,67 +152,28 @@ def get_inputs(
84152 :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
85153 :param cls_cache: cache class, by default it is
86154 :class:`transformers.cache_utils.DynamicCache`
155+ :param add_second_input: adds other kinds of inputs
87156 :return: dictionary
88157 """
89158 batch = "batch"
90159 seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
91160 cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
92161
93162 if config is not None and config .__class__ .__name__ == "FalconMambaConfig" :
94- try :
95- from transformers .models .mamba .modeling_mamba import MambaCache
96- except ImportError :
97- from transformers .cache_utils import MambaCache
98-
99- assert cls_cache in (
100- "MambaCache" ,
101- MambaCache ,
102- ), f"Unexpected value for cls_cache={ cls_cache } and config={ config } "
103- seq_length_multiple = 8
104- sequence_length = (
105- (sequence_length + seq_length_multiple )
106- // seq_length_multiple
107- * seq_length_multiple
108- )
109- # sequence_inc = seq_length_multiple
110- sequence_length2 = seq_length_multiple
111-
112- shapes = {
113- "input_ids" : {0 : batch , 1 : "sequence_length" },
114- "attention_mask" : {
115- 0 : batch ,
116- 1 : "cache+seq" , # cache_length + seq_length
117- },
118- "cache_position" : {
119- 0 : batch ,
120- 1 : "cache+seq" , # cache_length + seq_length
121- },
122- "cache_params" : [{0 : batch } for _ in range (num_hidden_layers * 2 )],
123- }
124- inputs = dict (
125- input_ids = torch .randint (
126- 0 , dummy_max_token_id , (batch_size , sequence_length + sequence_length2 )
127- ).to (torch .int64 ),
128- attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
129- torch .int64
130- ),
131- cache_position = torch .arange (0 , kwargs ["conv_kernel" ]).to (torch .int64 ),
132- # .expand((batch_size, -1))
133- cache_params = make_mamba_cache (
134- [
135- (
136- torch .randn (
137- batch_size , kwargs ["intermediate_size" ], kwargs ["conv_kernel" ]
138- ),
139- torch .randn (
140- batch_size , kwargs ["intermediate_size" ], kwargs ["state_size" ]
141- ),
142- )
143- for i in range (num_hidden_layers )
144- ]
145- ),
163+ res = _get_input_falcon_mamba (
164+ model = model ,
165+ config = config ,
166+ dummy_max_token_id = dummy_max_token_id ,
167+ num_hidden_layers = num_hidden_layers ,
168+ batch_size = batch_size ,
169+ sequence_length = sequence_length ,
170+ sequence_length2 = sequence_length2 ,
171+ dynamic_rope = dynamic_rope ,
172+ num_key_value_heads = num_key_value_heads ,
173+ head_dim = head_dim ,
174+ cls_cache = cls_cache ,
175+ ** kwargs , # unused
146176 )
147- res = dict (inputs = inputs , dynamic_shapes = shapes )
148177 else :
149178 if head_dim is None :
150179 assert config , "head_dim is None, the value cannot be set without a configuration"
@@ -244,6 +273,7 @@ def get_inputs(
244273 )
245274 res = dict (inputs = inputs , dynamic_shapes = shapes )
246275 if add_second_input :
276+ res ["inputs_prompt" ] = dict (input_ids = torch .randint (1000 , 30000 , (1 , 11 )))
247277 res ["inputs2" ] = get_inputs (
248278 model = model ,
249279 config = config ,
0 commit comments