Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 852b5c4

Browse files
committed
Only decode new tokens
1 parent c43654e commit 852b5c4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchtext/prototype/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -
8282

8383
# Forward pass
8484
# Explicitly call forward method to assert to assert this is a ScriptModule if JITted
85-
model_kwargs = {"encoder_outputs": encoder.forward(inputs)} # , **encoder_kwargs)
85+
model_kwargs = {"encoder_outputs": encoder.forward(inputs, **encoder_kwargs)}
8686
return model_kwargs
8787

8888
def _prepare_decoder_ids_for_generation(
@@ -286,7 +286,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
286286
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
287287
else:
288288
assert len(prev_model_state_sequences) == 1
289-
state_and_tokens = prev_model_state_sequences[0] # dims: [1, 1]
289+
state_and_tokens = token_indices = prev_model_state_sequences[0] # dims: [1, 1]
290290

291291
# Cleanup -- combine this with the above
292292
if self.is_encoder_decoder:
@@ -297,13 +297,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
297297
)
298298

299299
# Preprocess inputs for generation
300-
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
300+
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **new_model_kwargs)
301301
if self.is_huggingface_model:
302302
model_inputs.update(self._huggingface_model_input_values)
303303
if len(prev_step_hyp_idxs) > 1 and model_inputs["past_key_values"] is not None:
304304
model_inputs["past_key_values"] = self.model._reorder_cache(
305305
model_inputs["past_key_values"],
306-
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
306+
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32),
307307
)
308308

309309
# Forward pass

0 commit comments

Comments
 (0)