12
12
get_obj_from_emitting_model_state ,
13
13
)
14
14
15
- import logging
16
15
import warnings
17
16
18
- logger = logging .getLogger (__name__ )
17
+
18
+ MODEL_KWARGS_TYPE = Dict [str , Dict [str , Union [torch .Tensor , List [Optional [torch .Tensor ]], List [torch .Tensor ], None ]]]
19
19
20
20
DEFAULT_MAX_SEQ_LEN = 256
21
21
@@ -61,9 +61,7 @@ def __init__(self, model: nn.Module, **kwargs) -> None:
61
61
self .is_encoder_decoder = kwargs .pop ("is_encoder_decoder" , True )
62
62
self .is_huggingface_model = kwargs .pop ("is_huggingface_model" , False )
63
63
64
- def _prepare_encoder_decoder_kwargs_for_generation (
65
- self , inputs : torch .Tensor , model_kwargs : Dict [str , Any ]
66
- ) -> Dict [str , Any ]:
64
+ def _prepare_encoder_decoder_kwargs_for_generation (self , inputs : torch .Tensor ) -> MODEL_KWARGS_TYPE :
67
65
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
68
66
69
67
Args:
@@ -77,40 +75,36 @@ def _prepare_encoder_decoder_kwargs_for_generation(
77
75
encoder = self .model .get_encoder ()
78
76
79
77
# Create copy of encoder kwargs
80
- encoder_kwargs = model_kwargs . copy ()
78
+ encoder_kwargs : Dict [ str , bool ] = {}
81
79
82
- # Forward pass
83
80
if self .is_huggingface_model :
84
81
encoder_kwargs ["return_dict" ] = True
85
82
86
- # import pdb
87
- # pdb.set_trace()
88
- # print(encoder_kwargs.keys())
89
-
90
- # assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
91
-
92
- model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_kwargs )
93
-
83
+ # Forward pass
84
+ # Explicitly call forward method to assert to assert this is a ScriptModule if JITted
85
+ model_kwargs = {"encoder_outputs" : encoder .forward (inputs )} # , **encoder_kwargs)
94
86
return model_kwargs
95
87
96
88
def _prepare_decoder_ids_for_generation (
97
89
self ,
98
90
batch_size : int ,
99
91
pad_idx : int = 0 ,
100
92
device : Optional [torch .device ] = None ,
101
- model_kwargs : Optional [Dict [ str , Any ] ] = None ,
102
- ):
93
+ model_kwargs : Optional [MODEL_KWARGS_TYPE ] = None ,
94
+ ) -> torch . Tensor :
103
95
"""Prepare decoder IDs for generation."""
104
96
if model_kwargs is not None and "decoder_input_ids" in model_kwargs :
105
- return model_kwargs .pop ("decoder_input_ids" )
97
+ decoder_input_ids = model_kwargs .pop ("decoder_input_ids" )
98
+ assert torch .jit .isinstance (decoder_input_ids , torch .Tensor )
99
+ return decoder_input_ids
106
100
else :
107
101
return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
108
102
109
103
def _update_model_kwargs_for_generation (
110
104
self ,
111
105
outputs : Dict [str , Any ],
112
106
model_kwargs : Dict [str , Any ],
113
- ) -> Dict [ str , Any ] :
107
+ ) -> MODEL_KWARGS_TYPE :
114
108
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
115
109
116
110
Args:
@@ -157,7 +151,7 @@ def greedy_search(
157
151
max_length : int ,
158
152
eos_idx : int ,
159
153
pad_idx : Optional [int ] = None ,
160
- model_kwargs : Optional [Dict [ str , Any ] ] = {},
154
+ model_kwargs : Optional [MODEL_KWARGS_TYPE ] = {},
161
155
) -> torch .Tensor :
162
156
"""Greedy search decoding for text generation. Takes the most likely next token every time.
163
157
@@ -189,9 +183,8 @@ def greedy_search(
189
183
_ , next_tokens = torch .topk (probs , 1 )
190
184
191
185
# For any finished sequences, padding idx should be the last token
192
- if eos_idx is not None :
193
- if pad_idx is not None :
194
- next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
186
+ if eos_idx is not None and pad_idx is not None :
187
+ next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
195
188
196
189
# Append the next tokens to the previous tokens
197
190
input_ids = torch .cat ([input_ids , next_tokens ], dim = - 1 )
@@ -238,7 +231,7 @@ def beam_search(
238
231
encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
239
232
encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
240
233
241
- def update_func (emissions , N , T , prev_step_token_idxs , prev_step_model_states , timestep ):
234
+ def update_func (emissions , N , T , prev_step_token_idxs , prev_step_hyp_idxs , prev_step_model_states , timestep ):
242
235
# `emissions` and `N` are unused in this current implementation
243
236
244
237
i = T # Hacky access to the current seq in inputs
@@ -274,7 +267,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
274
267
if end > curr_beam_size :
275
268
end = curr_beam_size
276
269
277
- num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
270
+ num_samples = end - start
278
271
279
272
if prev_step_token_idxs != [- 1 ]:
280
273
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
@@ -308,9 +301,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
308
301
if self .is_huggingface_model :
309
302
model_inputs .update (self ._huggingface_model_input_values )
310
303
311
- from typing import MappingProxyType
312
-
313
- model_inputs = MappingProxyType (model_inputs )
314
304
# Forward pass
315
305
outputs = self .model (** model_inputs )
316
306
@@ -320,17 +310,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
320
310
321
311
# HF optimizations to reduce overhead in future `forward` calls
322
312
if self .is_huggingface_model :
323
- new_model_kwargs = self ._update_model_kwargs_for_generation (
324
- outputs , new_model_kwargs , is_encoder_decoder = self .is_encoder_decoder
325
- )
326
- if new_model_kwargs ["past" ] is not None :
327
- import pdb
328
-
329
- pdb .set_trace ()
330
- beam_indices += [start for _ in range (num_samples )]
313
+ new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
314
+ if new_model_kwargs ["past" ] is not None and len (prev_step_hyp_idxs ) > 1 :
315
+ if len (prev_step_hyp_idxs ) == 9 :
316
+ import pdb
317
+ pdb .set_trace ()
331
318
new_model_kwargs ["past" ] = self .model ._reorder_cache (
332
319
new_model_kwargs ["past" ],
333
- torch .Tensor (beam_indices ).to (dtype = torch .int32 ), # I think this is correct?
320
+ torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ), # I think this is correct?
334
321
)
335
322
336
323
# Keep track of probabilities over vocab for this pairing
@@ -409,7 +396,7 @@ def is_not_neg_one(elem: int) -> bool:
409
396
return final_tokens_as_tensors
410
397
411
398
if num_python_workers > 1 :
412
- logger . warning ("Multiprocessing has not yet been implemented." )
399
+ warnings . warn ("Multiprocessing has not yet been implemented." )
413
400
414
401
all_final_tokens = [beam_decode_step (i ) for i in range (len (input_ids ))]
415
402
@@ -478,28 +465,28 @@ def generate(
478
465
1. `num_beams` == 1 or `num_beams` is None -> greedy search
479
466
2. `num_beams` > 1 -> beam search
480
467
"""
481
- model_kwargs = {}
468
+ model_kwargs : MODEL_KWARGS_TYPE = {}
482
469
483
470
if self .is_encoder_decoder :
484
- model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs , model_kwargs )
471
+ assert torch .jit .isinstance (inputs , torch .Tensor )
472
+ model_kwargs = self ._prepare_encoder_decoder_kwargs_for_generation (inputs )
485
473
inputs = self ._prepare_decoder_ids_for_generation (
486
474
len (inputs ), device = inputs .device , model_kwargs = model_kwargs
487
475
)
488
476
489
477
if max_length is None :
490
478
# Too hard to try to figure out the exact max_seq_length for each model
491
- logger . warning ( f"`max_length ` was not specified. Defaulting to { DEFAULT_MAX_SEQ_LEN } tokens." )
492
- max_length = DEFAULT_MAX_SEQ_LEN
479
+ warnings . warn ( "`max_len ` was not specified. Defaulting to 256 tokens." )
480
+ max_length = 256
493
481
494
- if num_beams == 1 or num_beams is None :
482
+ if num_beams is None or num_beams == 1 :
495
483
if num_python_workers > 1 :
496
- logger . warning (f"Multiprocessing is not implemented for greedy search." )
484
+ warnings . warn (f"Multiprocessing is not implemented for greedy search." )
497
485
return self .greedy_search (inputs , max_length , eos_idx , pad_idx = pad_idx , model_kwargs = model_kwargs )
498
486
elif num_beams > 1 :
499
487
if beam_size_token is None :
500
488
raise ValueError (
501
- "`beam_size_token` must be specified for beam search. \
502
- If confused about what to put, you can default to the vocab size of the model you are using."
489
+ "`beam_size_token` must be specified for beam search. If confused about what to put, you can default to the vocab size of the model you are using."
503
490
)
504
491
return self .beam_search (
505
492
inputs ,
0 commit comments