Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 34a1346

Browse files
committed
wip feb 8
1 parent 62f5ca8 commit 34a1346

File tree

3 files changed

+42
-48
lines changed

3 files changed

+42
-48
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_generate_errors_with_incorrect_beams(self) -> None:
4646
with self.assertRaises(ValueError):
4747
generation_model.generate(self.inputs, num_beams=0)
4848

49-
@patch("logging.Logger.warning")
49+
@patch("warnings.warn")
5050
def test_warns_when_no_max_len_provided(self, mock) -> None:
5151
generation_model = GenerationUtil(self.model)
5252
generation_model.generate(self.inputs)
@@ -91,12 +91,17 @@ def test_hf_DELETE(self) -> None:
9191
max_len=100,
9292
pad_idx=t5.config.pad_token_id,
9393
num_beams=10,
94+
beam_size_token=t5.config.vocab_size,
9495
)
9596
end = time.time() - start
9697
print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end)
9798
exit()
9899

99100
def test_jit_generate(self) -> None:
101+
# jitted_model = torch.jit.script(self.model)
102+
# encoder = jitted_model.get_encoder()
103+
104+
100105
generation_model = GenerationUtil(self.model)
101106
torch.jit.script(generation_model)
102107

torchtext/prototype/generate.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
get_obj_from_emitting_model_state,
1313
)
1414

15-
import logging
1615
import warnings
1716

18-
logger = logging.getLogger(__name__)
17+
18+
MODEL_KWARGS_TYPE = Dict[str, Dict[str, Union[torch.Tensor, List[Optional[torch.Tensor]], List[torch.Tensor], None]]]
1919

2020

2121
@dataclass
@@ -56,9 +56,7 @@ def __init__(self, model: nn.Module, **kwargs) -> None:
5656
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
5757
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
5858

59-
def _prepare_encoder_decoder_kwargs_for_generation(
60-
self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]
61-
) -> Dict[str, Any]:
59+
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -> MODEL_KWARGS_TYPE:
6260
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
6361
6462
Args:
@@ -72,40 +70,36 @@ def _prepare_encoder_decoder_kwargs_for_generation(
7270
encoder = self.model.get_encoder()
7371

7472
# Create copy of encoder kwargs
75-
encoder_kwargs = model_kwargs.copy()
73+
encoder_kwargs: Dict[str, bool] = {}
7674

77-
# Forward pass
7875
if self.is_huggingface_model:
7976
encoder_kwargs["return_dict"] = True
8077

81-
# import pdb
82-
# pdb.set_trace()
83-
# print(encoder_kwargs.keys())
84-
85-
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
86-
87-
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
88-
78+
# Forward pass
79+
# Explicitly call forward method to assert to assert this is a ScriptModule if JITted
80+
model_kwargs = {"encoder_outputs": encoder.forward(inputs)} # , **encoder_kwargs)
8981
return model_kwargs
9082

9183
def _prepare_decoder_ids_for_generation(
9284
self,
9385
batch_size: int,
9486
pad_idx: int = 0,
9587
device: Optional[torch.device] = None,
96-
model_kwargs: Optional[Dict[str, Any]] = None,
97-
):
88+
model_kwargs: Optional[MODEL_KWARGS_TYPE] = None,
89+
) -> torch.Tensor:
9890
"""Prepare decoder IDs for generation."""
9991
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
100-
return model_kwargs.pop("decoder_input_ids")
92+
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
93+
assert torch.jit.isinstance(decoder_input_ids, torch.Tensor)
94+
return decoder_input_ids
10195
else:
10296
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
10397

