Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7ddcc1c

Browse files
committed
Test Flashlight Text 12/22/22
1 parent b699de2 commit 7ddcc1c

File tree

2 files changed

+121
-9
lines changed

2 files changed

+121
-9
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def setUp(self) -> None:
1616
self.inputs = self.transform(
1717
[
1818
"summarize: studies have shown that owning a dog is good for you",
19-
"translate English to German: That is good.",
20-
"cola sentence: The course is jumping well.",
21-
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22-
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
19+
# "translate English to German: That is good.",
20+
# "cola sentence: The course is jumping well.",
21+
# "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22+
# "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
2323
]
2424
)
2525
torch.manual_seed(0)
@@ -50,4 +50,20 @@ def test_generate_errors_with_incorrect_beams(self) -> None:
5050
def test_warns_when_no_max_len_provided(self, mock) -> None:
5151
generation_model = GenerationUtil(self.model)
5252
generation_model.generate(self.inputs)
53-
mock.assert_called_with("`max_len` was not specified. Defaulting to 100 tokens.")
53+
mock.assert_called_with("`max_len` was not specified. Defaulting to 256 tokens.")
54+
55+
def test_beam_search(self) -> None:
56+
generation_model = GenerationUtil(self.model)
57+
58+
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30)
59+
60+
import pdb
61+
62+
pdb.set_trace()
63+
64+
generated_text = self.transform.decode(tokens.tolist())
65+
66+
import pdb
67+
pdb.set_trace()
68+
69+

torchtext/prototype/generate.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
1-
from typing import Optional
1+
from dataclasses import dataclass
2+
from typing import List, Optional
23

34
import torch
45
import torch.nn.functional as F
56
from torch import nn
7+
from flashlight.lib.text.decoder import LexiconFreeSeq2SeqDecoder, LexiconFreeSeq2SeqDecoderOptions, ZeroLM, create_emitting_model_state, get_obj_from_emitting_model_state
68

79
import logging
810

911
logger = logging.getLogger(__name__)
1012

1113

14+
@dataclass
15+
class Seq2SeqModelState(object):
16+
17+
timestep: int
18+
hidden_states: List[torch.Tensor]
19+
sequence: torch.Tensor
20+
lm_scores: torch.Tensor
21+
22+
1223
class GenerationUtil:
1324
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.
1425
@@ -92,8 +103,90 @@ def greedy_search(
92103

93104
return input_ids
94105

95-
def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor:
96-
raise NotImplementedError()
106+
def beam_search(
107+
self,
108+
input_ids: torch.Tensor,
109+
num_beams: int,
110+
max_len: int,
111+
vocab_size: int,
112+
eos_idx: int = 1,
113+
**model_kwargs
114+
) -> torch.Tensor:
115+
116+
# Is this right?
117+
T = max_len
118+
N = vocab_size
119+
120+
def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
121+
# `emissions_ptr` should always be the same (from encoder output)
122+
# N is not needed
123+
# T is not needed
124+
125+
if timestep == 0:
126+
prev_step_token_idxs = input_ids
127+
prev_step_model_states = [
128+
create_emitting_model_state(
129+
Seq2SeqModelState(
130+
timestep=0,
131+
hidden_states=None,
132+
sequence=input_ids,
133+
lm_scores=None
134+
)
135+
)
136+
]
137+
138+
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
139+
if self.is_huggingface_model:
140+
model_inputs["return_dict"] = True
141+
model_inputs["output_hidden_states"] = True
142+
143+
outputs = self.model(**model_inputs)
144+
output_key = "logits" if self.is_huggingface_model else "decoder_output"
145+
lm_scores = outputs[output_key]
146+
147+
model_states = []
148+
for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states):
149+
model_state = get_obj_from_emitting_model_state(model_state_ptr)
150+
model_states.append(
151+
create_emitting_model_state(
152+
Seq2SeqModelState(
153+
timestep=timestep,
154+
hidden_states=outputs["decoder_hidden_states"],
155+
sequence=torch.cat([model_state.sequence[:, -1], idx], dim=-1),
156+
lm_scores=lm_scores
157+
)
158+
)
159+
)
160+
161+
import pdb
162+
pdb.set_trace()
163+
164+
out_probs = lm_scores[0][0].tolist() * len(prev_step_token_idxs)
165+
return out_probs, model_states
166+
167+
options = LexiconFreeSeq2SeqDecoderOptions(
168+
beam_size=num_beams,
169+
beam_size_token=self.model.config.vocab_size,
170+
beam_threshold=1000,
171+
lm_weight=0.0,
172+
eos_score=0.0,
173+
log_add=True,
174+
)
175+
176+
decoder = LexiconFreeSeq2SeqDecoder(
177+
options=options,
178+
lm=ZeroLM(),
179+
eos_idx=eos_idx,
180+
update_func=update_func,
181+
max_output_length=max_len
182+
)
183+
184+
emissions = model_kwargs["encoder_outputs"].get("encoder_output")
185+
186+
decoder.decode_step(emissions.data_ptr(), T, N)
187+
hyps = decoder.get_all_final_hypothesis()
188+
189+
return hyps
97190

98191
def generate(
99192
self,
@@ -135,6 +228,9 @@ def generate(
135228
if num_beams == 1 or num_beams is None:
136229
return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, **model_kwargs)
137230
elif num_beams > 1:
138-
return self.beam_search(inputs, num_beams, max_len)
231+
if torch.has_cuda:
232+
logger.warning("No CUDA parellelization available through CUDA yet.")
233+
# Implement some sort of multiprocessing here
234+
return self.beam_search(inputs, num_beams, vocab_size=self.model.config.vocab_size, max_len=max_len, **model_kwargs)
139235
else:
140236
raise ValueError("`num_beams` must be >= 1.")

0 commit comments

Comments
 (0)