@@ -119,10 +119,7 @@ def get_inputs(
119119 0 : batch ,
120120 1 : "cache+seq" , # cache_length + seq_length
121121 },
122- "cache_params" : [
123- [{0 : batch } for _ in range (num_hidden_layers )],
124- [{0 : batch } for _ in range (num_hidden_layers )],
125- ],
122+ "cache_params" : [{0 : batch } for _ in range (num_hidden_layers * 2 )],
126123 }
127124 inputs = dict (
128125 input_ids = torch .randint (
@@ -176,12 +173,7 @@ def get_inputs(
176173 "input_ids" : {0 : batch , 1 : seq_length },
177174 "attention_mask" : {0 : batch , 2 : "seq" },
178175 "cache_position" : {0 : "seq" },
179- "past_key_values" : [
180- # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181- # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
182- [{0 : batch } for _ in range (num_hidden_layers )],
183- [{0 : batch } for _ in range (num_hidden_layers )],
184- ],
176+ "past_key_values" : [{0 : batch } for _ in range (num_hidden_layers * 2 )],
185177 }
186178 inputs = dict (
187179 input_ids = torch .randint (
@@ -222,8 +214,7 @@ def get_inputs(
222214 },
223215 "position_ids" : {0 : batch , 1 : seq_length },
224216 "past_key_values" : [
225- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
226- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
217+ {0 : batch , 2 : cache_length } for _ in range (num_hidden_layers * 2 )
227218 ],
228219 }
229220
0 commit comments