Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit fa47956

Browse files
committed
wip
1 parent e64dc8b commit fa47956

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

torchtext/prototype/generate.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,12 @@ def _update_model_kwargs_for_generation(
9999
self,
100100
outputs: Dict[str, Any],
101101
model_kwargs: Dict[str, Any],
102-
) -> MODEL_KWARGS_TYPE:
102+
) -> None:
103103
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
104104
105105
Args:
106106
outputs (Dict[str, Any]): LM output.
107107
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
108-
109-
Returns:
110-
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
111108
"""
112109
# Update past
113110
if "past_key_values" in outputs:
@@ -138,8 +135,6 @@ def _update_model_kwargs_for_generation(
138135
dim=-1,
139136
)
140137

141-
return model_kwargs
142-
143138
def greedy_search(
144139
self,
145140
input_ids: torch.Tensor,
@@ -222,6 +217,8 @@ def beam_search(
222217
Returns:
223218
Tensor of the generated sequences.
224219
"""
220+
device = input_ids.device
221+
225222
if self.is_encoder_decoder:
226223
encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output"
227224
encoder_output = model_kwargs["encoder_outputs"][encoder_output_key]
@@ -231,9 +228,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231228

232229
i = T # Hacky access to the current seq in inputs
233230

234-
# Copy over the `model_kwargs` in order to modify
235-
new_model_kwargs = model_kwargs.copy()
236-
237231
# For first timestep, create previous step token_idxs and model_states
238232
if timestep == 0:
239233
prev_step_token_idxs = [-1]
@@ -268,7 +262,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
268262
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
269263
token_indices = (
270264
torch.Tensor(prev_step_token_idxs[start:end])
271-
.to(dtype=torch.long, device=self.model.device)
265+
.to(dtype=torch.long, device=device)
272266
.reshape(num_samples, 1)
273267
)
274268

@@ -281,23 +275,24 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
281275
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
282276
else:
283277
assert len(prev_model_state_sequences) == 1
284-
state_and_tokens = token_indices = prev_model_state_sequences[0] # dims: [1, 1]
278+
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(num_beams, -1) # TODO: Make this more robust
279+
285280

286281
# Cleanup -- combine this with the above
287282
if self.is_encoder_decoder:
288283
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
289284
# This is a view-only operation and doesn't copy
290-
new_model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
291-
num_samples if timestep > 0 else 1, -1, -1
285+
model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
286+
num_samples if timestep > 0 else num_beams, -1, -1
292287
)
293288

294289
# Preprocess inputs for generation
295-
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **new_model_kwargs)
290+
model_inputs = self.model.prepare_inputs_for_generation(token_indices, **model_kwargs)
296291
if self.is_huggingface_model:
297292
model_inputs.update(self._huggingface_model_input_values)
298-
if len(prev_step_hyp_idxs) > 1 and model_inputs["past_key_values"] is not None:
293+
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
299294
model_inputs["past_key_values"] = self.model._reorder_cache(
300-
model_inputs["past_key_values"],
295+
model_kwargs["past"],
301296
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32),
302297
)
303298

@@ -310,32 +305,23 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
310305

311306
# HF optimizations to reduce overhead in future `forward` calls
312307
if self.is_huggingface_model:
313-
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs)
308+
self._update_model_kwargs_for_generation(outputs, model_kwargs)
314309

315310
# Keep track of probabilities over vocab for this pairing
316-
# TODO: clean up duplicate code in these branches
317-
if timestep == 0:
318-
sample_lm_scores = torch.squeeze(lm_scores[:, -1])
311+
# TODO: fix how we track the number here?
312+
for i in range(lm_scores.shape[0]):
313+
sample_lm_scores = lm_scores[i, -1]
319314
out_probs.append(sample_lm_scores.tolist())
315+
# Keep track of sequence and decoder hidden states
320316
model_states.append(
321317
create_emitting_model_state(
322-
Seq2SeqModelState(timestep=timestep, sequence=state_and_tokens, lm_scores=sample_lm_scores)
323-
)
324-
)
325-
else:
326-
for i in range(num_samples):
327-
sample_lm_scores = lm_scores[i, -1]
328-
out_probs.append(sample_lm_scores.tolist())
329-
# Keep track of sequence and decoder hidden states
330-
model_states.append(
331-
create_emitting_model_state(
332-
Seq2SeqModelState(
333-
timestep=timestep,
334-
sequence=state_and_tokens[i].unsqueeze(0),
335-
lm_scores=sample_lm_scores,
336-
)
318+
Seq2SeqModelState(
319+
timestep=timestep,
320+
sequence=state_and_tokens[i].unsqueeze(0),
321+
lm_scores=sample_lm_scores,
337322
)
338323
)
324+
)
339325

340326
start += step
341327

0 commit comments

Comments
 (0)