14
14
15
15
import logging
16
16
import warnings
17
+
17
18
logger = logging .getLogger (__name__ )
18
19
19
20
DEFAULT_MAX_SEQ_LEN = 256
@@ -52,25 +53,23 @@ class GenerationUtils(nn.Module):
52
53
More examples can be found in the `notebooks` directory of this repository.
53
54
"""
54
55
55
- _huggingface_model_input_values = {
56
- "return_dict" : True ,
57
- "use_cache" : True ,
58
- "output_hidden_states" : True
59
- }
56
+ _huggingface_model_input_values = {"return_dict" : True , "use_cache" : True , "output_hidden_states" : True }
60
57
61
58
def __init__ (self , model : nn .Module , ** kwargs ) -> None :
62
59
super ().__init__ ()
63
60
self .model = model
64
61
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
65
62
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
66
-
67
- def _prepare_encoder_decoder_kwargs_for_generation (self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
63
+
64
+ def _prepare_encoder_decoder_kwargs_for_generation (
65
+ self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]
66
+ ) -> Dict [str , Any ]:
68
67
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
69
68
70
69
Args:
71
70
inputs: (Tensor): Tokenized startings sequence(s).
72
71
model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
73
-
72
+
74
73
Returns:
75
74
Modified model_kwargs with addition of encoded input sequence(s).
76
75
"""
@@ -83,19 +82,23 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, m
83
82
# Forward pass
84
83
if self .is_huggingface_model :
85
84
encoder_kwargs ["return_dict" ] = True
86
-
85
+
87
86
# import pdb
88
87
# pdb.set_trace()
89
88
# print(encoder_kwargs.keys())
90
-
89
+
91
90
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
92
-
91
+
93
92
model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
94
93
95
94
return model_kwargs
96
95
97
96
def _prepare_decoder_ids_for_generation (
98
- self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , model_kwargs : Optional [Dict [str , Any ]] = None
97
+ self ,
98
+ batch_size : int ,
99
+ pad_idx : int = 0 ,
100
+ device : Optional [torch .device ] = None ,
101
+ model_kwargs : Optional [Dict [str , Any ]] = None ,
99
102
):
100
103
"""Prepare decoder IDs for generation."""
101
104
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
@@ -113,7 +116,7 @@ def _update_model_kwargs_for_generation(
113
116
Args:
114
117
outputs (Dict[str, Any]): LM output.
115
118
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
116
-
119
+
117
120
Returns:
118
121
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
119
122
"""
@@ -149,7 +152,12 @@ def _update_model_kwargs_for_generation(
149
152
return model_kwargs
150
153
151
154
def greedy_search (
152
- self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : Optional [int ] = None , model_kwargs : Optional [Dict [str , Any ]] = {}
155
+ self ,
156
+ input_ids : torch .Tensor ,
157
+ max_length : int ,
158
+ eos_idx : int ,
159
+ pad_idx : Optional [int ] = None ,
160
+ model_kwargs : Optional [Dict [str , Any ]] = {},
153
161
) -> torch .Tensor :
154
162
"""Greedy search decoding for text generation. Takes the most likely next token every time.
155
163
@@ -222,7 +230,7 @@ def beam_search(
222
230
eos_idx (int): End-of-sequence index.
223
231
num_python_workers (int): Number of python workers to use for multiprocessing.
224
232
model_kwargs
225
-
233
+
226
234
Returns:
227
235
Tensor of the generated sequences.
228
236
"""
@@ -232,9 +240,9 @@ def beam_search(
232
240
233
241
def update_func (emissions , N , T , prev_step_token_idxs , prev_step_model_states , timestep ):
234
242
# `emissions` and `N` are unused in this current implementation
235
-
243
+
236
244
i = T # Hacky access to the current seq in inputs
237
-
245
+
238
246
# Copy over the `model_kwargs` in order to modify
239
247
new_model_kwargs = model_kwargs .copy ()
240
248
@@ -259,18 +267,22 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
259
267
max_inference_batch_size , 1000 / (timestep + 1 )
260
268
) # many hypotheses will EOS, so increase the batch size gradually
261
269
curr_beam_size = len (prev_step_token_idxs )
262
-
270
+
263
271
# 2. Batched inference to get next tokens
264
272
while start < curr_beam_size : # catch the remainder
265
273
end = start + step
266
274
if end > curr_beam_size :
267
275
end = curr_beam_size
268
276
269
- num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
277
+ num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
270
278
271
279
if prev_step_token_idxs != [- 1 ]:
272
280
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
273
- token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = self .model .device ).reshape (num_samples , 1 )
281
+ token_indices = (
282
+ torch .Tensor (prev_step_token_idxs [start :end ])
283
+ .to (dtype = torch .long , device = self .model .device )
284
+ .reshape (num_samples , 1 )
285
+ )
274
286
275
287
state_and_tokens = torch .cat (
276
288
[state_sequences , token_indices ], dim = - 1
@@ -308,14 +320,17 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
308
320
309
321
# HF optimizations to reduce overhead in future `forward` calls
310
322
if self .is_huggingface_model :
311
- new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder )
323
+ new_model_kwargs = self ._update_model_kwargs_for_generation (
324
+ outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder
325
+ )
312
326
if new_model_kwargs ["past" ] is not None :
313
327
import pdb
328
+
314
329
pdb .set_trace ()
315
330
beam_indices += [start for _ in range (num_samples )]
316
331
new_model_kwargs ["past" ] = self .model ._reorder_cache (
317
332
new_model_kwargs ["past" ],
318
- torch .Tensor (beam_indices ).to (dtype = torch .int32 ) # I think this is correct?
333
+ torch .Tensor (beam_indices ).to (dtype = torch .int32 ), # I think this is correct?
319
334
)
320
335
321
336
# Keep track of probabilities over vocab for this pairing
@@ -342,7 +357,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
342
357
)
343
358
)
344
359
)
345
-
360
+
346
361
start += step
347
362
348
363
return out_probs , model_states
@@ -397,11 +412,10 @@ def is_not_neg_one(elem: int) -> bool:
397
412
logger .warning ("Multiprocessing has not yet been implemented." )
398
413
399
414
all_final_tokens = [beam_decode_step (i ) for i in range (len (input_ids ))]
400
-
415
+
401
416
# 5. Return top hypotheses for all input sequences
402
417
return torch .stack (all_final_tokens , dim = 0 )
403
418
404
-
405
419
def forward (
406
420
self ,
407
421
inputs : Optional [torch .Tensor ] = None ,
@@ -465,10 +479,12 @@ def generate(
465
479
2. `num_beams` > 1 -> beam search
466
480
"""
467
481
model_kwargs = {}
468
-
482
+
469
483
if self .is_encoder_decoder :
470
484
model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
471
- inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , model_kwargs = model_kwargs )
485
+ inputs = self ._prepare_decoder_ids_for_generation (
486
+ len (inputs ), device = inputs .device , model_kwargs = model_kwargs
487
+ )
472
488
473
489
if max_length is None :
474
490
# Too hard to try to figure out the exact max_seq_length for each model
0 commit comments