@@ -129,7 +129,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
129
129
Seq2SeqModelState (
130
130
timestep = 0 ,
131
131
hidden_states = None ,
132
- sequence = input_ids ,
132
+ sequence = input_ids [:, - 1 ] ,
133
133
lm_scores = None
134
134
)
135
135
)
@@ -146,30 +146,29 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
146
146
147
147
model_states = []
148
148
for idx , model_state_ptr in zip (prev_step_token_idxs , prev_step_model_states ):
149
+ if isinstance (idx , int ):
150
+ idx = torch .Tensor ([idx ])
149
151
model_state = get_obj_from_emitting_model_state (model_state_ptr )
150
152
model_states .append (
151
153
create_emitting_model_state (
152
154
Seq2SeqModelState (
153
155
timestep = timestep ,
154
156
hidden_states = outputs ["decoder_hidden_states" ],
155
- sequence = torch .cat ([model_state .sequence [:, - 1 ] , idx ], dim = - 1 ),
157
+ sequence = torch .cat ([model_state .sequence , idx ], dim = - 1 ),
156
158
lm_scores = lm_scores
157
159
)
158
160
)
159
161
)
160
162
161
- import pdb
162
- pdb .set_trace ()
163
-
164
- out_probs = lm_scores [0 ][0 ].tolist () * len (prev_step_token_idxs )
163
+ out_probs = lm_scores [0 ].tolist () * len (prev_step_token_idxs )
165
164
return out_probs , model_states
166
165
167
166
options = LexiconFreeSeq2SeqDecoderOptions (
168
167
beam_size = num_beams ,
169
168
beam_size_token = self .model .config .vocab_size ,
170
- beam_threshold = 1000 ,
169
+ beam_threshold = 50 ,
171
170
lm_weight = 0.0 ,
172
- eos_score = 0 .0 ,
171
+ eos_score = 1 .0 ,
173
172
log_add = True ,
174
173
)
175
174
@@ -186,7 +185,10 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
186
185
decoder .decode_step (emissions .data_ptr (), T , N )
187
186
hyps = decoder .get_all_final_hypothesis ()
188
187
189
- return hyps
188
+ token_scores = [(hyp .tokens , hyp .score ) for hyp in hyps ]
189
+ max_tokens = max (token_scores , key = lambda x : x [1 ])
190
+
191
+ return torch .Tensor (max_tokens [0 ]).to (torch .int )
190
192
191
193
def generate (
192
194
self ,
0 commit comments