Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit daac585

Browse files
committed
Add HF update methods
1 parent 9afc810 commit daac585

File tree

1 file changed

+127
-35
lines changed

1 file changed

+127
-35
lines changed

torchtext/prototype/generate.py

Lines changed: 127 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import List, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
import torch
55
import torch.nn.functional as F
@@ -26,7 +26,7 @@ class Seq2SeqModelState(object):
2626
lm_scores: Optional[torch.Tensor]
2727

2828

29-
class GenerationUtils:
29+
class GenerationUtils(nn.Module):
3030
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.
3131
3232
Example:
@@ -51,20 +51,72 @@ class GenerationUtils:
5151
"""
5252

5353
def __init__(self, model: nn.Module, **kwargs) -> None:
54+
super().__init__()
5455
self.model = model
5556
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
5657
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
58+
59+
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs, model_kwargs):
60+
"""Modified from."""
61+
# Get encoder
62+
encoder = self.model.get_encoder()
63+
64+
# Prepare encoder args and encoder kwargs from model kwargs
65+
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
66+
encoder_kwargs = {}
67+
for argument, value in model_kwargs.items():
68+
if not any([argument.startswith(p) for p in irrelevant_prefix]):
69+
encoder_kwargs[argument] = value
70+
71+
# Forward pass
72+
if self.is_huggingface_model:
73+
encoder_kwargs["return_dict"] = True
74+
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
75+
76+
return model_kwargs
5777

5878
def _prepare_decoder_ids_for_generation(
59-
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs
79+
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, model_kwargs: Optional[Dict[str, Any]] = None
6080
):
6181
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
6282
return model_kwargs.pop("decoder_input_ids")
6383
else:
6484
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
6585

86+
def _update_model_kwargs_for_generation(
87+
self,
88+
outputs: Dict[str, Any],
89+
model_kwargs: Dict[str, Any],
90+
is_encoder_decoder: bool = False,
91+
) -> Dict[str, Any]:
92+
"""Modified from."""
93+
# Update past
94+
if "past_key_values" in outputs:
95+
model_kwargs["past"] = outputs.past_key_values
96+
elif "mems" in outputs:
97+
model_kwargs["past"] = outputs.mems
98+
elif "past_buckets_states" in outputs:
99+
model_kwargs["past"] = outputs.past_buckets_states
100+
else:
101+
model_kwargs["past"] = None
102+
103+
# Update token_type_ids with last value
104+
if "token_type_ids" in model_kwargs:
105+
token_type_ids = model_kwargs["token_type_ids"]
106+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
107+
108+
# Update attention mask
109+
if not is_encoder_decoder:
110+
if "attention_mask" in model_kwargs:
111+
attention_mask = model_kwargs["attention_mask"]
112+
model_kwargs["attention_mask"] = torch.cat(
113+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
114+
)
115+
116+
return model_kwargs
117+
66118
def greedy_search(
67-
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
119+
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = {}
68120
) -> torch.Tensor:
69121
"""Greedy search decoding for text generation. Takes the most likely next token every time.
70122
@@ -73,7 +125,7 @@ def greedy_search(
73125
max_length (int): Max length to generate responses.
74126
eos_idx (int): End of sequence index.
75127
pad_idx (int): Padding index.
76-
**model_kwargs
128+
model_kwargs
77129
78130
Returns:
79131
Batch of sequences decoded by greedy search.
@@ -123,7 +175,8 @@ def beam_search(
123175
eos_score: float,
124176
eos_idx: int,
125177
num_python_workers: int,
126-
**model_kwargs,
178+
max_inference_batch_size: int,
179+
model_kwargs,
127180
) -> torch.Tensor:
128181
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
129182
@@ -136,7 +189,7 @@ def beam_search(
136189
eos_score (float): Score to input when `eos_idx` is generated.
137190
eos_idx (int): End-of-sequence index.
138191
num_python_workers (int): Number of python workers to use for multiprocessing.
139-
**model_kwargs
192+
model_kwargs
140193
141194
Returns:
142195
Tensor of the generated sequences.
@@ -147,8 +200,9 @@ def beam_search(
147200

148201
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
149202
# `emissions` and `N` are unused in this current implementation
150-
i = T # Hacky, but access the current seq in inputs
151-
203+
204+
i = T # Hacky access to the current seq in inputs
205+
152206
# Copy over the `model_kwargs` in order to modify
153207
new_model_kwargs = model_kwargs.copy()
154208

@@ -161,31 +215,30 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
161215
)
162216
]
163217

164-
encoder_output_indexed = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None
218+
encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None
165219
prev_model_state_sequences = [
166220
get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states
167221
]
168222
out_probs, model_states = [], []
169223

170-
# Batch inference of chunks of elements in the beam
171224
start = 0
172-
# TODO: make this configurable to help people get around OOMs.
173225
# This is the parallelism level at which elements in the beam will be batched
174-
MAX_INFERENCE_BATCH_SIZE = 16
175226
step = min(
176-
MAX_INFERENCE_BATCH_SIZE, 1000 / (timestep + 1)
227+
max_inference_batch_size, 1000 / (timestep + 1)
177228
) # many hypotheses will EOS, so increase the batch size gradually
178-
cur_beam_size = len(prev_step_token_idxs)
179-
while start < cur_beam_size: # catch the remainder
229+
curr_beam_size = len(prev_step_token_idxs)
230+
231+
# 2. Batched inference to get next tokens
232+
while start < curr_beam_size: # catch the remainder
180233
end = start + step
181-
if end > cur_beam_size:
182-
end = cur_beam_size
234+
if end > curr_beam_size:
235+
end = curr_beam_size
183236

184-
num_samples = end - start
237+
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
185238

186239
if prev_step_token_idxs != [-1]:
187240
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
188-
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(torch.long).reshape(num_samples, 1)
241+
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=self.model.device).reshape(num_samples, 1)
189242

