@@ -99,15 +99,12 @@ def _update_model_kwargs_for_generation(
99
99
self ,
100
100
outputs : Dict [str , Any ],
101
101
model_kwargs : Dict [str , Any ],
102
- ) -> MODEL_KWARGS_TYPE :
102
+ ) -> None :
103
103
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
104
104
105
105
Args:
106
106
outputs (Dict[str, Any]): LM output.
107
107
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
108
-
109
- Returns:
110
- Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
111
108
"""
112
109
# Update past
113
110
if "past_key_values" in outputs :
@@ -138,8 +135,6 @@ def _update_model_kwargs_for_generation(
138
135
dim = - 1 ,
139
136
)
140
137
141
- return model_kwargs
142
-
143
138
def greedy_search (
144
139
self ,
145
140
input_ids : torch .Tensor ,
@@ -222,6 +217,8 @@ def beam_search(
222
217
Returns:
223
218
Tensor of the generated sequences.
224
219
"""
220
+ device = input_ids .device
221
+
225
222
if self .is_encoder_decoder :
226
223
encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
227
224
encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
@@ -231,9 +228,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231
228
232
229
i = T # Hacky access to the current seq in inputs
233
230
234
- # Copy over the `model_kwargs` in order to modify
235
- new_model_kwargs = model_kwargs .copy ()
236
-
237
231
# For first timestep, create previous step token_idxs and model_states
238
232
if timestep == 0 :
239
233
prev_step_token_idxs = [- 1 ]
@@ -268,7 +262,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
268
262
state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
269
263
token_indices = (
270
264
torch .Tensor (prev_step_token_idxs [start :end ])
271
- .to (dtype = torch .long , device = self . model . device )
265
+ .to (dtype = torch .long , device = device )
272
266
.reshape (num_samples , 1 )
273
267
)
274
268
@@ -281,23 +275,24 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
281
275
), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
282
276
else :
283
277
assert len (prev_model_state_sequences ) == 1
284
- state_and_tokens = token_indices = prev_model_state_sequences [0 ] # dims: [1, 1]
278
+ state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (num_beams , - 1 ) # TODO: Make this more robust
279
+
285
280
286
281
# Cleanup -- combine this with the above
287
282
if self .is_encoder_decoder :
288
283
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
289
284
# This is a view-only operation and doesn't copy
290
- new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
291
- num_samples if timestep > 0 else 1 , - 1 , - 1
285
+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
286
+ num_samples if timestep > 0 else num_beams , - 1 , - 1
292
287
)
293
288
294
289
# Preprocess inputs for generation
295
- model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** new_model_kwargs )
290
+ model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** model_kwargs )
296
291
if self .is_huggingface_model :
297
292
model_inputs .update (self ._huggingface_model_input_values )
298
- if len (prev_step_hyp_idxs ) > 1 and model_inputs [ "past_key_values " ] is not None :
293
+ if len (prev_step_hyp_idxs ) > 1 and model_kwargs [ "past " ] is not None :
299
294
model_inputs ["past_key_values" ] = self .model ._reorder_cache (
300
- model_inputs [ "past_key_values " ],
295
+ model_kwargs [ "past " ],
301
296
torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
302
297
)
303
298
@@ -310,32 +305,23 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
310
305
311
306
# HF optimizations to reduce overhead in future `forward` calls
312
307
if self .is_huggingface_model :
313
- new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
308
+ self ._update_model_kwargs_for_generation (outputs , model_kwargs )
314
309
315
310
# Keep track of probabilities over vocab for this pairing
316
- # TODO: clean up duplicate code in these branches
317
- if timestep == 0 :
318
- sample_lm_scores = torch . squeeze ( lm_scores [: , - 1 ])
311
+ # TODO: fix how we track the number here?
312
+ for i in range ( lm_scores . shape [ 0 ]) :
313
+ sample_lm_scores = lm_scores [i , - 1 ]
319
314
out_probs .append (sample_lm_scores .tolist ())
315
+ # Keep track of sequence and decoder hidden states
320
316
model_states .append (
321
317
create_emitting_model_state (
322
- Seq2SeqModelState (timestep = timestep , sequence = state_and_tokens , lm_scores = sample_lm_scores )
323
- )
324
- )
325
- else :
326
- for i in range (num_samples ):
327
- sample_lm_scores = lm_scores [i , - 1 ]
328
- out_probs .append (sample_lm_scores .tolist ())
329
- # Keep track of sequence and decoder hidden states
330
- model_states .append (
331
- create_emitting_model_state (
332
- Seq2SeqModelState (
333
- timestep = timestep ,
334
- sequence = state_and_tokens [i ].unsqueeze (0 ),
335
- lm_scores = sample_lm_scores ,
336
- )
318
+ Seq2SeqModelState (
319
+ timestep = timestep ,
320
+ sequence = state_and_tokens [i ].unsqueeze (0 ),
321
+ lm_scores = sample_lm_scores ,
337
322
)
338
323
)
324
+ )
339
325
340
326
start += step
341
327
0 commit comments