Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9afc810

Browse files
jacobkahnjoecummings
authored andcommitted
Batch decoder inference
1 parent 2ddc638 commit 9afc810

File tree

1 file changed

+87
-52
lines changed

1 file changed

+87
-52
lines changed

torchtext/prototype/generate.py

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import torch
55
import torch.nn.functional as F
66
from torch import nn
7-
from flashlight.lib.text.decoder import LexiconFreeSeq2SeqDecoder, LexiconFreeSeq2SeqDecoderOptions, ZeroLM, create_emitting_model_state, get_obj_from_emitting_model_state
7+
from flashlight.lib.text.decoder import (
8+
LexiconFreeSeq2SeqDecoder,
9+
LexiconFreeSeq2SeqDecoderOptions,
10+
ZeroLM,
11+
create_emitting_model_state,
12+
get_obj_from_emitting_model_state,
13+
)
814

915
logger = logging.getLogger(__name__)
1016

@@ -108,7 +114,7 @@ def greedy_search(
108114
return input_ids
109115

110116
def beam_search(
111-
self,
117+
self,
112118
input_ids: torch.Tensor,
113119
num_beams: int,
114120
max_len: int,
@@ -117,7 +123,7 @@ def beam_search(
117123
eos_score: float,
118124
eos_idx: int,
119125
num_python_workers: int,
120-
**model_kwargs
126+
**model_kwargs,
121127
) -> torch.Tensor:
122128
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
123129
@@ -145,40 +151,65 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
145151

146152
# Copy over the `model_kwargs` in order to modify
147153
new_model_kwargs = model_kwargs.copy()
148-
154+
149155
# For first timestep, create previous step token_idxs and model_states
150156
if timestep == 0:
151157
prev_step_token_idxs = [-1]
152158
prev_step_model_states = [
153159
create_emitting_model_state(
154-
Seq2SeqModelState(
155-
timestep=0,
156-
sequence=input_ids[i].unsqueeze(0),
157-
lm_scores=None
158-
)
160+
Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None)
159161
)
160162
]
161-
162-
if self.is_encoder_decoder:
163-
# Get the correct encoded seq from the full `encoder_output`` and put it in the correct format
164-
new_model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output[i, :, :].unsqueeze(0)
165163

164+
encoder_output_indexed = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None
165+
prev_model_state_sequences = [
166+
get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states
167+
]
166168
out_probs, model_states = [], []
167-
for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states):
168-
# Convert `idx` into a Tensor b/c it's always returned as a native python `int`
169-
idx = torch.Tensor([idx]).to(torch.long)
170-
171-
# Get previous model state
172-
prev_model_state = get_obj_from_emitting_model_state(model_state_ptr)
173-
174-
# Create new decoder token ids
175-
if idx != -1:
176-
new_input_ids = torch.cat([prev_model_state.sequence, idx.unsqueeze(0)], dim=-1)
169+
170+
# Batch inference of chunks of elements in the beam
171+
start = 0
172+
# TODO: make this configurable to help people get around OOMs.
173+
# This is the parallelism level at which elements in the beam will be batched
174+
MAX_INFERENCE_BATCH_SIZE = 16
175+
step = min(
176+
MAX_INFERENCE_BATCH_SIZE, 1000 / (timestep + 1)
177+
) # many hypotheses will EOS, so increase the batch size gradually
178+
cur_beam_size = len(prev_step_token_idxs)
179+
while start < cur_beam_size: # catch the remainder
180+
end = start + step
181+
if end > cur_beam_size:
182+
end = cur_beam_size
183+
184+
num_samples = end - start
185+
186+
if prev_step_token_idxs != [-1]:
187+
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
188+
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(torch.long).reshape(num_samples, 1)
189+
190+
state_and_tokens = torch.cat(
191+
[state_sequences, token_indices], dim=-1
192+
) # [batch_size x (timestep + 1)]
193+
assert state_and_tokens.shape == (
194+
num_samples,
195+
timestep + 1,
196+
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
177197
else:
178-
new_input_ids = prev_model_state.sequence
179-
198+
assert len(prev_model_state_sequences) == 1
199+
state_and_tokens = prev_model_state_sequences[0] # dims: [1, 1]
200+
201+
start += step
202+
203+
# Cleanup -- combine this with the above
204+
if self.is_encoder_decoder:
205+
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
206+
# This is a view-only operation and doesn't copy
207+
new_model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_indexed.expand(
208+
num_samples if timestep > 0 else 1, -1, -1
209+
)
180210
# Forward pass
181-
model_inputs = self.model.prepare_inputs_for_generation(new_input_ids, **new_model_kwargs)
211+
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
212+
182213
if self.is_huggingface_model:
183214
model_inputs["return_dict"] = True
184215
model_inputs["output_hidden_states"] = True
@@ -188,18 +219,29 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
188219
lm_scores = outputs[output_key]
189220

190221
# Keep track of probabilities over vocab for this pairing
191-
out_probs.append(torch.squeeze(lm_scores[:, -1]).tolist())
192-
193-
# Keep track of sequence and decoder hidden states
194-
model_states.append(
195-
create_emitting_model_state(
196-
Seq2SeqModelState(
197-
timestep=timestep,
198-
sequence=new_input_ids,
199-
lm_scores=lm_scores
222+
# TODO: clean up duplicate code in these branches
223+
if timestep == 0:
224+
sample_lm_scores = torch.squeeze(lm_scores[:, -1])
225+
out_probs.append(sample_lm_scores.tolist())
226+
model_states.append(
227+
create_emitting_model_state(
228+
Seq2SeqModelState(timestep=timestep, sequence=state_and_tokens, lm_scores=sample_lm_scores)
200229
)
201230
)
202-
)
231+
else:
232+
for i in range(num_samples):
233+
sample_lm_scores = lm_scores[i, -1]
234+
out_probs.append(sample_lm_scores.tolist())
235+
# Keep track of sequence and decoder hidden states
236+
model_states.append(
237+
create_emitting_model_state(
238+
Seq2SeqModelState(
239+
timestep=timestep,
240+
sequence=state_and_tokens[i].unsqueeze(0),
241+
lm_scores=sample_lm_scores,
242+
)
243+
)
244+
)
203245

204246
return out_probs, model_states
205247

@@ -213,17 +255,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
213255
)
214256

215257
decoder = LexiconFreeSeq2SeqDecoder(
216-
options=options,
217-
lm=ZeroLM(),
218-
eos_idx=eos_idx,
219-
update_func=update_func,
220-
max_output_length=max_len
258+
options=options, lm=ZeroLM(), eos_idx=eos_idx, update_func=update_func, max_output_length=max_len
221259
)
222260

223261
# Create these as function b/c unnamed functions (lambdas) cause problems w/ MP
224262
def select_second_elem_in_tuple(tup: Tuple[List[int], float]) -> float:
225263
return tup[1]
226-
264+
227265
def is_not_neg_one(elem: int) -> bool:
228266
return elem != -1
229267

@@ -235,12 +273,7 @@ def beam_decode_step(timestep: int) -> torch.Tensor:
235273

236274
# Find the best beam
237275
token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
238-
final_tokens = list(
239-
filter(
240-
is_not_neg_one,
241-
max(token_scores, key=select_second_elem_in_tuple)[0]
242-
)
243-
)
276+
final_tokens = list(filter(is_not_neg_one, max(token_scores, key=select_second_elem_in_tuple)[0]))
244277

245278
# Hack, but have to prepend the input tokens if decoder-only model
246279
if not self.is_encoder_decoder:
@@ -249,15 +282,15 @@ def beam_decode_step(timestep: int) -> torch.Tensor:
249282
# Makeshift padding so that we can stack the tensors
250283
while len(final_tokens) < max_len:
251284
final_tokens += [0]
252-
285+
253286
# Convert from list to tensors
254287
final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long)
255288

256289
return final_tokens_as_tensors
257290

258291
if num_python_workers > 1:
259292
logger.warning("Multiprocessing has not yet been implemented.")
260-
293+
261294
all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))]
262295

263296
return torch.stack(all_final_tokens, dim=0)
@@ -297,6 +330,7 @@ def generate(
297330

298331
if self.is_encoder_decoder:
299332
encoder = self.model.get_encoder()
333+
# print("inputs size is", inputs.shape)
300334
model_kwargs["encoder_outputs"] = encoder(inputs)
301335
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)
302336

@@ -309,7 +343,8 @@ def generate(
309343
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs)
310344
elif num_beams > 1:
311345
if beam_size_token is None:
312-
raise ValueError("`beam_size_token` must be specified for beam search. \
346+
raise ValueError(
347+
"`beam_size_token` must be specified for beam search. \
313348
If confused about what to put, you can default to the vocab size of the model you are using."
314349
)
315350
return self.beam_search(
@@ -321,7 +356,7 @@ def generate(
321356
eos_score=eos_score,
322357
num_python_workers=num_python_workers,
323358
eos_idx=eos_idx,
324-
**model_kwargs
359+
**model_kwargs,
325360
)
326361
else:
327362
raise ValueError("`num_beams` must be >= 1.")

0 commit comments

Comments
 (0)