4
4
import torch
5
5
import torch .nn .functional as F
6
6
from torch import nn
7
- from flashlight .lib .text .decoder import LexiconFreeSeq2SeqDecoder , LexiconFreeSeq2SeqDecoderOptions , ZeroLM , create_emitting_model_state , get_obj_from_emitting_model_state
7
+ from flashlight .lib .text .decoder import (
8
+ LexiconFreeSeq2SeqDecoder ,
9
+ LexiconFreeSeq2SeqDecoderOptions ,
10
+ ZeroLM ,
11
+ create_emitting_model_state ,
12
+ get_obj_from_emitting_model_state ,
13
+ )
8
14
9
15
logger = logging .getLogger (__name__ )
10
16
@@ -108,7 +114,7 @@ def greedy_search(
108
114
return input_ids
109
115
110
116
def beam_search (
111
- self ,
117
+ self ,
112
118
input_ids : torch .Tensor ,
113
119
num_beams : int ,
114
120
max_len : int ,
@@ -117,7 +123,7 @@ def beam_search(
117
123
eos_score : float ,
118
124
eos_idx : int ,
119
125
num_python_workers : int ,
120
- ** model_kwargs
126
+ ** model_kwargs ,
121
127
) -> torch .Tensor :
122
128
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
123
129
@@ -145,40 +151,65 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
145
151
146
152
# Copy over the `model_kwargs` in order to modify
147
153
new_model_kwargs = model_kwargs .copy ()
148
-
154
+
149
155
# For first timestep, create previous step token_idxs and model_states
150
156
if timestep == 0 :
151
157
prev_step_token_idxs = [- 1 ]
152
158
prev_step_model_states = [
153
159
create_emitting_model_state (
154
- Seq2SeqModelState (
155
- timestep = 0 ,
156
- sequence = input_ids [i ].unsqueeze (0 ),
157
- lm_scores = None
158
- )
160
+ Seq2SeqModelState (timestep = 0 , sequence = input_ids [i ].unsqueeze (0 ), lm_scores = None )
159
161
)
160
162
]
161
-
162
- if self .is_encoder_decoder :
163
- # Get the correct encoded seq from the full `encoder_output`` and put it in the correct format
164
- new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output [i , :, :].unsqueeze (0 )
165
163
164
+ encoder_output_indexed = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
165
+ prev_model_state_sequences = [
166
+ get_obj_from_emitting_model_state (state ).sequence for state in prev_step_model_states
167
+ ]
166
168
out_probs , model_states = [], []
167
- for idx , model_state_ptr in zip (prev_step_token_idxs , prev_step_model_states ):
168
- # Convert `idx` into a Tensor b/c it's always returned as a native python `int`
169
- idx = torch .Tensor ([idx ]).to (torch .long )
170
-
171
- # Get previous model state
172
- prev_model_state = get_obj_from_emitting_model_state (model_state_ptr )
173
-
174
- # Create new decoder token ids
175
- if idx != - 1 :
176
- new_input_ids = torch .cat ([prev_model_state .sequence , idx .unsqueeze (0 )], dim = - 1 )
169
+
170
+ # Batch inference of chunks of elements in the beam
171
+ start = 0
172
+ # TODO: make this configurable to help people get around OOMs.
173
+ # This is the parallelism level at which elements in the beam will be batched
174
+ MAX_INFERENCE_BATCH_SIZE = 16
175
+ step = min (
176
+ MAX_INFERENCE_BATCH_SIZE , 1000 / (timestep + 1 )
177
+ ) # many hypotheses will EOS, so increase the batch size gradually
178
+ cur_beam_size = len (prev_step_token_idxs )
179
+ while start < cur_beam_size : # catch the remainder
180
+ end = start + step
181
+ if end > cur_beam_size :
182
+ end = cur_beam_size
183
+
184
+ num_samples = end - start
185
+
186
+ if prev_step_token_idxs != [- 1 ]:
187
+ state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
188
+ token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (torch .long ).reshape (num_samples , 1 )
189
+
190
+ state_and_tokens = torch .cat (
191
+ [state_sequences , token_indices ], dim = - 1
192
+ ) # [batch_size x (timestep + 1)]
193
+ assert state_and_tokens .shape == (
194
+ num_samples ,
195
+ timestep + 1 ,
196
+ ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
177
197
else :
178
- new_input_ids = prev_model_state .sequence
179
-
198
+ assert len (prev_model_state_sequences ) == 1
199
+ state_and_tokens = prev_model_state_sequences [0 ] # dims: [1, 1]
200
+
201
+ start += step
202
+
203
+ # Cleanup -- combine this with the above
204
+ if self .is_encoder_decoder :
205
+ # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
206
+ # This is a view-only operation and doesn't copy
207
+ new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_indexed .expand (
208
+ num_samples if timestep > 0 else 1 , - 1 , - 1
209
+ )
180
210
# Forward pass
181
- model_inputs = self .model .prepare_inputs_for_generation (new_input_ids , ** new_model_kwargs )
211
+ model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
212
+
182
213
if self .is_huggingface_model :
183
214
model_inputs ["return_dict" ] = True
184
215
model_inputs ["output_hidden_states" ] = True
@@ -188,18 +219,29 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
188
219
lm_scores = outputs [output_key ]
189
220
190
221
# Keep track of probabilities over vocab for this pairing
191
- out_probs .append (torch .squeeze (lm_scores [:, - 1 ]).tolist ())
192
-
193
- # Keep track of sequence and decoder hidden states
194
- model_states .append (
195
- create_emitting_model_state (
196
- Seq2SeqModelState (
197
- timestep = timestep ,
198
- sequence = new_input_ids ,
199
- lm_scores = lm_scores
222
+ # TODO: clean up duplicate code in these branches
223
+ if timestep == 0 :
224
+ sample_lm_scores = torch .squeeze (lm_scores [:, - 1 ])
225
+ out_probs .append (sample_lm_scores .tolist ())
226
+ model_states .append (
227
+ create_emitting_model_state (
228
+ Seq2SeqModelState (timestep = timestep , sequence = state_and_tokens , lm_scores = sample_lm_scores )
200
229
)
201
230
)
202
- )
231
+ else :
232
+ for i in range (num_samples ):
233
+ sample_lm_scores = lm_scores [i , - 1 ]
234
+ out_probs .append (sample_lm_scores .tolist ())
235
+ # Keep track of sequence and decoder hidden states
236
+ model_states .append (
237
+ create_emitting_model_state (
238
+ Seq2SeqModelState (
239
+ timestep = timestep ,
240
+ sequence = state_and_tokens [i ].unsqueeze (0 ),
241
+ lm_scores = sample_lm_scores ,
242
+ )
243
+ )
244
+ )
203
245
204
246
return out_probs , model_states
205
247
@@ -213,17 +255,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
213
255
)
214
256
215
257
decoder = LexiconFreeSeq2SeqDecoder (
216
- options = options ,
217
- lm = ZeroLM (),
218
- eos_idx = eos_idx ,
219
- update_func = update_func ,
220
- max_output_length = max_len
258
+ options = options , lm = ZeroLM (), eos_idx = eos_idx , update_func = update_func , max_output_length = max_len
221
259
)
222
260
223
261
# Create these as function b/c unnamed functions (lambdas) cause problems w/ MP
224
262
def select_second_elem_in_tuple (tup : Tuple [List [int ], float ]) -> float :
225
263
return tup [1 ]
226
-
264
+
227
265
def is_not_neg_one (elem : int ) -> bool :
228
266
return elem != - 1
229
267
@@ -235,12 +273,7 @@ def beam_decode_step(timestep: int) -> torch.Tensor:
235
273
236
274
# Find the best beam
237
275
token_scores = [(hyp .tokens , hyp .score ) for hyp in hyps ]
238
- final_tokens = list (
239
- filter (
240
- is_not_neg_one ,
241
- max (token_scores , key = select_second_elem_in_tuple )[0 ]
242
- )
243
- )
276
+ final_tokens = list (filter (is_not_neg_one , max (token_scores , key = select_second_elem_in_tuple )[0 ]))
244
277
245
278
# Hack, but have to prepend the input tokens if decoder-only model
246
279
if not self .is_encoder_decoder :
@@ -249,15 +282,15 @@ def beam_decode_step(timestep: int) -> torch.Tensor:
249
282
# Makeshift padding so that we can stack the tensors
250
283
while len (final_tokens ) < max_len :
251
284
final_tokens += [0 ]
252
-
285
+
253
286
# Convert from list to tensors
254
287
final_tokens_as_tensors = torch .Tensor (final_tokens ).to (torch .long )
255
288
256
289
return final_tokens_as_tensors
257
290
258
291
if num_python_workers > 1 :
259
292
logger .warning ("Multiprocessing has not yet been implemented." )
260
-
293
+
261
294
all_final_tokens = [beam_decode_step (i ) for i in range (len (input_ids ))]
262
295
263
296
return torch .stack (all_final_tokens , dim = 0 )
@@ -297,6 +330,7 @@ def generate(
297
330
298
331
if self .is_encoder_decoder :
299
332
encoder = self .model .get_encoder ()
333
+ # print("inputs size is", inputs.shape)
300
334
model_kwargs ["encoder_outputs" ] = encoder (inputs )
301
335
inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , ** model_kwargs )
302
336
@@ -309,7 +343,8 @@ def generate(
309
343
return self .greedy_search (inputs , max_length , eos_idx , pad_idx = pad_idx , ** model_kwargs )
310
344
elif num_beams > 1 :
311
345
if beam_size_token is None :
312
- raise ValueError ("`beam_size_token` must be specified for beam search. \
346
+ raise ValueError (
347
+ "`beam_size_token` must be specified for beam search. \
313
348
If confused about what to put, you can default to the vocab size of the model you are using."
314
349
)
315
350
return self .beam_search (
@@ -321,7 +356,7 @@ def generate(
321
356
eos_score = eos_score ,
322
357
num_python_workers = num_python_workers ,
323
358
eos_idx = eos_idx ,
324
- ** model_kwargs
359
+ ** model_kwargs ,
325
360
)
326
361
else :
327
362
raise ValueError ("`num_beams` must be >= 1." )
0 commit comments