Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d622466

Browse files
committed
wip feb 9
1 parent fa47956 commit d622466

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torchtext/prototype/generate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,9 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
275275
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
276276
else:
277277
assert len(prev_model_state_sequences) == 1
278-
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(num_beams, -1) # TODO: Make this more robust
279-
278+
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(
279+
num_beams, -1
280+
) # TODO: Make this more robust
280281

281282
# Cleanup -- combine this with the above
282283
if self.is_encoder_decoder:
@@ -287,14 +288,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
287288
)
288289

289290
# Preprocess inputs for generation
290-
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **model_kwargs)
291+
model_inputs = self.model.prepare_inputs_for_generation(
292+
token_indices, **model_kwargs
293+
) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
291294
if self.is_huggingface_model:
292295
model_inputs.update(self._huggingface_model_input_values)
293296
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
294-
model_inputs["past_key_values"] = self.model._reorder_cache(
295-
model_kwargs["past"],
296-
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32),
297-
)
297+
beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
298+
model_inputs["past_key_values"] = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
298299

299300
# Forward pass
300301
outputs = self.model(**model_inputs)

0 commit comments

Comments
 (0)