@@ -173,37 +173,34 @@ def get_inputs(
173173 # static
174174 shapes = {
175175 "input_ids" : {0 : batch , 1 : seq_length },
176- "attention_mask" : {0 : batch , 2 : "sequence_length+ past_sequence_length" },
177- "cache_position" : {0 : "sequence_length+ past_sequence_length" },
176+ "attention_mask" : {0 : batch , 2 : "past_sequence_length" },
177+ "cache_position" : {0 : "past_sequence_length" },
178178 "past_key_values" : [
179- # [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
180- # [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
179+ # past_sequence_length is now static
181180 [{0 : batch } for _ in range (num_hidden_layers )],
182181 [{0 : batch } for _ in range (num_hidden_layers )],
183182 ],
184183 }
185184 inputs = dict (
186185 input_ids = torch .randint (
187- 0 , dummy_max_token_id , (batch_size , sequence_length )
186+ 0 , dummy_max_token_id , (batch_size , past_sequence_length )
188187 ).to (torch .int64 ),
189188 attention_mask = torch .ones (
190189 (
191190 batch_size ,
192191 num_key_value_heads ,
193- past_sequence_length + sequence_length ,
192+ past_sequence_length ,
194193 head_dim ,
195194 )
196195 ).to (torch .bool ),
197- cache_position = torch .arange (past_sequence_length + sequence_length ).to (
198- torch .int64
199- ),
196+ cache_position = torch .arange (past_sequence_length ).to (torch .int64 ),
200197 past_key_values = make_static_cache (
201198 [
202199 (
203200 torch .randn (
204201 batch_size ,
205202 num_key_value_heads ,
206- past_sequence_length + sequence_length ,
203+ sequence_length + past_sequence_length ,
207204 head_dim ,
208205 ),
209206 torch .randn (
@@ -215,7 +212,7 @@ def get_inputs(
215212 )
216213 for i in range (num_hidden_layers )
217214 ],
218- max_cache_len = max (sequence_length + past_sequence_length , head_dim ),
215+ max_cache_len = max (past_sequence_length , head_dim ),
219216 ),
220217 )
221218 else :
0 commit comments