Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e64dc8b

Browse files
committed
Only decode new tokens
1 parent 7ba41ca commit e64dc8b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchtext/prototype/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -
7777

7878
# Forward pass
7979
# Explicitly call forward method to assert to assert this is a ScriptModule if JITted
80-
model_kwargs = {"encoder_outputs": encoder.forward(inputs)} # , **encoder_kwargs)
80+
model_kwargs = {"encoder_outputs": encoder.forward(inputs, **encoder_kwargs)}
8181
return model_kwargs
8282

8383
def _prepare_decoder_ids_for_generation(
@@ -281,7 +281,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
281281
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
282282
else:
283283
assert len(prev_model_state_sequences) == 1
284-
state_and_tokens = prev_model_state_sequences[0] # dims: [1, 1]
284+
state_and_tokens = token_indices = prev_model_state_sequences[0] # dims: [1, 1]
285285

286286
# Cleanup -- combine this with the above
287287
if self.is_encoder_decoder:
@@ -292,13 +292,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
292292
)
293293

294294
# Preprocess inputs for generation
295-
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
295+
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **new_model_kwargs)
296296
if self.is_huggingface_model:
297297
model_inputs.update(self._huggingface_model_input_values)
298298
if len(prev_step_hyp_idxs) > 1 and model_inputs["past_key_values"] is not None:
299299
model_inputs["past_key_values"] = self.model._reorder_cache(
300300
model_inputs["past_key_values"],
301-
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
301+
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32),
302302
)
303303

304304
# Forward pass

0 commit comments

Comments
 (0)