@@ -104,15 +104,12 @@ def _update_model_kwargs_for_generation(
104
104
self ,
105
105
outputs : Dict [str , Any ],
106
106
model_kwargs : Dict [str , Any ],
107
- ) -> MODEL_KWARGS_TYPE :
107
+ ) -> None :
108
108
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
109
109
110
110
Args:
111
111
outputs (Dict[str, Any]): LM output.
112
112
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
113
-
114
- Returns:
115
- Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
116
113
"""
117
114
# Update past
118
115
if "past_key_values" in outputs :
@@ -143,8 +140,6 @@ def _update_model_kwargs_for_generation(
143
140
dim = - 1 ,
144
141
)
145
142
146
- return model_kwargs
147
-
148
143
def greedy_search (
149
144
self ,
150
145
input_ids : torch .Tensor ,
@@ -227,6 +222,8 @@ def beam_search(
227
222
Returns:
228
223
Tensor of the generated sequences.
229
224
"""
225
+ device = input_ids .device
226
+
230
227
if self .is_encoder_decoder :
231
228
encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
232
229
encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
@@ -236,9 +233,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
236
233
237
234
i = T # Hacky access to the current seq in inputs
238
235
239
- # Copy over the `model_kwargs` in order to modify
240
- new_model_kwargs = model_kwargs .copy ()
241
-
242
236
# For first timestep, create previous step token_idxs and model_states
243
237
if timestep == 0 :
244
238
prev_step_token_idxs = [- 1 ]
@@ -273,7 +267,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
273
267
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
274
268
token_indices = (
275
269
torch .Tensor (prev_step_token_idxs [start :end ])
276
- .to (dtype = torch .long , device = self . model . device )
270
+ .to (dtype = torch .long , device = device )
277
271
.reshape (num_samples , 1 )
278
272
)
279
273
@@ -286,23 +280,24 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
286
280
), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
287
281
else :
288
282
assert len (prev_model_state_sequences ) == 1
289
- state_and_tokens = token_indices = prev_model_state_sequences [0 ] # dims: [1, 1]
283
+ state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (num_beams , - 1 ) # TODO: Make this more robust
284
+
290
285
291
286
# Cleanup -- combine this with the above
292
287
if self .is_encoder_decoder :
293
288
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
294
289
# This is a view-only operation and doesn't copy
295
- new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
296
- num_samples if timestep > 0 else 1 , - 1 , - 1
290
+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
291
+ num_samples if timestep > 0 else num_beams , - 1 , - 1
297
292
)
298
293
299
294
# Preprocess inputs for generation
300
- model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** new_model_kwargs )
295
+ model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** model_kwargs )
301
296
if self .is_huggingface_model :
302
297
model_inputs .update (self ._huggingface_model_input_values )
303
- if len (prev_step_hyp_idxs ) > 1 and model_inputs [ "past_key_values " ] is not None :
298
+ if len (prev_step_hyp_idxs ) > 1 and model_kwargs [ "past " ] is not None :
304
299
model_inputs ["past_key_values" ] = self .model ._reorder_cache (
305
- model_inputs [ "past_key_values " ],
300
+ model_kwargs [ "past " ],
306
301
torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
307
302
)
308
303
@@ -315,32 +310,23 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
315
310
316
311
# HF optimizations to reduce overhead in future `forward` calls
317
312
if self .is_huggingface_model :
318
- new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
313
+ self ._update_model_kwargs_for_generation (outputs , model_kwargs )
319
314
320
315
# Keep track of probabilities over vocab for this pairing
321
- # TODO: clean up duplicate code in these branches
322
- if timestep == 0 :
323
- sample_lm_scores = torch . squeeze ( lm_scores [: , - 1 ])
316
+ # TODO: fix how we track the number here?
317
+ for i in range ( lm_scores . shape [ 0 ]) :
318
+ sample_lm_scores = lm_scores [i , - 1 ]
324
319
out_probs .append (sample_lm_scores .tolist ())
320
+ # Keep track of sequence and decoder hidden states
325
321
model_states .append (
326
322
create_emitting_model_state (
327
- Seq2SeqModelState (timestep = timestep , sequence = state_and_tokens , lm_scores = sample_lm_scores )
328
- )
329
- )
330
- else :
331
- for i in range (num_samples ):
332
- sample_lm_scores = lm_scores [i , - 1 ]
333
- out_probs .append (sample_lm_scores .tolist ())
334
- # Keep track of sequence and decoder hidden states
335
- model_states .append (
336
- create_emitting_model_state (
337
- Seq2SeqModelState (
338
- timestep = timestep ,
339
- sequence = state_and_tokens [i ].unsqueeze (0 ),
340
- lm_scores = sample_lm_scores ,
341
- )
323
+ Seq2SeqModelState (
324
+ timestep = timestep ,
325
+ sequence = state_and_tokens [i ].unsqueeze (0 ),
326
+ lm_scores = sample_lm_scores ,
342
327
)
343
328
)
329
+ )
344
330
345
331
start += step
346
332
0 commit comments