14
14
15
15
import logging
16
16
import warnings
17
+
17
18
logger = logging .getLogger (__name__ )
18
19
19
20
@@ -47,25 +48,23 @@ class GenerationUtil(nn.Module):
47
48
More examples can be found in the `notebooks` directory of this repository.
48
49
"""
49
50
50
- _huggingface_model_input_values = {
51
- "return_dict" : True ,
52
- "use_cache" : True ,
53
- "output_hidden_states" : True
54
- }
51
+ _huggingface_model_input_values = {"return_dict" : True , "use_cache" : True , "output_hidden_states" : True }
55
52
56
53
def __init__ (self , model : nn .Module , ** kwargs ) -> None :
57
54
super ().__init__ ()
58
55
self .model = model
59
56
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
60
57
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
61
-
62
- def _prepare_encoder_decoder_kwargs_for_generation (self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
58
+
59
+ def _prepare_encoder_decoder_kwargs_for_generation (
60
+ self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]
61
+ ) -> Dict [str , Any ]:
63
62
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
64
63
65
64
Args:
66
65
inputs: (Tensor): Tokenized startings sequence(s).
67
66
model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
68
-
67
+
69
68
Returns:
70
69
Modified model_kwargs with addition of encoded input sequence(s).
71
70
"""
@@ -78,19 +77,23 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, m
78
77
# Forward pass
79
78
if self .is_huggingface_model :
80
79
encoder_kwargs ["return_dict" ] = True
81
-
80
+
82
81
# import pdb
83
82
# pdb.set_trace()
84
83
# print(encoder_kwargs.keys())
85
-
84
+
86
85
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
87
-
86
+
88
87
model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
89
88
90
89
return model_kwargs
91
90
92
91
def _prepare_decoder_ids_for_generation (
93
- self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , model_kwargs : Optional [Dict [str , Any ]] = None
92
+ self ,
93
+ batch_size : int ,
94
+ pad_idx : int = 0 ,
95
+ device : Optional [torch .device ] = None ,
96
+ model_kwargs : Optional [Dict [str , Any ]] = None ,
94
97
):
95
98
"""Prepare decoder IDs for generation."""
96
99
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
@@ -108,7 +111,7 @@ def _update_model_kwargs_for_generation(
108
111
Args:
109
112
outputs (Dict[str, Any]): LM output.
110
113
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
111
-
114
+
112
115
Returns:
113
116
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
114
117
"""
@@ -144,7 +147,12 @@ def _update_model_kwargs_for_generation(
144
147
return model_kwargs
145
148
146
149
def greedy_search (
147
- self , input_ids : torch .Tensor , max_len : int , eos_idx : int , pad_idx : Optional [int ] = None , model_kwargs : Optional [Dict [str , Any ]] = {}
150
+ self ,
151
+ input_ids : torch .Tensor ,
152
+ max_len : int ,
153
+ eos_idx : int ,
154
+ pad_idx : Optional [int ] = None ,
155
+ model_kwargs : Optional [Dict [str , Any ]] = {},
148
156
) -> torch .Tensor :
149
157
"""Greedy search decoding for text generation. Takes the most likely next token every time.
150
158
@@ -217,7 +225,7 @@ def beam_search(
217
225
eos_idx (int): End-of-sequence index.
218
226
num_python_workers (int): Number of python workers to use for multiprocessing.
219
227
model_kwargs
220
-
228
+
221
229
Returns:
222
230
Tensor of the generated sequences.
223
231
"""
@@ -227,9 +235,9 @@ def beam_search(
227
235
228
236
def update_func (emissions , N , T , prev_step_token_idxs , prev_step_model_states , timestep ):
229
237
# `emissions` and `N` are unused in this current implementation
230
-
238
+
231
239
i = T # Hacky access to the current seq in inputs
232
-
240
+
233
241
# Copy over the `model_kwargs` in order to modify
234
242
new_model_kwargs = model_kwargs .copy ()
235
243
@@ -254,18 +262,22 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
254
262
max_inference_batch_size , 1000 / (timestep + 1 )
255
263
) # many hypotheses will EOS, so increase the batch size gradually
256
264
curr_beam_size = len (prev_step_token_idxs )
257
-
265
+
258
266
# 2. Batched inference to get next tokens
259
267
while start < curr_beam_size : # catch the remainder
260
268
end = start + step
261
269
if end > curr_beam_size :
262
270
end = curr_beam_size
263
271
264
- num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
272
+ num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
265
273
266
274
if prev_step_token_idxs != [- 1 ]:
267
275
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
268
- token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = self .model .device ).reshape (num_samples , 1 )
276
+ token_indices = (
277
+ torch .Tensor (prev_step_token_idxs [start :end ])
278
+ .to (dtype = torch .long , device = self .model .device )
279
+ .reshape (num_samples , 1 )
280
+ )
269
281
270
282
state_and_tokens = torch .cat (
271
283
[state_sequences , token_indices ], dim = - 1
@@ -303,14 +315,17 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
303
315
304
316
# HF optimizations to reduce overhead in future `forward` calls
305
317
if self .is_huggingface_model :
306
- new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder )
318
+ new_model_kwargs = self ._update_model_kwargs_for_generation (
319
+ outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder
320
+ )
307
321
if new_model_kwargs ["past" ] is not None :
308
322
import pdb
323
+
309
324
pdb .set_trace ()
310
325
beam_indices += [start for _ in range (num_samples )]
311
326
new_model_kwargs ["past" ] = self .model ._reorder_cache (
312
327
new_model_kwargs ["past" ],
313
- torch .Tensor (beam_indices ).to (dtype = torch .int32 ) # I think this is correct?
328
+ torch .Tensor (beam_indices ).to (dtype = torch .int32 ), # I think this is correct?
314
329
)
315
330
316
331
# Keep track of probabilities over vocab for this pairing
@@ -337,7 +352,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
337
352
)
338
353
)
339
354
)
340
-
355
+
341
356
start += step
342
357
343
358
return out_probs , model_states
@@ -392,11 +407,10 @@ def is_not_neg_one(elem: int) -> bool:
392
407
logger .warning ("Multiprocessing has not yet been implemented." )
393
408
394
409
all_final_tokens = [beam_decode_step (i ) for i in range (len (input_ids ))]
395
-
410
+
396
411
# 5. Return top hypotheses for all input sequences
397
412
return torch .stack (all_final_tokens , dim = 0 )
398
413
399
-
400
414
def forward (
401
415
self ,
402
416
inputs : Optional [torch .Tensor ] = None ,
@@ -460,10 +474,12 @@ def generate(
460
474
2. `num_beams` > 1 -> beam search
461
475
"""
462
476
model_kwargs = {}
463
-
477
+
464
478
if self .is_encoder_decoder :
465
479
model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
466
- inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , model_kwargs = model_kwargs )
480
+ inputs = self ._prepare_decoder_ids_for_generation (
481
+ len (inputs ), device = inputs .device , model_kwargs = model_kwargs
482
+ )
467
483
468
484
if max_len is None :
469
485
# Too hard to try to figure out the exact max_seq_length for each model
0 commit comments