1
1
from dataclasses import dataclass
2
- from typing import Any , Dict , List , Optional , Tuple
2
+ from typing import Any , Dict , List , Optional , Tuple , Union
3
3
4
4
import torch
5
5
import torch .nn .functional as F
12
12
get_obj_from_emitting_model_state ,
13
13
)
14
14
15
+ import logging
16
+ import warnings
15
17
logger = logging .getLogger (__name__ )
16
18
17
19
DEFAULT_MAX_SEQ_LEN = 256
@@ -50,34 +52,52 @@ class GenerationUtils(nn.Module):
50
52
More examples can be found in the `notebooks` directory of this repository.
51
53
"""
52
54
55
+ _huggingface_model_input_values = {
56
+ "return_dict" : True ,
57
+ "use_cache" : True ,
58
+ "output_hidden_states" : True
59
+ }
60
+
53
61
def __init__ (self , model : nn .Module , ** kwargs ) -> None :
54
62
super ().__init__ ()
55
63
self .model = model
56
64
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
57
65
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
58
66
59
- def _prepare_encoder_decoder_kwargs_for_generation (self , inputs , model_kwargs ):
60
- """Modified from."""
67
+ def _prepare_encoder_decoder_kwargs_for_generation (self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
68
+ """Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
69
+
70
+ Args:
71
+ inputs: (Tensor): Tokenized startings sequence(s).
72
+ model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
73
+
74
+ Returns:
75
+ Modified model_kwargs with addition of encoded input sequence(s).
76
+ """
61
77
# Get encoder
62
78
encoder = self .model .get_encoder ()
63
79
64
- # Prepare encoder args and encoder kwargs from model kwargs
65
- irrelevant_prefix = ["decoder_" , "cross_attn" , "use_cache" ]
66
- encoder_kwargs = {}
67
- for argument , value in model_kwargs .items ():
68
- if not any ([argument .startswith (p ) for p in irrelevant_prefix ]):
69
- encoder_kwargs [argument ] = value
80
+ # Create copy of encoder kwargs
81
+ encoder_kwargs = model_kwargs .copy ()
70
82
71
83
# Forward pass
72
84
if self .is_huggingface_model :
73
85
encoder_kwargs ["return_dict" ] = True
86
+
87
+ # import pdb
88
+ # pdb.set_trace()
89
+ # print(encoder_kwargs.keys())
90
+
91
+ # assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
92
+
74
93
model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
75
94
76
95
return model_kwargs
77
96
78
97
def _prepare_decoder_ids_for_generation (
79
98
self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , model_kwargs : Optional [Dict [str , Any ]] = None
80
99
):
100
+ """Prepare decoder IDs for generation."""
81
101
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
82
102
return model_kwargs .pop ("decoder_input_ids" )
83
103
else :
@@ -87,16 +107,23 @@ def _update_model_kwargs_for_generation(
87
107
self ,
88
108
outputs : Dict [str , Any ],
89
109
model_kwargs : Dict [str , Any ],
90
- is_encoder_decoder : bool = False ,
91
110
) -> Dict [str , Any ]:
92
- """Modified from."""
111
+ """After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
112
+
113
+ Args:
114
+ outputs (Dict[str, Any]): LM output.
115
+ model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
116
+
117
+ Returns:
118
+ Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
119
+ """
93
120
# Update past
94
121
if "past_key_values" in outputs :
95
- model_kwargs ["past" ] = outputs . past_key_values
122
+ model_kwargs ["past" ] = outputs [ " past_key_values" ]
96
123
elif "mems" in outputs :
97
- model_kwargs ["past" ] = outputs . mems
124
+ model_kwargs ["past" ] = outputs [ " mems" ]
98
125
elif "past_buckets_states" in outputs :
99
- model_kwargs ["past" ] = outputs . past_buckets_states
126
+ model_kwargs ["past" ] = outputs [ " past_buckets_states" ]
100
127
else :
101
128
model_kwargs ["past" ] = None
102
129
@@ -105,13 +132,19 @@ def _update_model_kwargs_for_generation(
105
132
token_type_ids = model_kwargs ["token_type_ids" ]
106
133
model_kwargs ["token_type_ids" ] = torch .cat ([token_type_ids , token_type_ids [:, - 1 ].unsqueeze (- 1 )], dim = - 1 )
107
134
108
- # Update attention mask
109
- if not is_encoder_decoder :
135
+ if not self .is_encoder_decoder :
110
136
if "attention_mask" in model_kwargs :
111
137
attention_mask = model_kwargs ["attention_mask" ]
112
138
model_kwargs ["attention_mask" ] = torch .cat (
113
139
[attention_mask , attention_mask .new_ones ((attention_mask .shape [0 ], 1 ))], dim = - 1
114
140
)
141
+ else :
142
+ if "decoder_attention_mask" in model_kwargs :
143
+ decoder_attention_mask = model_kwargs ["decoder_attention_mask" ]
144
+ model_kwargs ["decoder_attention_mask" ] = torch .cat (
145
+ [decoder_attention_mask , decoder_attention_mask .new_ones ((decoder_attention_mask .shape [0 ], 1 ))],
146
+ dim = - 1 ,
147
+ )
115
148
116
149
return model_kwargs
117
150
@@ -135,9 +168,7 @@ def greedy_search(
135
168
while True :
136
169
model_inputs = self .model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
137
170
if self .is_huggingface_model :
138
- model_inputs ["return_dict" ] = True
139
- model_inputs ["use_cache" ] = True
140
- model_inputs ["output_hidden_states" ] = True
171
+ model_inputs .update (self ._huggingface_model_input_values )
141
172
142
173
# Get model output
143
174
outputs = self .model (** model_inputs )
@@ -177,7 +208,7 @@ def beam_search(
177
208
eos_idx : int ,
178
209
num_python_workers : int ,
179
210
max_inference_batch_size : int ,
180
- model_kwargs ,
211
+ model_kwargs : Dict [ str , Any ] ,
181
212
) -> torch .Tensor :
182
213
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
183
214
@@ -260,26 +291,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
260
291
num_samples if timestep > 0 else 1 , - 1 , - 1
261
292
)
262
293
263
- # Forward pass
294
+ # Preprocess inputs for generation
264
295
model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
265
- print (model_inputs .get ("use_cache" ), model_inputs .get ("past_key_values" ))
266
-
267
296
if self .is_huggingface_model :
268
- model_inputs ["return_dict" ] = True
269
- model_inputs ["use_cache" ] = True
270
- model_inputs ["output_hidden_states" ] = True
297
+ model_inputs .update (self ._huggingface_model_input_values )
271
298
272
- print ( model_inputs . get ( "use_cache" ), model_inputs . get ( "past_key_values" ))
299
+ from typing import MappingProxyType
273
300
301
+ model_inputs = MappingProxyType (model_inputs )
302
+ # Forward pass
274
303
outputs = self .model (** model_inputs )
304
+
305
+ # Collect outputs
275
306
output_key = "logits" if self .is_huggingface_model else "decoder_output"
276
307
lm_scores = outputs [output_key ]
277
308
278
309
# HF optimizations to reduce overhead in future `forward` calls
279
310
if self .is_huggingface_model :
280
311
new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder )
281
312
if new_model_kwargs ["past" ] is not None :
282
- new_model_kwargs ["past" ] = self .model ._reorder_cache (new_model_kwargs ["past" ], torch .Tensor (num_samples ).to (dtype = torch .int32 , device = self .model .device ))
313
+ import pdb
314
+ pdb .set_trace ()
315
+ beam_indices += [start for _ in range (num_samples )]
316
+ new_model_kwargs ["past" ] = self .model ._reorder_cache (
317
+ new_model_kwargs ["past" ],
318
+ torch .Tensor (beam_indices ).to (dtype = torch .int32 ) # I think this is correct?
319
+ )
283
320
284
321
# Keep track of probabilities over vocab for this pairing
285
322
# TODO: clean up duplicate code in these branches
@@ -305,7 +342,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
305
342
)
306
343
)
307
344
)
308
-
345
+
309
346
start += step
310
347
311
348
return out_probs , model_states
@@ -378,6 +415,8 @@ def forward(
378
415
num_python_workers : int = 1 ,
379
416
max_inference_batch_size : int = 16 ,
380
417
):
418
+ """Calls self.generate() method."""
419
+ warnings .warn ("Forward method simply calls `GenerationUtils.generate()`. Please use generate method directly." )
381
420
return self .generate (
382
421
inputs = inputs ,
383
422
num_beams = num_beams ,
@@ -391,41 +430,42 @@ def forward(
391
430
max_inference_batch_size = max_inference_batch_size ,
392
431
)
393
432
394
-
395
433
def generate (
396
434
self ,
397
435
inputs : Optional [torch .Tensor ] = None ,
398
436
num_beams : Optional [int ] = None ,
399
437
max_length : Optional [int ] = None ,
400
438
pad_idx : int = 0 ,
401
439
eos_idx : int = 1 ,
440
+ num_python_workers : int = 1 ,
402
441
beam_threshold : int = 100 ,
403
442
beam_size_token : Optional [int ] = None ,
404
443
eos_score : float = - 1.0 ,
405
- num_python_workers : int = 1 ,
406
444
max_inference_batch_size : int = 16 ,
407
445
) -> torch .Tensor :
408
- """Generation method.
446
+ """Entrypoint generation method.
409
447
410
448
Args:
411
449
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
412
450
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
413
451
max_length (int): Max length to generate responses.
414
452
pad_idx (int): Padding index. Defaults to 0.
415
453
eos_idx (int): End-of-sequence index. Defaults to 1.
454
+ num_python_workers (int): If > 1, using multiprocessing on CPU.
416
455
beam_size_token (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model.
417
456
beam_threshold (int): Threshold before pruning; specific to beam search.
418
457
eos_score (float): Score to input when `eos_idx` is generated; specific to beam search.
458
+ max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 16.
419
459
420
460
Returns:
421
461
Tensor of Tensors containing output sequences as ids.
422
462
423
- Conditions for generation: \
424
- 1. `num_beams` == 1 or `num_beams` is None -> greedy search \
463
+ Conditions for generation:
464
+ 1. `num_beams` == 1 or `num_beams` is None -> greedy search
425
465
2. `num_beams` > 1 -> beam search
426
466
"""
427
467
model_kwargs = {}
428
-
468
+
429
469
if self .is_encoder_decoder :
430
470
model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
431
471
inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , model_kwargs = model_kwargs )
@@ -436,6 +476,8 @@ def generate(
436
476
max_length = DEFAULT_MAX_SEQ_LEN
437
477
438
478
if num_beams == 1 or num_beams is None :
479
+ if num_python_workers > 1 :
480
+ logger .warning (f"Multiprocessing is not implemented for greedy search." )
439
481
return self .greedy_search (inputs , max_length , eos_idx , pad_idx = pad_idx , model_kwargs = model_kwargs )
440
482
elif num_beams > 1 :
441
483
if beam_size_token is None :
0 commit comments