@@ -77,7 +77,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -
77
77
78
78
# Forward pass
79
79
# Explicitly call forward method to assert to assert this is a ScriptModule if JITted
80
- model_kwargs = {"encoder_outputs" : encoder .forward (inputs )} # , **encoder_kwargs)
80
+ model_kwargs = {"encoder_outputs" : encoder .forward (inputs , ** encoder_kwargs )}
81
81
return model_kwargs
82
82
83
83
def _prepare_decoder_ids_for_generation (
@@ -281,7 +281,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
281
281
), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
282
282
else :
283
283
assert len (prev_model_state_sequences ) == 1
284
- state_and_tokens = prev_model_state_sequences [0 ] # dims: [1, 1]
284
+ state_and_tokens = token_indices = prev_model_state_sequences [0 ] # dims: [1, 1]
285
285
286
286
# Cleanup -- combine this with the above
287
287
if self .is_encoder_decoder :
@@ -292,13 +292,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
292
292
)
293
293
294
294
# Preprocess inputs for generation
295
- model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
295
+ model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** new_model_kwargs )
296
296
if self .is_huggingface_model :
297
297
model_inputs .update (self ._huggingface_model_input_values )
298
298
if len (prev_step_hyp_idxs ) > 1 and model_inputs ["past_key_values" ] is not None :
299
299
model_inputs ["past_key_values" ] = self .model ._reorder_cache (
300
300
model_inputs ["past_key_values" ],
301
- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
301
+ torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
302
302
)
303
303
304
304
# Forward pass
0 commit comments