@@ -275,8 +275,9 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
275
275
), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
276
276
else :
277
277
assert len (prev_model_state_sequences ) == 1
278
- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (num_beams , - 1 ) # TODO: Make this more robust
279
-
278
+ state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
279
+ num_beams , - 1
280
+ ) # TODO: Make this more robust
280
281
281
282
# Cleanup -- combine this with the above
282
283
if self .is_encoder_decoder :
@@ -287,14 +288,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
287
288
)
288
289
289
290
# Preprocess inputs for generation
290
- model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** model_kwargs )
291
+ model_inputs = self .model .prepare_inputs_for_generation (
292
+ token_indices , ** model_kwargs
293
+ ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
291
294
if self .is_huggingface_model :
292
295
model_inputs .update (self ._huggingface_model_input_values )
293
296
if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
294
- model_inputs ["past_key_values" ] = self .model ._reorder_cache (
295
- model_kwargs ["past" ],
296
- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
297
- )
297
+ beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
298
+ model_inputs ["past_key_values" ] = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
298
299
299
300
# Forward pass
300
301
outputs = self .model (** model_inputs )
0 commit comments