@@ -83,6 +83,7 @@ def greedy_search(
83
83
decoder_output = outputs [output_key ]
84
84
85
85
# Calculate probabilities and take the most likely next token
86
+ # Why do we take the last token instead of a mean_pooling across all of them?
86
87
probs = F .log_softmax (decoder_output [:, - 1 ], dim = - 1 )
87
88
_ , next_tokens = torch .topk (probs , 1 )
88
89
@@ -129,38 +130,48 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
129
130
Seq2SeqModelState (
130
131
timestep = 0 ,
131
132
hidden_states = None ,
132
- sequence = input_ids [:, - 1 ] ,
133
+ sequence = input_ids ,
133
134
lm_scores = None
134
135
)
135
136
)
136
137
]
137
138
138
- model_inputs = self .model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
139
- if self .is_huggingface_model :
140
- model_inputs ["return_dict" ] = True
141
- model_inputs ["output_hidden_states" ] = True
142
-
143
- outputs = self .model (** model_inputs )
144
- output_key = "logits" if self .is_huggingface_model else "decoder_output"
145
- lm_scores = outputs [output_key ]
146
-
147
- model_states = []
139
+ out_probs , model_states = [], []
148
140
for idx , model_state_ptr in zip (prev_step_token_idxs , prev_step_model_states ):
149
141
if isinstance (idx , int ):
150
- idx = torch .Tensor ([idx ])
151
- model_state = get_obj_from_emitting_model_state (model_state_ptr )
142
+ idx = torch .Tensor ([idx ]).to (torch .long )
143
+
144
+ # Get previous model state
145
+ prev_model_state = get_obj_from_emitting_model_state (model_state_ptr )
146
+
147
+ # Create new decoder token ids
148
+ new_input_ids = torch .cat ([prev_model_state .sequence [:, - 1 ], idx ], dim = - 1 )
149
+
150
+ # Forward pass
151
+ model_inputs = self .model .prepare_inputs_for_generation (new_input_ids .unsqueeze (dim = 0 ), ** model_kwargs )
152
+ if self .is_huggingface_model :
153
+ model_inputs ["return_dict" ] = True
154
+ model_inputs ["output_hidden_states" ] = True
155
+
156
+ outputs = self .model (** model_inputs )
157
+ output_key = "logits" if self .is_huggingface_model else "decoder_output"
158
+ lm_scores = outputs [output_key ]
159
+
160
+ # Keep track of probabilities over vocab for this pairing
161
+ out_probs .append (torch .squeeze (lm_scores [:, - 1 ]).tolist ())
162
+
163
+ # Keep track of sequence and decoder hidden states
152
164
model_states .append (
153
165
create_emitting_model_state (
154
166
Seq2SeqModelState (
155
167
timestep = timestep ,
156
168
hidden_states = outputs ["decoder_hidden_states" ],
157
- sequence = torch . cat ([ model_state . sequence , idx ], dim = - 1 ),
169
+ sequence = new_input_ids . unsqueeze ( dim = 0 ),
158
170
lm_scores = lm_scores
159
171
)
160
172
)
161
173
)
162
174
163
- out_probs = lm_scores [0 ].tolist () * len (prev_step_token_idxs )
164
175
return out_probs , model_states
165
176
166
177
options = LexiconFreeSeq2SeqDecoderOptions (
@@ -188,7 +199,13 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
188
199
token_scores = [(hyp .tokens , hyp .score ) for hyp in hyps ]
189
200
max_tokens = max (token_scores , key = lambda x : x [1 ])
190
201
191
- return torch .Tensor (max_tokens [0 ]).to (torch .int )
202
+ filtered = list (filter (lambda x : x != - 1 , max_tokens [0 ]))
203
+ final_tokens = [0 ] + filtered
204
+
205
+ import pdb
206
+ pdb .set_trace ()
207
+
208
+ return torch .Tensor (final_tokens ).to (torch .long )
192
209
193
210
def generate (
194
211
self ,
0 commit comments