Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b5c3210

Browse files
committed
chkpt
1 parent 476a51c commit b5c3210

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
5555
def test_beam_search(self) -> None:
5656
generation_model = GenerationUtil(self.model)
5757

58-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30)
58+
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=100)
5959

6060
generated_text = self.transform.decode(tokens.tolist())
6161

torchtext/prototype/generate.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def greedy_search(
8383
decoder_output = outputs[output_key]
8484

8585
# Calculate probabilities and take the most likely next token
86+
# Why do we take the last token instead of a mean_pooling across all of them?
8687
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
8788
_, next_tokens = torch.topk(probs, 1)
8889

@@ -129,38 +130,48 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
129130
Seq2SeqModelState(
130131
timestep=0,
131132
hidden_states=None,
132-
sequence=input_ids[:, -1],
133+
sequence=input_ids,
133134
lm_scores=None
134135
)
135136
)
136137
]
137138

138-
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
139-
if self.is_huggingface_model:
140-
model_inputs["return_dict"] = True
141-
model_inputs["output_hidden_states"] = True
142-
143-
outputs = self.model(**model_inputs)
144-
output_key = "logits" if self.is_huggingface_model else "decoder_output"
145-
lm_scores = outputs[output_key]
146-
147-
model_states = []
139+
out_probs, model_states = [], []
148140
for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states):
149141
if isinstance(idx, int):
150-
idx = torch.Tensor([idx])
151-
model_state = get_obj_from_emitting_model_state(model_state_ptr)
142+
idx = torch.Tensor([idx]).to(torch.long)
143+
144+
# Get previous model state
145+
prev_model_state = get_obj_from_emitting_model_state(model_state_ptr)
146+
147+
# Create new decoder token ids
148+
new_input_ids = torch.cat([prev_model_state.sequence[:, -1], idx], dim=-1)
149+
150+
# Forward pass
151+
model_inputs = self.model.prepare_inputs_for_generation(new_input_ids.unsqueeze(dim=0), **model_kwargs)
152+
if self.is_huggingface_model:
153+
model_inputs["return_dict"] = True
154+
model_inputs["output_hidden_states"] = True
155+
156+
outputs = self.model(**model_inputs)
157+
output_key = "logits" if self.is_huggingface_model else "decoder_output"
158+
lm_scores = outputs[output_key]
159+
160+
# Keep track of probabilities over vocab for this pairing
161+
out_probs.append(torch.squeeze(lm_scores[:, -1]).tolist())
162+
163+
# Keep track of sequence and decoder hidden states
152164
model_states.append(
153165
create_emitting_model_state(
154166
Seq2SeqModelState(
155167
timestep=timestep,
156168
hidden_states=outputs["decoder_hidden_states"],
157-
sequence=torch.cat([model_state.sequence, idx], dim=-1),
169+
sequence=new_input_ids.unsqueeze(dim=0),
158170
lm_scores=lm_scores
159171
)
160172
)
161173
)
162174

163-
out_probs = lm_scores[0].tolist() * len(prev_step_token_idxs)
164175
return out_probs, model_states
165176

166177
options = LexiconFreeSeq2SeqDecoderOptions(
@@ -188,7 +199,13 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
188199
token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
189200
max_tokens = max(token_scores, key=lambda x: x[1])
190201

191-
return torch.Tensor(max_tokens[0]).to(torch.int)
202+
filtered = list(filter(lambda x: x != -1, max_tokens[0]))
203+
final_tokens = [0] + filtered
204+
205+
import pdb
206+
pdb.set_trace()
207+
208+
return torch.Tensor(final_tokens).to(torch.long)
192209

193210
def generate(
194211
self,

0 commit comments

Comments
 (0)