Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c43654e

Browse files
committed
wip feb 8/2
1 parent e64f1ef commit c43654e

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

torchtext/prototype/generate.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
300300
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
301301
if self.is_huggingface_model:
302302
model_inputs.update(self._huggingface_model_input_values)
303+
if len(prev_step_hyp_idxs) > 1 and model_inputs["past_key_values"] is not None:
304+
model_inputs["past_key_values"] = self.model._reorder_cache(
305+
model_inputs["past_key_values"],
306+
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
307+
)
303308

304309
# Forward pass
305310
outputs = self.model(**model_inputs)
@@ -311,14 +316,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
311316
# HF optimizations to reduce overhead in future `forward` calls
312317
if self.is_huggingface_model:
313318
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs)
314-
if new_model_kwargs["past"] is not None and len(prev_step_hyp_idxs) > 1:
315-
if len(prev_step_hyp_idxs) == 9:
316-
import pdb
317-
pdb.set_trace()
318-
new_model_kwargs["past"] = self.model._reorder_cache(
319-
new_model_kwargs["past"],
320-
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
321-
)
322319

323320
# Keep track of probabilities over vocab for this pairing
324321
# TODO: clean up duplicate code in these branches

0 commit comments

Comments
 (0)