1
1
from dataclasses import dataclass
2
- from typing import List , Optional , Tuple
2
+ from typing import Any , Dict , List , Optional , Tuple
3
3
4
4
import torch
5
5
import torch .nn .functional as F
@@ -26,7 +26,7 @@ class Seq2SeqModelState(object):
26
26
lm_scores : Optional [torch .Tensor ]
27
27
28
28
29
- class GenerationUtils :
29
+ class GenerationUtils ( nn . Module ) :
30
30
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.
31
31
32
32
Example:
@@ -51,20 +51,72 @@ class GenerationUtils:
51
51
"""
52
52
53
53
def __init__ (self , model : nn .Module , ** kwargs ) -> None :
54
+ super ().__init__ ()
54
55
self .model = model
55
56
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
56
57
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
58
+
59
+ def _prepare_encoder_decoder_kwargs_for_generation (self , inputs , model_kwargs ):
60
+ """Modified from."""
61
+ # Get encoder
62
+ encoder = self .model .get_encoder ()
63
+
64
+ # Prepare encoder args and encoder kwargs from model kwargs
65
+ irrelevant_prefix = ["decoder_" , "cross_attn" , "use_cache" ]
66
+ encoder_kwargs = {}
67
+ for argument , value in model_kwargs .items ():
68
+ if not any ([argument .startswith (p ) for p in irrelevant_prefix ]):
69
+ encoder_kwargs [argument ] = value
70
+
71
+ # Forward pass
72
+ if self .is_huggingface_model :
73
+ encoder_kwargs ["return_dict" ] = True
74
+ model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
75
+
76
+ return model_kwargs
57
77
58
78
def _prepare_decoder_ids_for_generation (
59
- self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , ** model_kwargs
79
+ self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , model_kwargs : Optional [ Dict [ str , Any ]] = None
60
80
):
61
81
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
62
82
return model_kwargs .pop ("decoder_input_ids" )
63
83
else :
64
84
return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
65
85
86
+ def _update_model_kwargs_for_generation (
87
+ self ,
88
+ outputs : Dict [str , Any ],
89
+ model_kwargs : Dict [str , Any ],
90
+ is_encoder_decoder : bool = False ,
91
+ ) -> Dict [str , Any ]:
92
+ """Modified from."""
93
+ # Update past
94
+ if "past_key_values" in outputs :
95
+ model_kwargs ["past" ] = outputs .past_key_values
96
+ elif "mems" in outputs :
97
+ model_kwargs ["past" ] = outputs .mems
98
+ elif "past_buckets_states" in outputs :
99
+ model_kwargs ["past" ] = outputs .past_buckets_states
100
+ else :
101
+ model_kwargs ["past" ] = None
102
+
103
+ # Update token_type_ids with last value
104
+ if "token_type_ids" in model_kwargs :
105
+ token_type_ids = model_kwargs ["token_type_ids" ]
106
+ model_kwargs ["token_type_ids" ] = torch .cat ([token_type_ids , token_type_ids [:, - 1 ].unsqueeze (- 1 )], dim = - 1 )
107
+
108
+ # Update attention mask
109
+ if not is_encoder_decoder :
110
+ if "attention_mask" in model_kwargs :
111
+ attention_mask = model_kwargs ["attention_mask" ]
112
+ model_kwargs ["attention_mask" ] = torch .cat (
113
+ [attention_mask , attention_mask .new_ones ((attention_mask .shape [0 ], 1 ))], dim = - 1
114
+ )
115
+
116
+ return model_kwargs
117
+
66
118
def greedy_search (
67
- self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : Optional [int ] = None , ** model_kwargs
119
+ self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : Optional [int ] = None , model_kwargs : Optional [ Dict [ str , Any ]] = {}
68
120
) -> torch .Tensor :
69
121
"""Greedy search decoding for text generation. Takes the most likely next token every time.
70
122
@@ -73,7 +125,7 @@ def greedy_search(
73
125
max_length (int): Max length to generate responses.
74
126
eos_idx (int): End of sequence index.
75
127
pad_idx (int): Padding index.
76
- ** model_kwargs
128
+ model_kwargs
77
129
78
130
Returns:
79
131
Batch of sequences decoded by greedy search.
@@ -123,7 +175,8 @@ def beam_search(
123
175
eos_score : float ,
124
176
eos_idx : int ,
125
177
num_python_workers : int ,
126
- ** model_kwargs ,
178
+ max_inference_batch_size : int ,
179
+ model_kwargs ,
127
180
) -> torch .Tensor :
128
181
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
129
182
@@ -136,7 +189,7 @@ def beam_search(
136
189
eos_score (float): Score to input when `eos_idx` is generated.
137
190
eos_idx (int): End-of-sequence index.
138
191
num_python_workers (int): Number of python workers to use for multiprocessing.
139
- ** model_kwargs
192
+ model_kwargs
140
193
141
194
Returns:
142
195
Tensor of the generated sequences.
@@ -147,8 +200,9 @@ def beam_search(
147
200
148
201
def update_func (emissions , N , T , prev_step_token_idxs , prev_step_model_states , timestep ):
149
202
# `emissions` and `N` are unused in this current implementation
150
- i = T # Hacky, but access the current seq in inputs
151
-
203
+
204
+ i = T # Hacky access to the current seq in inputs
205
+
152
206
# Copy over the `model_kwargs` in order to modify
153
207
new_model_kwargs = model_kwargs .copy ()
154
208
@@ -161,31 +215,30 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
161
215
)
162
216
]
163
217
164
- encoder_output_indexed = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
218
+ encoder_output_for_curr_seq = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
165
219
prev_model_state_sequences = [
166
220
get_obj_from_emitting_model_state (state ).sequence for state in prev_step_model_states
167
221
]
168
222
out_probs , model_states = [], []
169
223
170
- # Batch inference of chunks of elements in the beam
171
224
start = 0
172
- # TODO: make this configurable to help people get around OOMs.
173
225
# This is the parallelism level at which elements in the beam will be batched
174
- MAX_INFERENCE_BATCH_SIZE = 16
175
226
step = min (
176
- MAX_INFERENCE_BATCH_SIZE , 1000 / (timestep + 1 )
227
+ max_inference_batch_size , 1000 / (timestep + 1 )
177
228
) # many hypotheses will EOS, so increase the batch size gradually
178
- cur_beam_size = len (prev_step_token_idxs )
179
- while start < cur_beam_size : # catch the remainder
229
+ curr_beam_size = len (prev_step_token_idxs )
230
+
231
+ # 2. Batched inference to get next tokens
232
+ while start < curr_beam_size : # catch the remainder
180
233
end = start + step
181
- if end > cur_beam_size :
182
- end = cur_beam_size
234
+ if end > curr_beam_size :
235
+ end = curr_beam_size
183
236
184
- num_samples = end - start
237
+ num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
185
238
186
239
if prev_step_token_idxs != [- 1 ]:
187
240
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
188
- token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (torch .long ).reshape (num_samples , 1 )
241
+ token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = self . model . device ).reshape (num_samples , 1 )
189
242
190
243
state_and_tokens = torch .cat (
191
244
[state_sequences , token_indices ], dim = - 1
@@ -198,15 +251,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
198
251
assert len (prev_model_state_sequences ) == 1
199
252
state_and_tokens = prev_model_state_sequences [0 ] # dims: [1, 1]
200
253
201
- start += step
202
-
203
254
# Cleanup -- combine this with the above
204
255
if self .is_encoder_decoder :
205
256
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
206
257
# This is a view-only operation and doesn't copy
207
- new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_indexed .expand (
258
+ new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
208
259
num_samples if timestep > 0 else 1 , - 1 , - 1
209
260
)
261
+
210
262
# Forward pass
211
263
model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
212
264
@@ -218,6 +270,12 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
218
270
output_key = "logits" if self .is_huggingface_model else "decoder_output"
219
271
lm_scores = outputs [output_key ]
220
272
273
+ # HF optimizations to reduce overhead in future `forward` calls
274
+ if self .is_huggingface_model :
275
+ new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder )
276
+ if new_model_kwargs ["past" ] is not None :
277
+ new_model_kwargs ["past" ] = self .model ._reorder_cache (new_model_kwargs ["past" ], torch .Tensor (num_samples ).to (dtype = torch .int32 , device = self .model .device ))
278
+
221
279
# Keep track of probabilities over vocab for this pairing
222
280
# TODO: clean up duplicate code in these branches
223
281
if timestep == 0 :
@@ -242,9 +300,12 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
242
300
)
243
301
)
244
302
)
303
+
304
+ start += step
245
305
246
306
return out_probs , model_states
247
307
308
+ # 3. Initialize options and decoder from Flashlight Text
248
309
options = LexiconFreeSeq2SeqDecoderOptions (
249
310
beam_size = num_beams ,
250
311
beam_size_token = beam_size_token ,
@@ -258,14 +319,16 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
258
319
options = options , lm = ZeroLM (), eos_idx = eos_idx , update_func = update_func , max_output_length = max_len
259
320
)
260
321
261
- # Create these as function b/c unnamed functions (lambdas) cause problems w/ MP
262
- def select_second_elem_in_tuple (tup : Tuple [List [int ], float ]) -> float :
263
- return tup [1 ]
322
+ # 4. Process outputs from beam decoder
323
+ # TODO: This can definitely be optimized
324
+ def beam_decode_step (timestep : int ) -> torch .Tensor :
325
+ # Create these as function b/c unnamed functions (lambdas) cause problems w/ MP
326
+ def select_second_elem_in_tuple (tup : Tuple [List [int ], float ]) -> float :
327
+ return tup [1 ]
264
328
265
- def is_not_neg_one (elem : int ) -> bool :
266
- return elem != - 1
329
+ def is_not_neg_one (elem : int ) -> bool :
330
+ return elem != - 1
267
331
268
- def beam_decode_step (timestep : int ) -> torch .Tensor :
269
332
# Decode step takes ptr to encoder emissions, i, and beam size token
270
333
# but actually these aren't currently being used.
271
334
decoder .decode_step (0 , timestep , 0 )
@@ -292,9 +355,38 @@ def beam_decode_step(timestep: int) -> torch.Tensor:
292
355
logger .warning ("Multiprocessing has not yet been implemented." )
293
356
294
357
all_final_tokens = [beam_decode_step (i ) for i in range (len (input_ids ))]
295
-
358
+
359
+ # 5. Return top hypotheses for all input sequences
296
360
return torch .stack (all_final_tokens , dim = 0 )
297
361
362
+
363
+ def forward (
364
+ self ,
365
+ inputs : Optional [torch .Tensor ] = None ,
366
+ num_beams : Optional [int ] = None ,
367
+ max_len : Optional [int ] = None ,
368
+ pad_idx : int = 0 ,
369
+ eos_idx : int = 1 ,
370
+ beam_threshold : int = 100 ,
371
+ beam_size_token : Optional [int ] = None ,
372
+ eos_score : float = - 1.0 ,
373
+ num_python_workers : int = 1 ,
374
+ max_inference_batch_size : int = 16 ,
375
+ ):
376
+ return self .generate (
377
+ inputs = inputs ,
378
+ num_beams = num_beams ,
379
+ max_len = max_len ,
380
+ pad_idx = pad_idx ,
381
+ eos_idx = eos_idx ,
382
+ beam_threshold = beam_threshold ,
383
+ beam_size_token = beam_size_token ,
384
+ eos_score = eos_score ,
385
+ num_python_workers = num_python_workers ,
386
+ max_inference_batch_size = max_inference_batch_size ,
387
+ )
388
+
389
+
298
390
def generate (
299
391
self ,
300
392
inputs : Optional [torch .Tensor ] = None ,
@@ -306,6 +398,7 @@ def generate(
306
398
beam_size_token : Optional [int ] = None ,
307
399
eos_score : float = - 1.0 ,
308
400
num_python_workers : int = 1 ,
401
+ max_inference_batch_size : int = 16 ,
309
402
) -> torch .Tensor :
310
403
"""Generation method.
311
404
@@ -329,10 +422,8 @@ def generate(
329
422
model_kwargs = {}
330
423
331
424
if self .is_encoder_decoder :
332
- encoder = self .model .get_encoder ()
333
- # print("inputs size is", inputs.shape)
334
- model_kwargs ["encoder_outputs" ] = encoder (inputs )
335
- inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , ** model_kwargs )
425
+ model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
426
+ inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , model_kwargs = model_kwargs )
336
427
337
428
if max_length is None :
338
429
# Too hard to try to figure out the exact max_seq_length for each model
@@ -356,7 +447,8 @@ def generate(
356
447
eos_score = eos_score ,
357
448
num_python_workers = num_python_workers ,
358
449
eos_idx = eos_idx ,
359
- ** model_kwargs ,
450
+ max_inference_batch_size = max_inference_batch_size ,
451
+ model_kwargs = model_kwargs ,
360
452
)
361
453
else :
362
454
raise ValueError ("`num_beams` must be >= 1." )
0 commit comments