Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7a95b93

Browse files
committed
wip
1 parent 365de76 commit 7a95b93

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torchtext/prototype/generate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def greedy_search(
133133
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
134134
if self.is_huggingface_model:
135135
model_inputs["return_dict"] = True
136+
model_inputs["use_cache"] = True
136137
model_inputs["output_hidden_states"] = True
137138

138139
# Get model output
@@ -258,11 +259,15 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
258259

259260
# Forward pass
260261
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
262+
print(model_inputs.get("use_cache"), model_inputs.get("past_key_values"))
261263

262264
if self.is_huggingface_model:
263265
model_inputs["return_dict"] = True
266+
model_inputs["use_cache"] = True
264267
model_inputs["output_hidden_states"] = True
265268

269+
print(model_inputs.get("use_cache"), model_inputs.get("past_key_values"))
270+
266271
outputs = self.model(**model_inputs)
267272
output_key = "logits" if self.is_huggingface_model else "decoder_output"
268273
lm_scores = outputs[output_key]

0 commit comments

Comments
 (0)