@@ -280,8 +280,9 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
280
280
), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
281
281
else :
282
282
assert len (prev_model_state_sequences ) == 1
283
- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (num_beams , - 1 ) # TODO: Make this more robust
284
-
283
+ state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
284
+ num_beams , - 1
285
+ ) # TODO: Make this more robust
285
286
286
287
# Cleanup -- combine this with the above
287
288
if self .is_encoder_decoder :
@@ -292,14 +293,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
292
293
)
293
294
294
295
# Preprocess inputs for generation
295
- model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** model_kwargs )
296
+ model_inputs = self .model .prepare_inputs_for_generation (
297
+ token_indices , ** model_kwargs
298
+ ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
296
299
if self .is_huggingface_model :
297
300
model_inputs .update (self ._huggingface_model_input_values )
298
301
if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
299
- model_inputs ["past_key_values" ] = self .model ._reorder_cache (
300
- model_kwargs ["past" ],
301
- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
302
- )
302
+ beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
303
+ model_inputs ["past_key_values" ] = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
303
304
304
305
# Forward pass
305
306
outputs = self .model (** model_inputs )
0 commit comments