Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 476a51c

Browse files
committed
ckkpt
1 parent 7ddcc1c commit 476a51c

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ def test_beam_search(self) -> None:
5757

5858
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30)
5959

60-
import pdb
61-
62-
pdb.set_trace()
63-
6460
generated_text = self.transform.decode(tokens.tolist())
6561

6662
import pdb

torchtext/prototype/generate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
129129
Seq2SeqModelState(
130130
timestep=0,
131131
hidden_states=None,
132-
sequence=input_ids,
132+
sequence=input_ids[:, -1],
133133
lm_scores=None
134134
)
135135
)
@@ -146,30 +146,29 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
146146

147147
model_states = []
148148
for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states):
149+
if isinstance(idx, int):
150+
idx = torch.Tensor([idx])
149151
model_state = get_obj_from_emitting_model_state(model_state_ptr)
150152
model_states.append(
151153
create_emitting_model_state(
152154
Seq2SeqModelState(
153155
timestep=timestep,
154156
hidden_states=outputs["decoder_hidden_states"],
155-
sequence=torch.cat([model_state.sequence[:, -1], idx], dim=-1),
157+
sequence=torch.cat([model_state.sequence, idx], dim=-1),
156158
lm_scores=lm_scores
157159
)
158160
)
159161
)
160162

161-
import pdb
162-
pdb.set_trace()
163-
164-
out_probs = lm_scores[0][0].tolist() * len(prev_step_token_idxs)
163+
out_probs = lm_scores[0].tolist() * len(prev_step_token_idxs)
165164
return out_probs, model_states
166165

167166
options = LexiconFreeSeq2SeqDecoderOptions(
168167
beam_size=num_beams,
169168
beam_size_token=self.model.config.vocab_size,
170-
beam_threshold=1000,
169+
beam_threshold=50,
171170
lm_weight=0.0,
172-
eos_score=0.0,
171+
eos_score=1.0,
173172
log_add=True,
174173
)
175174

@@ -186,7 +185,10 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
186185
decoder.decode_step(emissions.data_ptr(), T, N)
187186
hyps = decoder.get_all_final_hypothesis()
188187

189-
return hyps
188+
token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
189+
max_tokens = max(token_scores, key=lambda x: x[1])
190+
191+
return torch.Tensor(max_tokens[0]).to(torch.int)
190192

191193
def generate(
192194
self,

0 commit comments

Comments
 (0)