Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7f27bc1

Browse files
committed
wip feb 9
1 parent 8b9a12a commit 7f27bc1

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torchtext/prototype/generate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,9 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
280280
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
281281
else:
282282
assert len(prev_model_state_sequences) == 1
283-
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(num_beams, -1) # TODO: Make this more robust
284-
283+
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(
284+
num_beams, -1
285+
) # TODO: Make this more robust
285286

286287
# Cleanup -- combine this with the above
287288
if self.is_encoder_decoder:
@@ -292,14 +293,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
292293
)
293294

294295
# Preprocess inputs for generation
295-
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **model_kwargs)
296+
model_inputs = self.model.prepare_inputs_for_generation(
297+
token_indices, **model_kwargs
298+
) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
296299
if self.is_huggingface_model:
297300
model_inputs.update(self._huggingface_model_input_values)
298301
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
299-
model_inputs["past_key_values"] = self.model._reorder_cache(
300-
model_kwargs["past"],
301-
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32),
302-
)
302+
beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
303+
model_inputs["past_key_values"] = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
303304

304305
# Forward pass
305306
outputs = self.model(**model_inputs)

0 commit comments

Comments
 (0)