10498
def _update_model_kwargs_for_generation(
10599
self,
106100
outputs: Dict[str, Any],
107101
model_kwargs: Dict[str, Any],
108-
) -> Dict[str, Any]:
102+
) -> MODEL_KWARGS_TYPE:
109103
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
110104
111105
Args:
@@ -152,7 +146,7 @@ def greedy_search(
152146
max_len: int,
153147
eos_idx: int,
154148
pad_idx: Optional[int] = None,
155-
model_kwargs: Optional[Dict[str, Any]] = {},
149+
model_kwargs: Optional[MODEL_KWARGS_TYPE] = {},
156150
) -> torch.Tensor:
157151
"""Greedy search decoding for text generation. Takes the most likely next token every time.
158152
@@ -184,9 +178,8 @@ def greedy_search(
184178
_, next_tokens = torch.topk(probs, 1)
185179

186180
# For any finished sequences, padding idx should be the last token
187-
if eos_idx is not None:
188-
if pad_idx is not None:
189-
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
181+
if eos_idx is not None and pad_idx is not None:
182+
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
190183

191184
# Append the next tokens to the previous tokens
192185
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
@@ -233,7 +226,7 @@ def beam_search(
233226
encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output"
234227
encoder_output = model_kwargs["encoder_outputs"][encoder_output_key]
235228

236-
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
229+
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep):
237230
# `emissions` and `N` are unused in this current implementation
238231

239232
i = T # Hacky access to the current seq in inputs
@@ -269,7 +262,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
269262
if end > curr_beam_size:
270263
end = curr_beam_size
271264

272-
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
265+
num_samples = end - start
273266

274267
if prev_step_token_idxs != [-1]:
275268
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
@@ -303,9 +296,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
303296
if self.is_huggingface_model:
304297
model_inputs.update(self._huggingface_model_input_values)
305298

306-
from typing import MappingProxyType
307-
308-
model_inputs = MappingProxyType(model_inputs)
309299
# Forward pass
310300
outputs = self.model(**model_inputs)
311301

@@ -315,17 +305,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
315305

316306
# HF optimizations to reduce overhead in future `forward` calls
317307
if self.is_huggingface_model:
318-
new_model_kwargs = self._update_model_kwargs_for_generation(
319-
outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder
320-
)
321-
if new_model_kwargs["past"] is not None:
322-
import pdb
323-
324-
pdb.set_trace()
325-
beam_indices += [start for _ in range(num_samples)]
308+
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs)
309+
if new_model_kwargs["past"] is not None and len(prev_step_hyp_idxs) > 1:
310+
if len(prev_step_hyp_idxs) == 9:
311+
import pdb
312+
pdb.set_trace()
326313
new_model_kwargs["past"] = self.model._reorder_cache(
327314
new_model_kwargs["past"],
328-
torch.Tensor(beam_indices).to(dtype=torch.int32), # I think this is correct?
315+
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
329316
)
330317

331318
# Keep track of probabilities over vocab for this pairing
@@ -404,7 +391,7 @@ def is_not_neg_one(elem: int) -> bool:
404391
return final_tokens_as_tensors
405392

406393
if num_python_workers > 1:
407-
logger.warning("Multiprocessing has not yet been implemented.")
394+
warnings.warn("Multiprocessing has not yet been implemented.")
408395

409396
all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))]
410397

@@ -473,28 +460,28 @@ def generate(
473460
1. `num_beams` == 1 or `num_beams` is None -> greedy search
474461
2. `num_beams` > 1 -> beam search
475462
"""
476-
model_kwargs = {}
463+
model_kwargs: MODEL_KWARGS_TYPE = {}
477464

478465
if self.is_encoder_decoder:
479-
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
466+
assert torch.jit.isinstance(inputs, torch.Tensor)
467+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs)
480468
inputs = self._prepare_decoder_ids_for_generation(
481469
len(inputs), device=inputs.device, model_kwargs=model_kwargs
482470
)
483471

484472
if max_len is None:
485473
# Too hard to try to figure out the exact max_seq_length for each model
486-
logger.warning("`max_len` was not specified. Defaulting to 256 tokens.")
474+
warnings.warn("`max_len` was not specified. Defaulting to 256 tokens.")
487475
max_len = 256
488476

489-
if num_beams == 1 or num_beams is None:
477+
if num_beams is None or num_beams == 1:
490478
if num_python_workers > 1:
491-
logger.warning(f"Multiprocessing is not implemented for greedy search.")
479+
warnings.warn(f"Multiprocessing is not implemented for greedy search.")
492480
return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, model_kwargs=model_kwargs)
493481
elif num_beams > 1:
494482
if beam_size_token is None:
495483
raise ValueError(
496-
"`beam_size_token` must be specified for beam search. \
497-
If confused about what to put, you can default to the vocab size of the model you are using."
484+
"`beam_size_token` must be specified for beam search. If confused about what to put, you can default to the vocab size of the model you are using."
498485
)
499486
return self.beam_search(
500487
inputs,

torchtext/prototype/models/t5/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,15 @@ def __init__(
134134
for p in self.parameters():
135135
p.requires_grad = False
136136

137-
def prepare_inputs_for_generation(self, input_ids, encoder_outputs):
137+
@torch.jit.export
138+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor) -> Dict[str, torch.Tensor]:
138139
return {"decoder_tokens": input_ids, "encoder_outputs": encoder_outputs}
139140

141+
@torch.jit.export
140142
def get_encoder(self) -> T5Encoder:
141143
return self.encoder
142144

143-
@torch.jit.ignore
145+
@torch.jit.export
144146
def get_decoder(self) -> Optional[T5Decoder]:
145147
if self.decoder is None:
146148
warnings.warn("Decoder is not set on this model.")

0 commit comments

Comments
 (0)