@@ -82,7 +82,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -
82
82
83
83
# Forward pass
84
84
# Explicitly call forward method to assert to assert this is a ScriptModule if JITted
85
- model_kwargs = {"encoder_outputs" : encoder .forward (inputs )} # , **encoder_kwargs)
85
+ model_kwargs = {"encoder_outputs" : encoder .forward (inputs , ** encoder_kwargs )}
86
86
return model_kwargs
87
87
88
88
def _prepare_decoder_ids_for_generation (
@@ -286,7 +286,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
286
286
), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
287
287
else :
288
288
assert len (prev_model_state_sequences ) == 1
289
- state_and_tokens = prev_model_state_sequences [0 ] # dims: [1, 1]
289
+ state_and_tokens = token_indices = prev_model_state_sequences [0 ] # dims: [1, 1]
290
290
291
291
# Cleanup -- combine this with the above
292
292
if self .is_encoder_decoder :
@@ -297,13 +297,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
297
297
)
298
298
299
299
# Preprocess inputs for generation
300
- model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
300
+ model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** new_model_kwargs )
301
301
if self .is_huggingface_model :
302
302
model_inputs .update (self ._huggingface_model_input_values )
303
303
if len (prev_step_hyp_idxs ) > 1 and model_inputs ["past_key_values" ] is not None :
304
304
model_inputs ["past_key_values" ] = self .model ._reorder_cache (
305
305
model_inputs ["past_key_values" ],
306
- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
306
+ torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
307
307
)
308
308
309
309
# Forward pass
0 commit comments