@@ -59,8 +59,8 @@ def get_inputs(
5959 dummy_max_token_id : int ,
6060 num_hidden_layers : int ,
6161 batch_size : int = 2 ,
62- sequence_length : int = 30 ,
63- sequence_length2 : int = 3 ,
62+ past_sequence_length : int = 30 ,
63+ sequence_length : int = 3 ,
6464 dynamic_rope : bool = False ,
6565 num_key_value_heads : Optional [int ] = None ,
6666 head_dim : Optional [int ] = None ,
@@ -76,17 +76,18 @@ def get_inputs(
7676 :param head_dim: last dimension of the cache
7777 :param dummy_max_token_id: dummy max token id
7878 :param batch_size: batch size
79- :param sequence_length: sequence length
80- :param sequence_length2 : new sequence length
79+ :param past_sequence_length: past sequence length
80+ :param sequence_length : new sequence length
8181 :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
8282 :param cls_cache: cache class, by default it is
8383 :class:`transformers.cache_utils.DynamicCache`
8484 :return: dictionary
8585 """
8686 batch = "batch"
87- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
88- cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
87+ seq_length = "seq_length"
88+ past_seq_length = "past_seq_length"
8989
90+ # TODO(team): Is this code block still necessary?
9091 if config is not None and config .__class__ .__name__ == "FalconMambaConfig" :
9192 try :
9293 from transformers .models .mamba .modeling_mamba import MambaCache
@@ -98,23 +99,23 @@ def get_inputs(
9899 MambaCache ,
99100 ), f"Unexpected value for cls_cache={ cls_cache } and config={ config } "
100101 seq_length_multiple = 8
101- sequence_length = (
102- (sequence_length + seq_length_multiple )
102+ past_sequence_length = (
103+ (past_sequence_length + seq_length_multiple )
103104 // seq_length_multiple
104105 * seq_length_multiple
105106 )
106107 # sequence_inc = seq_length_multiple
107- sequence_length2 = seq_length_multiple
108+ sequence_length = seq_length_multiple
108109
109110 shapes = {
110111 "input_ids" : {0 : batch , 1 : "sequence_length" },
111112 "attention_mask" : {
112113 0 : batch ,
113- 1 : "cache+seq" , # cache_length + seq_length
114+ 1 : "cache+seq" , # past_seq_length + seq_length
114115 },
115116 "cache_position" : {
116117 0 : batch ,
117- 1 : "cache+seq" , # cache_length + seq_length
118+ 1 : "cache+seq" , # past_seq_length + seq_length
118119 },
119120 "cache_params" : [
120121 [{0 : batch } for _ in range (num_hidden_layers )],
@@ -123,9 +124,9 @@ def get_inputs(
123124 }
124125 inputs = dict (
125126 input_ids = torch .randint (
126- 0 , dummy_max_token_id , (batch_size , sequence_length + sequence_length2 )
127+ 0 , dummy_max_token_id , (batch_size , past_sequence_length + sequence_length )
127128 ).to (torch .int64 ),
128- attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
129+ attention_mask = torch .ones ((batch_size , past_sequence_length + sequence_length )).to (
129130 torch .int64
130131 ),
131132 cache_position = torch .arange (0 , kwargs ["conv_kernel" ]).to (torch .int64 ),
@@ -167,46 +168,54 @@ def get_inputs(
167168 make_cache = make_dynamic_cache if cache_name is None else make_caches [cache_name ]
168169 is_static = cache_name == "StaticCache"
169170
171+ # TODO(team): Is this code block still necessary?
170172 if is_static :
171173 # static
172174 shapes = {
173175 "input_ids" : {0 : batch , 1 : seq_length },
174- "attention_mask" : {0 : batch , 2 : "seq " },
175- "cache_position" : {0 : "seq " },
176+ "attention_mask" : {0 : batch , 2 : "sequence_length+past_sequence_length " },
177+ "cache_position" : {0 : "sequence_length+past_sequence_length " },
176178 "past_key_values" : [
177- # [{0: batch, 2: cache_length } for _ in range(num_hidden_layers)],
178- # [{0: batch, 2: cache_length } for _ in range(num_hidden_layers)],
179+ # [{0: batch, 2: past_seq_length } for _ in range(num_hidden_layers)],
180+ # [{0: batch, 2: past_seq_length } for _ in range(num_hidden_layers)],
179181 [{0 : batch } for _ in range (num_hidden_layers )],
180182 [{0 : batch } for _ in range (num_hidden_layers )],
181183 ],
182184 }
183185 inputs = dict (
184186 input_ids = torch .randint (
185- 0 , dummy_max_token_id , (batch_size , sequence_length2 )
187+ 0 , dummy_max_token_id , (batch_size , sequence_length )
186188 ).to (torch .int64 ),
187189 attention_mask = torch .ones (
188- (batch_size , num_key_value_heads , sequence_length2 , head_dim )
190+ (
191+ batch_size ,
192+ num_key_value_heads ,
193+ past_sequence_length + sequence_length ,
194+ head_dim ,
195+ )
189196 ).to (torch .bool ),
190- cache_position = torch .arange (sequence_length2 ).to (torch .int64 ),
197+ cache_position = torch .arange (past_sequence_length + sequence_length ).to (
198+ torch .int64
199+ ),
191200 past_key_values = make_static_cache (
192201 [
193202 (
194203 torch .randn (
195204 batch_size ,
196205 num_key_value_heads ,
197- sequence_length + sequence_length2 ,
206+ past_sequence_length + sequence_length ,
198207 head_dim ,
199208 ),
200209 torch .randn (
201210 batch_size ,
202211 num_key_value_heads ,
203- sequence_length + sequence_length2 ,
212+ sequence_length + past_sequence_length ,
204213 head_dim ,
205214 ),
206215 )
207216 for i in range (num_hidden_layers )
208217 ],
209- max_cache_len = max (sequence_length + sequence_length2 , head_dim ),
218+ max_cache_len = max (sequence_length + past_sequence_length , head_dim ),
210219 ),
211220 )
212221 else :
@@ -215,53 +224,56 @@ def get_inputs(
215224 "input_ids" : {0 : batch , 1 : seq_length },
216225 "attention_mask" : {
217226 0 : batch ,
218- 1 : "cache+seq" , # cache_length + seq_length
227+ 1 : "cache+seq" , # past_seq_length + seq_length
219228 },
220229 "position_ids" : {
221230 0 : batch ,
222- 1 : "cache+seq" , # cache_length + seq_length
231+ 1 : seq_length ,
223232 },
224- "past_key_values" : [
225- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
226- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
227- ],
228233 }
229234
230235 inputs = dict (
231236 input_ids = torch .randint (
232- 0 , dummy_max_token_id , (batch_size , sequence_length2 )
237+ 0 , dummy_max_token_id , (batch_size , sequence_length )
233238 ).to (torch .int64 ),
234- attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
235- torch .int64
236- ),
237- position_ids = torch .arange (sequence_length , sequence_length + sequence_length2 )
239+ attention_mask = torch .ones (
240+ (batch_size , sequence_length + past_sequence_length )
241+ ).to (torch .int64 ),
242+ position_ids = torch .arange (
243+ past_sequence_length , sequence_length + past_sequence_length
244+ )
238245 .to (torch .int64 )
239246 .expand ((batch_size , - 1 )),
240- past_key_values = make_cache ( # type: ignore[operator]
247+ )
248+ if past_sequence_length > 0 :
249+ inputs ["past_key_values" ] = make_cache (
241250 [
242251 (
243252 torch .randn (
244- batch_size , num_key_value_heads , sequence_length , head_dim
253+ batch_size , num_key_value_heads , past_sequence_length , head_dim
245254 ),
246255 torch .randn (
247- batch_size , num_key_value_heads , sequence_length , head_dim
256+ batch_size , num_key_value_heads , past_sequence_length , head_dim
248257 ),
249258 )
250259 for i in range (num_hidden_layers )
251260 ]
252- ),
253- )
261+ )
262+ shapes ["past_key_values" ] = [
263+ [{0 : batch , 2 : past_seq_length } for _ in range (num_hidden_layers )],
264+ [{0 : batch , 2 : past_seq_length } for _ in range (num_hidden_layers )],
265+ ]
254266 res = dict (inputs = inputs , dynamic_shapes = shapes )
255267 if add_second_input :
268+ # prompt processing (prefill) testing
256269 res ["inputs2" ] = get_inputs (
257270 model = model ,
258271 config = config ,
259272 dummy_max_token_id = dummy_max_token_id ,
260273 num_hidden_layers = num_hidden_layers ,
261- batch_size = (batch_size + 1 ) if add_second_input > 0 else 1 ,
262- sequence_length = sequence_length + 1 ,
263- sequence_length2 = sequence_length2
264- + (add_second_input if add_second_input > 0 else - add_second_input ),
274+ batch_size = batch_size ,
275+ past_sequence_length = 0 ,
276+ sequence_length = 32 ,
265277 dynamic_rope = dynamic_rope ,
266278 num_key_value_heads = num_key_value_heads ,
267279 head_dim = head_dim ,
@@ -276,6 +288,23 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
276288 """
277289 Inputs kwargs.
278290
291+ NOTE: We test two scenarios:
292+ 1. prompt processing (aka prefill):
293+ input_ids=(batch_size, prompt_length)
294+ attn_mask=(batch_size, 0+prompt_length) = (batch_size, prompt_length)
295+ pos_ids=(batch_size, prompt_length)
296+ past_key_values=(batch_size, num_key_value_heads, 0, head_dim)
297+ present_key_values=(batch_size, num_key_value_heads, 0+prompt_length, head_dim)
298+ 2. token generation (aka decode).
299+ input_ids=(batch_size, 1)
300+ attn_mask=(batch_size, past_sequence_length+1)
301+ pos_ids=(batch_size, 1)
302+ past_key_values=(batch_size, num_key_value_heads, past_sequence_length,
303+ head_dim)
304+ present_key_values=(batch_size, num_key_value_heads,
305+ past_sequence_length+1, head_dim)
306+
307+
279308 If the configuration is None, the function selects typical dimensions.
280309 """
281310 if config is not None :
@@ -290,8 +319,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
290319 check_hasattr (config , "conv_kernel" , "state_size" , "intermediate_size" ) # 4 and 8
291320 kwargs = dict (
292321 batch_size = 2 ,
293- sequence_length = 30 ,
294- sequence_length2 = 3 ,
322+ past_sequence_length = 30 ,
323+ sequence_length = 3 ,
295324 dummy_max_token_id = 31999 if config is None else (config .vocab_size - 1 ),
296325 num_hidden_layers = 4 if config is None else config .num_hidden_layers ,
297326 intermediate_size = 256 if config is None else config .intermediate_size ,
@@ -300,10 +329,12 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
300329 conv_kernel = 8 if config is None else getattr (config , "conv_kernel" , None ),
301330 )
302331 else :
332+ # Token generation (decode) testing
333+ # NOTE: We have to export model in decode mode to preserve the cache
303334 kwargs = dict (
304335 batch_size = 2 ,
305- sequence_length = 30 ,
306- sequence_length2 = 3 ,
336+ past_sequence_length = 32 ,
337+ sequence_length = 1 ,
307338 head_dim = (
308339 16
309340 if config is None
0 commit comments