12
12
get_obj_from_emitting_model_state ,
13
13
)
14
14
15
- import logging
16
15
import warnings
17
16
18
- logger = logging .getLogger (__name__ )
17
+
18
+ MODEL_KWARGS_TYPE = Dict [str , Dict [str , Union [torch .Tensor , List [Optional [torch .Tensor ]], List [torch .Tensor ], None ]]]
19
19
20
20
21
21
@dataclass
@@ -56,9 +56,7 @@ def __init__(self, model: nn.Module, **kwargs) -> None:
56
56
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
57
57
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
58
58
59
- def _prepare_encoder_decoder_kwargs_for_generation (
60
- self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]
61
- ) -> Dict [str , Any ]:
59
+ def _prepare_encoder_decoder_kwargs_for_generation (self , inputs : torch .Tensor ) -> MODEL_KWARGS_TYPE :
62
60
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
63
61
64
62
Args:
@@ -72,40 +70,36 @@ def _prepare_encoder_decoder_kwargs_for_generation(
72
70
encoder = self .model .get_encoder ()
73
71
74
72
# Create copy of encoder kwargs
75
- encoder_kwargs = model_kwargs . copy ()
73
+ encoder_kwargs : Dict [ str , bool ] = {}
76
74
77
- # Forward pass
78
75
if self .is_huggingface_model :
79
76
encoder_kwargs ["return_dict" ] = True
80
77
81
- # import pdb
82
- # pdb.set_trace()
83
- # print(encoder_kwargs.keys())
84
-
85
- # assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
86
-
87
- model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
88
-
78
+ # Forward pass
79
+ # Explicitly call forward method to assert to assert this is a ScriptModule if JITted
80
+ model_kwargs = {"encoder_outputs" : encoder .forward (inputs )} # , **encoder_kwargs)
89
81
return model_kwargs
90
82
91
83
def _prepare_decoder_ids_for_generation (
92
84
self ,
93
85
batch_size : int ,
94
86
pad_idx : int = 0 ,
95
87
device : Optional [torch .device ] = None ,
96
- model_kwargs : Optional [Dict [ str , Any ] ] = None ,
97
- ):
88
+ model_kwargs : Optional [MODEL_KWARGS_TYPE ] = None ,
89
+ ) -> torch . Tensor :
98
90
"""Prepare decoder IDs for generation."""
99
91
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
100
- return model_kwargs .pop ("decoder_input_ids" )
92
+ decoder_input_ids = model_kwargs .pop ("decoder_input_ids" )
93
+ assert torch .jit .isinstance (decoder_input_ids , torch .Tensor )
94
+ return decoder_input_ids
101
95
else :
102
96
return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
103
97
104
98
def _update_model_kwargs_for_generation (
105
99
self ,
106
100
outputs : Dict [str , Any ],
107
101
model_kwargs : Dict [str , Any ],
108
- ) -> Dict [ str , Any ] :
102
+ ) -> MODEL_KWARGS_TYPE :
109
103
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
110
104
111
105
Args:
@@ -152,7 +146,7 @@ def greedy_search(
152
146
max_len : int ,
153
147
eos_idx : int ,
154
148
pad_idx : Optional [int ] = None ,
155
- model_kwargs : Optional [Dict [ str , Any ] ] = {},
149
+ model_kwargs : Optional [MODEL_KWARGS_TYPE ] = {},
156
150
) -> torch .Tensor :
157
151
"""Greedy search decoding for text generation. Takes the most likely next token every time.
158
152
@@ -184,9 +178,8 @@ def greedy_search(
184
178
_ , next_tokens = torch .topk (probs , 1 )
185
179
186
180
# For any finished sequences, padding idx should be the last token
187
- if eos_idx is not None :
188
- if pad_idx is not None :
189
- next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
181
+ if eos_idx is not None and pad_idx is not None :
182
+ next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
190
183
191
184
# Append the next tokens to the previous tokens
192
185
input_ids = torch .cat ([input_ids , next_tokens ], dim = - 1 )
@@ -233,7 +226,7 @@ def beam_search(
233
226
encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
234
227
encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
235
228
236
- def update_func (emissions , N , T , prev_step_token_idxs , prev_step_model_states , timestep ):
229
+ def update_func (emissions , N , T , prev_step_token_idxs , prev_step_hyp_idxs , prev_step_model_states , timestep ):
237
230
# `emissions` and `N` are unused in this current implementation
238
231
239
232
i = T # Hacky access to the current seq in inputs
@@ -269,7 +262,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
269
262
if end > curr_beam_size :
270
263
end = curr_beam_size
271
264
272
- num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
265
+ num_samples = end - start
273
266
274
267
if prev_step_token_idxs != [- 1 ]:
275
268
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
@@ -303,9 +296,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
303
296
if self .is_huggingface_model :
304
297
model_inputs .update (self ._huggingface_model_input_values )
305
298
306
- from typing import MappingProxyType
307
-
308
- model_inputs = MappingProxyType (model_inputs )
309
299
# Forward pass
310
300
outputs = self .model (** model_inputs )
311
301
@@ -315,17 +305,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
315
305
316
306
# HF optimizations to reduce overhead in future `forward` calls
317
307
if self .is_huggingface_model :
318
- new_model_kwargs = self ._update_model_kwargs_for_generation (
319
- outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder
320
- )
321
- if new_model_kwargs ["past" ] is not None :
322
- import pdb
323
-
324
- pdb .set_trace ()
325
- beam_indices += [start for _ in range (num_samples )]
308
+ new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
309
+ if new_model_kwargs ["past" ] is not None and len (prev_step_hyp_idxs ) > 1 :
310
+ if len (prev_step_hyp_idxs ) == 9 :
311
+ import pdb
312
+ pdb .set_trace ()
326
313
new_model_kwargs ["past" ] = self .model ._reorder_cache (
327
314
new_model_kwargs ["past" ],
328
- torch .Tensor (beam_indices ).to (dtype = torch .int32 ), # I think this is correct?
315
+ torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
329
316
)
330
317
331
318
# Keep track of probabilities over vocab for this pairing
@@ -404,7 +391,7 @@ def is_not_neg_one(elem: int) -> bool:
404
391
return final_tokens_as_tensors
405
392
406
393
if num_python_workers > 1 :
407
- logger . warning ("Multiprocessing has not yet been implemented." )
394
+ warnings . warn ("Multiprocessing has not yet been implemented." )
408
395
409
396
all_final_tokens = [beam_decode_step (i ) for i in range (len (input_ids ))]
410
397
@@ -473,28 +460,28 @@ def generate(
473
460
1. `num_beams` == 1 or `num_beams` is None -> greedy search
474
461
2. `num_beams` > 1 -> beam search
475
462
"""
476
- model_kwargs = {}
463
+ model_kwargs : MODEL_KWARGS_TYPE = {}
477
464
478
465
if self .is_encoder_decoder :
479
- model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
466
+ assert torch .jit .isinstance (inputs , torch .Tensor )
467
+ model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs )
480
468
inputs = self ._prepare_decoder_ids_for_generation (
481
469
len (inputs ), device = inputs .device , model_kwargs = model_kwargs
482
470
)
483
471
484
472
if max_len is None :
485
473
# Too hard to try to figure out the exact max_seq_length for each model
486
- logger . warning ("`max_len` was not specified. Defaulting to 256 tokens." )
474
+ warnings . warn ("`max_len` was not specified. Defaulting to 256 tokens." )
487
475
max_len = 256
488
476
489
- if num_beams == 1 or num_beams is None :
477
+ if num_beams is None or num_beams == 1 :
490
478
if num_python_workers > 1 :
491
- logger . warning (f"Multiprocessing is not implemented for greedy search." )
479
+ warnings . warn (f"Multiprocessing is not implemented for greedy search." )
492
480
return self .greedy_search (inputs , max_len , eos_idx , pad_idx = pad_idx , model_kwargs = model_kwargs )
493
481
elif num_beams > 1 :
494
482
if beam_size_token is None :
495
483
raise ValueError (
496
- "`beam_size_token` must be specified for beam search. \
497
- If confused about what to put, you can default to the vocab size of the model you are using."
484
+ "`beam_size_token` must be specified for beam search. If confused about what to put, you can default to the vocab size of the model you are using."
498
485
)
499
486
return self .beam_search (
500
487
inputs ,
0 commit comments