190243
state_and_tokens = torch.cat(
191244
[state_sequences, token_indices], dim=-1
@@ -198,15 +251,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
198251
assert len(prev_model_state_sequences) == 1
199252
state_and_tokens = prev_model_state_sequences[0] # dims: [1, 1]
200253

201-
start += step
202-
203254
# Cleanup -- combine this with the above
204255
if self.is_encoder_decoder:
205256
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
206257
# This is a view-only operation and doesn't copy
207-
new_model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_indexed.expand(
258+
new_model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
208259
num_samples if timestep > 0 else 1, -1, -1
209260
)
261+
210262
# Forward pass
211263
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
212264

@@ -218,6 +270,12 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
218270
output_key = "logits" if self.is_huggingface_model else "decoder_output"
219271
lm_scores = outputs[output_key]
220272

273+
# HF optimizations to reduce overhead in future `forward` calls
274+
if self.is_huggingface_model:
275+
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder)
276+
if new_model_kwargs["past"] is not None:
277+
new_model_kwargs["past"] = self.model._reorder_cache(new_model_kwargs["past"], torch.Tensor(num_samples).to(dtype=torch.int32, device=self.model.device))
278+
221279
# Keep track of probabilities over vocab for this pairing
222280
# TODO: clean up duplicate code in these branches
223281
if timestep == 0:
@@ -242,9 +300,12 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
242300
)
243301
)
244302
)
303+
304+
start += step
245305

246306
return out_probs, model_states
247307

308+
# 3. Initialize options and decoder from Flashlight Text
248309
options = LexiconFreeSeq2SeqDecoderOptions(
249310
beam_size=num_beams,
250311
beam_size_token=beam_size_token,
@@ -258,14 +319,16 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
258319
options=options, lm=ZeroLM(), eos_idx=eos_idx, update_func=update_func, max_output_length=max_len
259320
)
260321

261-
# Create these as function b/c unnamed functions (lambdas) cause problems w/ MP
262-
def select_second_elem_in_tuple(tup: Tuple[List[int], float]) -> float:
263-
return tup[1]
322+
# 4. Process outputs from beam decoder
323+
# TODO: This can definitely be optimized
324+
def beam_decode_step(timestep: int) -> torch.Tensor:
325+
# Create these as function b/c unnamed functions (lambdas) cause problems w/ MP
326+
def select_second_elem_in_tuple(tup: Tuple[List[int], float]) -> float:
327+
return tup[1]
264328

265-
def is_not_neg_one(elem: int) -> bool:
266-
return elem != -1
329+
def is_not_neg_one(elem: int) -> bool:
330+
return elem != -1
267331

268-
def beam_decode_step(timestep: int) -> torch.Tensor:
269332
# Decode step takes ptr to encoder emissions, i, and beam size token
270333
# but actually these aren't currently being used.
271334
decoder.decode_step(0, timestep, 0)
@@ -292,9 +355,38 @@ def beam_decode_step(timestep: int) -> torch.Tensor:
292355
logger.warning("Multiprocessing has not yet been implemented.")
293356

294357
all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))]
295-
358+
359+
# 5. Return top hypotheses for all input sequences
296360
return torch.stack(all_final_tokens, dim=0)
297361

362+
363+
def forward(
364+
self,
365+
inputs: Optional[torch.Tensor] = None,
366+
num_beams: Optional[int] = None,
367+
max_len: Optional[int] = None,
368+
pad_idx: int = 0,
369+
eos_idx: int = 1,
370+
beam_threshold: int = 100,
371+
beam_size_token: Optional[int] = None,
372+
eos_score: float = -1.0,
373+
num_python_workers: int = 1,
374+
max_inference_batch_size: int = 16,
375+
):
376+
return self.generate(
377+
inputs=inputs,
378+
num_beams=num_beams,
379+
max_len=max_len,
380+
pad_idx=pad_idx,
381+
eos_idx=eos_idx,
382+
beam_threshold=beam_threshold,
383+
beam_size_token=beam_size_token,
384+
eos_score=eos_score,
385+
num_python_workers=num_python_workers,
386+
max_inference_batch_size=max_inference_batch_size,
387+
)
388+
389+
298390
def generate(
299391
self,
300392
inputs: Optional[torch.Tensor] = None,
@@ -306,6 +398,7 @@ def generate(
306398
beam_size_token: Optional[int] = None,
307399
eos_score: float = -1.0,
308400
num_python_workers: int = 1,
401+
max_inference_batch_size: int = 16,
309402
) -> torch.Tensor:
310403
"""Generation method.
311404
@@ -329,10 +422,8 @@ def generate(
329422
model_kwargs = {}
330423

331424
if self.is_encoder_decoder:
332-
encoder = self.model.get_encoder()
333-
# print("inputs size is", inputs.shape)
334-
model_kwargs["encoder_outputs"] = encoder(inputs)
335-
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)
425+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
426+
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, model_kwargs=model_kwargs)
336427

337428
if max_length is None:
338429
# Too hard to try to figure out the exact max_seq_length for each model
@@ -356,7 +447,8 @@ def generate(
356447
eos_score=eos_score,
357448
num_python_workers=num_python_workers,
358449
eos_idx=eos_idx,
359-
**model_kwargs,
450+
max_inference_batch_size=max_inference_batch_size,
451+
model_kwargs=model_kwargs,
360452
)
361453
else:
362454
raise ValueError("`num_beams` must be >= 1.")

0 commit comments

Comments
 (0)