Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7ba41ca

Browse files
committed
wip feb 8/2
1 parent 34a1346 commit 7ba41ca

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

torchtext/prototype/generate.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
295295
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
296296
if self.is_huggingface_model:
297297
model_inputs.update(self._huggingface_model_input_values)
298+
if len(prev_step_hyp_idxs) > 1 and model_inputs["past_key_values"] is not None:
299+
model_inputs["past_key_values"] = self.model._reorder_cache(
300+
model_inputs["past_key_values"],
301+
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
302+
)
298303

299304
# Forward pass
300305
outputs = self.model(**model_inputs)
@@ -306,14 +311,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
306311
# HF optimizations to reduce overhead in future `forward` calls
307312
if self.is_huggingface_model:
308313
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs)
309-
if new_model_kwargs["past"] is not None and len(prev_step_hyp_idxs) > 1:
310-
if len(prev_step_hyp_idxs) == 9:
311-
import pdb
312-
pdb.set_trace()
313-
new_model_kwargs["past"] = self.model._reorder_cache(
314-
new_model_kwargs["past"],
315-
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
316-
)
317314

318315
# Keep track of probabilities over vocab for this pairing
319316
# TODO: clean up duplicate code in these branches

0 commit comments

Comments
 (0)