@@ -295,6 +295,11 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
295
295
model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
296
296
if self .is_huggingface_model :
297
297
model_inputs .update (self ._huggingface_model_input_values )
298
+ if len (prev_step_hyp_idxs ) > 1 and model_inputs ["past_key_values" ] is not None :
299
+ model_inputs ["past_key_values" ] = self .model ._reorder_cache (
300
+ model_inputs ["past_key_values" ],
301
+ torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
302
+ )
298
303
299
304
# Forward pass
300
305
outputs = self .model (** model_inputs )
@@ -306,14 +311,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
306
311
# HF optimizations to reduce overhead in future `forward` calls
307
312
if self .is_huggingface_model :
308
313
new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
309
- if new_model_kwargs ["past" ] is not None and len (prev_step_hyp_idxs ) > 1 :
310
- if len (prev_step_hyp_idxs ) == 9 :
311
- import pdb
312
- pdb .set_trace ()
313
- new_model_kwargs ["past" ] = self .model ._reorder_cache (
314
- new_model_kwargs ["past" ],
315
- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
316
- )
317
314
318
315
# Keep track of probabilities over vocab for this pairing
319
316
# TODO: clean up duplicate code in these branches
0 commit comments