1
1
from dataclasses import dataclass
2
- from typing import Any , Dict , List , Optional , Tuple
2
+ from typing import Any , Dict , List , Optional , Tuple , Union
3
3
4
4
import torch
5
5
import torch .nn .functional as F
13
13
)
14
14
15
15
import logging
16
-
16
+ import warnings
17
17
logger = logging .getLogger (__name__ )
18
18
19
19
@@ -47,34 +47,52 @@ class GenerationUtil(nn.Module):
47
47
More examples can be found in the `notebooks` directory of this repository.
48
48
"""
49
49
50
+ _huggingface_model_input_values = {
51
+ "return_dict" : True ,
52
+ "use_cache" : True ,
53
+ "output_hidden_states" : True
54
+ }
55
+
50
56
def __init__ (self , model : nn .Module , ** kwargs ) -> None :
51
57
super ().__init__ ()
52
58
self .model = model
53
59
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
54
60
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
55
61
56
- def _prepare_encoder_decoder_kwargs_for_generation (self , inputs , model_kwargs ):
57
- """Modified from."""
62
+ def _prepare_encoder_decoder_kwargs_for_generation (self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
63
+ """Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
64
+
65
+ Args:
66
+ inputs: (Tensor): Tokenized startings sequence(s).
67
+ model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
68
+
69
+ Returns:
70
+ Modified model_kwargs with addition of encoded input sequence(s).
71
+ """
58
72
# Get encoder
59
73
encoder = self .model .get_encoder ()
60
74
61
- # Prepare encoder args and encoder kwargs from model kwargs
62
- irrelevant_prefix = ["decoder_" , "cross_attn" , "use_cache" ]
63
- encoder_kwargs = {}
64
- for argument , value in model_kwargs .items ():
65
- if not any ([argument .startswith (p ) for p in irrelevant_prefix ]):
66
- encoder_kwargs [argument ] = value
75
+ # Create copy of encoder kwargs
76
+ encoder_kwargs = model_kwargs .copy ()
67
77
68
78
# Forward pass
69
79
if self .is_huggingface_model :
70
80
encoder_kwargs ["return_dict" ] = True
81
+
82
+ # import pdb
83
+ # pdb.set_trace()
84
+ # print(encoder_kwargs.keys())
85
+
86
+ # assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
87
+
71
88
model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
72
89
73
90
return model_kwargs
74
91
75
92
def _prepare_decoder_ids_for_generation (
76
93
self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , model_kwargs : Optional [Dict [str , Any ]] = None
77
94
):
95
+ """Prepare decoder IDs for generation."""
78
96
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
79
97
return model_kwargs .pop ("decoder_input_ids" )
80
98
else :
@@ -84,16 +102,23 @@ def _update_model_kwargs_for_generation(
84
102
self ,
85
103
outputs : Dict [str , Any ],
86
104
model_kwargs : Dict [str , Any ],
87
- is_encoder_decoder : bool = False ,
88
105
) -> Dict [str , Any ]:
89
- """Modified from."""
106
+ """After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
107
+
108
+ Args:
109
+ outputs (Dict[str, Any]): LM output.
110
+ model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
111
+
112
+ Returns:
113
+ Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
114
+ """
90
115
# Update past
91
116
if "past_key_values" in outputs :
92
- model_kwargs ["past" ] = outputs . past_key_values
117
+ model_kwargs ["past" ] = outputs [ " past_key_values" ]
93
118
elif "mems" in outputs :
94
- model_kwargs ["past" ] = outputs . mems
119
+ model_kwargs ["past" ] = outputs [ " mems" ]
95
120
elif "past_buckets_states" in outputs :
96
- model_kwargs ["past" ] = outputs . past_buckets_states
121
+ model_kwargs ["past" ] = outputs [ " past_buckets_states" ]
97
122
else :
98
123
model_kwargs ["past" ] = None
99
124
@@ -102,13 +127,19 @@ def _update_model_kwargs_for_generation(
102
127
token_type_ids = model_kwargs ["token_type_ids" ]
103
128
model_kwargs ["token_type_ids" ] = torch .cat ([token_type_ids , token_type_ids [:, - 1 ].unsqueeze (- 1 )], dim = - 1 )
104
129
105
- # Update attention mask
106
- if not is_encoder_decoder :
130
+ if not self .is_encoder_decoder :
107
131
if "attention_mask" in model_kwargs :
108
132
attention_mask = model_kwargs ["attention_mask" ]
109
133
model_kwargs ["attention_mask" ] = torch .cat (
110
134
[attention_mask , attention_mask .new_ones ((attention_mask .shape [0 ], 1 ))], dim = - 1
111
135
)
136
+ else :
137
+ if "decoder_attention_mask" in model_kwargs :
138
+ decoder_attention_mask = model_kwargs ["decoder_attention_mask" ]
139
+ model_kwargs ["decoder_attention_mask" ] = torch .cat (
140
+ [decoder_attention_mask , decoder_attention_mask .new_ones ((decoder_attention_mask .shape [0 ], 1 ))],
141
+ dim = - 1 ,
142
+ )
112
143
113
144
return model_kwargs
114
145
@@ -132,9 +163,7 @@ def greedy_search(
132
163
while True :
133
164
model_inputs = self .model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
134
165
if self .is_huggingface_model :
135
- model_inputs ["return_dict" ] = True
136
- model_inputs ["use_cache" ] = True
137
- model_inputs ["output_hidden_states" ] = True
166
+ model_inputs .update (self ._huggingface_model_input_values )
138
167
139
168
# Get model output
140
169
outputs = self .model (** model_inputs )
@@ -174,7 +203,7 @@ def beam_search(
174
203
eos_idx : int ,
175
204
num_python_workers : int ,
176
205
max_inference_batch_size : int ,
177
- model_kwargs ,
206
+ model_kwargs : Dict [ str , Any ] ,
178
207
) -> torch .Tensor :
179
208
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
180
209
@@ -257,26 +286,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
257
286
num_samples if timestep > 0 else 1 , - 1 , - 1
258
287
)
259
288
260
- # Forward pass
289
+ # Preprocess inputs for generation
261
290
model_inputs = self .model .prepare_inputs_for_generation (state_and_tokens , ** new_model_kwargs )
262
- print (model_inputs .get ("use_cache" ), model_inputs .get ("past_key_values" ))
263
-
264
291
if self .is_huggingface_model :
265
- model_inputs ["return_dict" ] = True
266
- model_inputs ["use_cache" ] = True
267
- model_inputs ["output_hidden_states" ] = True
292
+ model_inputs .update (self ._huggingface_model_input_values )
268
293
269
- print ( model_inputs . get ( "use_cache" ), model_inputs . get ( "past_key_values" ))
294
+ from typing import MappingProxyType
270
295
296
+ model_inputs = MappingProxyType (model_inputs )
297
+ # Forward pass
271
298
outputs = self .model (** model_inputs )
299
+
300
+ # Collect outputs
272
301
output_key = "logits" if self .is_huggingface_model else "decoder_output"
273
302
lm_scores = outputs [output_key ]
274
303
275
304
# HF optimizations to reduce overhead in future `forward` calls
276
305
if self .is_huggingface_model :
277
306
new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder )
278
307
if new_model_kwargs ["past" ] is not None :
279
- new_model_kwargs ["past" ] = self .model ._reorder_cache (new_model_kwargs ["past" ], torch .Tensor (num_samples ).to (dtype = torch .int32 , device = self .model .device ))
308
+ import pdb
309
+ pdb .set_trace ()
310
+ beam_indices += [start for _ in range (num_samples )]
311
+ new_model_kwargs ["past" ] = self .model ._reorder_cache (
312
+ new_model_kwargs ["past" ],
313
+ torch .Tensor (beam_indices ).to (dtype = torch .int32 ) # I think this is correct?
314
+ )
280
315
281
316
# Keep track of probabilities over vocab for this pairing
282
317
# TODO: clean up duplicate code in these branches
@@ -302,7 +337,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
302
337
)
303
338
)
304
339
)
305
-
340
+
306
341
start += step
307
342
308
343
return out_probs , model_states
@@ -375,6 +410,8 @@ def forward(
375
410
num_python_workers : int = 1 ,
376
411
max_inference_batch_size : int = 16 ,
377
412
):
413
+ """Calls self.generate() method."""
414
+ warnings .warn ("Forward method simply calls `GenerationUtils.generate()`. Please use generate method directly." )
378
415
return self .generate (
379
416
inputs = inputs ,
380
417
num_beams = num_beams ,
@@ -388,41 +425,42 @@ def forward(
388
425
max_inference_batch_size = max_inference_batch_size ,
389
426
)
390
427
391
-
392
428
def generate (
393
429
self ,
394
430
inputs : Optional [torch .Tensor ] = None ,
395
431
num_beams : Optional [int ] = None ,
396
432
max_len : Optional [int ] = None ,
397
433
pad_idx : int = 0 ,
398
434
eos_idx : int = 1 ,
435
+ num_python_workers : int = 1 ,
399
436
beam_threshold : int = 100 ,
400
437
beam_size_token : Optional [int ] = None ,
401
438
eos_score : float = - 1.0 ,
402
- num_python_workers : int = 1 ,
403
439
max_inference_batch_size : int = 16 ,
404
440
) -> torch .Tensor :
405
- """Generation method.
441
+ """Entrypoint generation method.
406
442
407
443
Args:
408
444
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
409
445
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
410
446
max_len (int): Max length to generate responses.
411
447
pad_idx (int): Padding index. Defaults to 0.
412
448
eos_idx (int): End-of-sequence index. Defaults to 1.
449
+ num_python_workers (int): If > 1, using multiprocessing on CPU.
413
450
beam_size_token (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model.
414
451
beam_threshold (int): Threshold before pruning; specific to beam search.
415
452
eos_score (float): Score to input when `eos_idx` is generated; specific to beam search.
453
+ max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 16.
416
454
417
455
Returns:
418
456
Tensor of Tensors containing output sequences as ids.
419
457
420
- Conditions for generation: \
421
- 1. `num_beams` == 1 or `num_beams` is None -> greedy search \
458
+ Conditions for generation:
459
+ 1. `num_beams` == 1 or `num_beams` is None -> greedy search
422
460
2. `num_beams` > 1 -> beam search
423
461
"""
424
462
model_kwargs = {}
425
-
463
+
426
464
if self .is_encoder_decoder :
427
465
model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
428
466
inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , model_kwargs = model_kwargs )
@@ -433,6 +471,8 @@ def generate(
433
471
max_len = 256
434
472
435
473
if num_beams == 1 or num_beams is None :
474
+ if num_python_workers > 1 :
475
+ logger .warning (f"Multiprocessing is not implemented for greedy search." )
436
476
return self .greedy_search (inputs , max_len , eos_idx , pad_idx = pad_idx , model_kwargs = model_kwargs )
437
477
elif num_beams > 1 :
438
478
if beam_size_token is None :
0 commit comments