@@ -300,6 +300,11 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
300
300
model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
301
301
if self .is_huggingface_model :
302
302
model_inputs .update (self ._huggingface_model_input_values )
303
+ if len (prev_step_hyp_idxs ) > 1 and model_inputs ["past_key_values" ] is not None :
304
+ model_inputs ["past_key_values" ] = self .model ._reorder_cache (
305
+ model_inputs ["past_key_values" ],
306
+ torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
307
+ )
303
308
304
309
# Forward pass
305
310
outputs = self .model (** model_inputs )
@@ -311,14 +316,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
311
316
# HF optimizations to reduce overhead in future `forward` calls
312
317
if self .is_huggingface_model :
313
318
new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
314
- if new_model_kwargs ["past" ] is not None and len (prev_step_hyp_idxs ) > 1 :
315
- if len (prev_step_hyp_idxs ) == 9 :
316
- import pdb
317
- pdb .set_trace ()
318
- new_model_kwargs ["past" ] = self .model ._reorder_cache (
319
- new_model_kwargs ["past" ],
320
- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
321
- )
322
319
323
320
# Keep track of probabilities over vocab for this pairing
324
321
# TODO: clean up duplicate code in these branches
0 commit comments