Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 57c3a00

Browse files
committed
repro unbounded beam_idx
1 parent 7f27bc1 commit 57c3a00

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

torchtext/prototype/generate.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,27 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
300300
model_inputs.update(self._huggingface_model_input_values)
301301
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
302302
beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
303-
model_inputs["past_key_values"] = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
303+
304+
# We could store this in model_kwargs
305+
num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0]
306+
307+
num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs)
308+
if num_finished_hyps_in_step > 0:
309+
beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0)
310+
311+
reordered_cached = model_kwargs["past"] #self.model._reorder_cache(model_kwargs["past"], beam_idxs)
312+
313+
if num_finished_hyps_in_step > 0:
314+
sliced_cache = ()
315+
for states in reordered_cached:
316+
sliced_state = ()
317+
for state in states:
318+
sliced_state = sliced_state + (state[:len(prev_step_hyp_idxs)],)
319+
sliced_cache = sliced_cache + (sliced_state,)
320+
reordered_cached = sliced_cache
321+
322+
model_inputs["past_key_values"] = reordered_cached
323+
304324

305325
# Forward pass
306326
outputs = self.model(**model_inputs)

0 commit comments

Comments
 (0)