Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8b9a12a

Browse files
committed
wip
1 parent 852b5c4 commit 8b9a12a

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

torchtext/prototype/generate.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,12 @@ def _update_model_kwargs_for_generation(
104104
self,
105105
outputs: Dict[str, Any],
106106
model_kwargs: Dict[str, Any],
107-
) -> MODEL_KWARGS_TYPE:
107+
) -> None:
108108
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
109109
110110
Args:
111111
outputs (Dict[str, Any]): LM output.
112112
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
113-
114-
Returns:
115-
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
116113
"""
117114
# Update past
118115
if "past_key_values" in outputs:
@@ -143,8 +140,6 @@ def _update_model_kwargs_for_generation(
143140
dim=-1,
144141
)
145142

146-
return model_kwargs
147-
148143
def greedy_search(
149144
self,
150145
input_ids: torch.Tensor,
@@ -227,6 +222,8 @@ def beam_search(
227222
Returns:
228223
Tensor of the generated sequences.
229224
"""
225+
device = input_ids.device
226+
230227
if self.is_encoder_decoder:
231228
encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output"
232229
encoder_output = model_kwargs["encoder_outputs"][encoder_output_key]
@@ -236,9 +233,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
236233

237234
i = T # Hacky access to the current seq in inputs
238235

239-
# Copy over the `model_kwargs` in order to modify
240-
new_model_kwargs = model_kwargs.copy()
241-
242236
# For first timestep, create previous step token_idxs and model_states
243237
if timestep == 0:
244238
prev_step_token_idxs = [-1]
@@ -273,7 +267,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
273267
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
274268
token_indices = (
275269
torch.Tensor(prev_step_token_idxs[start:end])
276-
.to(dtype=torch.long, device=self.model.device)
270+
.to(dtype=torch.long, device=device)
277271
.reshape(num_samples, 1)
278272
)
279273

@@ -286,23 +280,24 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
286280
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
287281
else:
288282
assert len(prev_model_state_sequences) == 1
289-
state_and_tokens = token_indices = prev_model_state_sequences[0] # dims: [1, 1]
283+
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(num_beams, -1) # TODO: Make this more robust
284+
290285

291286
# Cleanup -- combine this with the above
292287
if self.is_encoder_decoder:
293288
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
294289
# This is a view-only operation and doesn't copy
295-
new_model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
296-
num_samples if timestep > 0 else 1, -1, -1
290+
model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
291+
num_samples if timestep > 0 else num_beams, -1, -1
297292
)
298293

299294
# Preprocess inputs for generation
300-
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **new_model_kwargs)
295+
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **model_kwargs)
301296
if self.is_huggingface_model:
302297
model_inputs.update(self._huggingface_model_input_values)
303-
if len(prev_step_hyp_idxs) > 1 and model_inputs["past_key_values"] is not None:
298+
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
304299
model_inputs["past_key_values"] = self.model._reorder_cache(
305-
model_inputs["past_key_values"],
300+
model_kwargs["past"],
306301
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32),
307302
)
308303

@@ -315,32 +310,23 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
315310

316311
# HF optimizations to reduce overhead in future `forward` calls
317312
if self.is_huggingface_model:
318-
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs)
313+
self._update_model_kwargs_for_generation(outputs, model_kwargs)
319314

320315
# Keep track of probabilities over vocab for this pairing
321-
# TODO: clean up duplicate code in these branches
322-
if timestep == 0:
323-
sample_lm_scores = torch.squeeze(lm_scores[:, -1])
316+
# TODO: fix how we track the number here?
317+
for i in range(lm_scores.shape[0]):
318+
sample_lm_scores = lm_scores[i, -1]
324319
out_probs.append(sample_lm_scores.tolist())
320+
# Keep track of sequence and decoder hidden states
325321
model_states.append(
326322
create_emitting_model_state(
327-
Seq2SeqModelState(timestep=timestep, sequence=state_and_tokens, lm_scores=sample_lm_scores)
328-
)
329-
)
330-
else:
331-
for i in range(num_samples):
332-
sample_lm_scores = lm_scores[i, -1]
333-
out_probs.append(sample_lm_scores.tolist())
334-
# Keep track of sequence and decoder hidden states
335-
model_states.append(
336-
create_emitting_model_state(
337-
Seq2SeqModelState(
338-
timestep=timestep,
339-
sequence=state_and_tokens[i].unsqueeze(0),
340-
lm_scores=sample_lm_scores,
341-
)
323+
Seq2SeqModelState(
324+
timestep=timestep,
325+
sequence=state_and_tokens[i].unsqueeze(0),
326+
lm_scores=sample_lm_scores,
342327
)
343328
)
329+
)
344330

345331
start += step
346332

0 commit comments

Comments
 (0)