Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b699de2

Browse files
committed
Add generation utils with greedy search and tests
1 parent c008115 commit b699de2

File tree

5 files changed

+290
-3
lines changed

5 files changed

+290
-3
lines changed

notebooks/hf_vs_tt_t5.ipynb

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### Ensuring the TorchText T5 implementation matches other OSS implementations\n",
8+
"\n",
9+
"> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 29,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"from transformers import T5Model\n",
19+
"from torchtext.prototype.models import T5_BASE\n",
20+
"\n",
21+
"import torch"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 30,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"input_sentence = [\"translate to Spanish: My name is Joe\"]\n",
31+
"output_sentence = [\"Me llamo Joe\"]\n",
32+
"\n",
33+
"transform = T5_BASE.transform()\n",
34+
"tt_t5_model = T5_BASE.get_model()\n",
35+
"\n",
36+
"hf_t5_model = T5Model.from_pretrained(\"t5-base\")"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 31,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"tokenized_sentence = transform(input_sentence)\n",
46+
"tokenized_output = transform(output_sentence)\n",
47+
"\n",
48+
"tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n",
49+
"hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n",
50+
"\n",
51+
"assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n",
52+
"assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))"
53+
]
54+
}
55+
],
56+
"metadata": {
57+
"kernelspec": {
58+
"display_name": "Python 3.9.13 ('torchtext39')",
59+
"language": "python",
60+
"name": "python3"
61+
},
62+
"language_info": {
63+
"codemirror_mode": {
64+
"name": "ipython",
65+
"version": 3
66+
},
67+
"file_extension": ".py",
68+
"mimetype": "text/x-python",
69+
"name": "python",
70+
"nbconvert_exporter": "python",
71+
"pygments_lexer": "ipython3",
72+
"version": "3.9.13"
73+
},
74+
"orig_nbformat": 4,
75+
"vscode": {
76+
"interpreter": {
77+
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
78+
}
79+
}
80+
},
81+
"nbformat": 4,
82+
"nbformat_minor": 2
83+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [
15+
{
16+
"name": "stderr",
17+
"output_type": "stream",
18+
"text": [
19+
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
20+
" from .autonotebook import tqdm as notebook_tqdm\n"
21+
]
22+
}
23+
],
24+
"source": [
25+
"from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n",
26+
"from torchtext.prototype.generate import GenerationUtil"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": 2,
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n",
36+
"bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n",
37+
"gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 3,
43+
"metadata": {},
44+
"outputs": [
45+
{
46+
"name": "stderr",
47+
"output_type": "stream",
48+
"text": [
49+
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
50+
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
51+
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
52+
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
53+
"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
54+
" warnings.warn(\n"
55+
]
56+
},
57+
{
58+
"name": "stdout",
59+
"output_type": "stream",
60+
"text": [
61+
"['owning a dog is good for you, according to studies. a dog is']\n"
62+
]
63+
}
64+
],
65+
"source": [
66+
"# Testing Huggingface's T5\n",
67+
"test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n",
68+
"generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n",
69+
"t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n",
70+
"test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
71+
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n",
72+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 4,
78+
"metadata": {},
79+
"outputs": [
80+
{
81+
"name": "stdout",
82+
"output_type": "stream",
83+
"text": [
84+
"['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n"
85+
]
86+
}
87+
],
88+
"source": [
89+
"# Testing Huggingface's BART\n",
90+
"test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n",
91+
" \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n",
92+
" \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n",
93+
"generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n",
94+
"bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n",
95+
"test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
96+
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n",
97+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": 5,
103+
"metadata": {},
104+
"outputs": [
105+
{
106+
"name": "stdout",
107+
"output_type": "stream",
108+
"text": [
109+
"[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n"
110+
]
111+
}
112+
],
113+
"source": [
114+
"# Testing Huggingface's GPT2\n",
115+
"test_sequence = [\"I enjoy walking with my cute dog\"]\n",
116+
"generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n",
117+
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
118+
"test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
119+
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
120+
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
121+
]
122+
}
123+
],
124+
"metadata": {
125+
"kernelspec": {
126+
"display_name": "Python 3.9.13 ('torchtext39')",
127+
"language": "python",
128+
"name": "python3"
129+
},
130+
"language_info": {
131+
"codemirror_mode": {
132+
"name": "ipython",
133+
"version": 3
134+
},
135+
"file_extension": ".py",
136+
"mimetype": "text/x-python",
137+
"name": "python",
138+
"nbconvert_exporter": "python",
139+
"pygments_lexer": "ipython3",
140+
"version": "3.9.13"
141+
},
142+
"orig_nbformat": 4,
143+
"vscode": {
144+
"interpreter": {
145+
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
146+
}
147+
}
148+
},
149+
"nbformat": 4,
150+
"nbformat_minor": 2
151+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from unittest.mock import patch
2+
from torchtext.prototype.generate import GenerationUtil
3+
from torchtext.prototype.models import T5_BASE_GENERATION
4+
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
5+
import torch
6+
7+
8+
class TestGenerationUtil(TorchtextTestCase):
9+
def setUp(self) -> None:
10+
super().setUp()
11+
t5_base = T5_BASE_GENERATION
12+
self.transform = t5_base.transform()
13+
self.model = t5_base.get_model()
14+
self.model.eval()
15+
# Examples taken from T5 Paper and Huggingface
16+
self.inputs = self.transform(
17+
[
18+
"summarize: studies have shown that owning a dog is good for you",
19+
"translate English to German: That is good.",
20+
"cola sentence: The course is jumping well.",
21+
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22+
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
23+
]
24+
)
25+
torch.manual_seed(0)
26+
27+
def test_greedy_generate_with_t5(self) -> None:
28+
generation_model = GenerationUtil(self.model)
29+
30+
tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
31+
generated_text = self.transform.decode(tokens.tolist())
32+
33+
expected_generated_text = [
34+
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .",
35+
"Das ist gut.",
36+
"acceptable",
37+
"4.0",
38+
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage",
39+
]
40+
41+
self.assertEqual(generated_text, expected_generated_text)
42+
43+
def test_generate_errors_with_incorrect_beams(self) -> None:
44+
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)
45+
46+
with self.assertRaises(ValueError):
47+
generation_model.generate(self.inputs, num_beams=0)
48+
49+
@patch("logging.Logger.warning")
50+
def test_warns_when_no_max_len_provided(self, mock) -> None:
51+
generation_model = GenerationUtil(self.model)
52+
generation_model.generate(self.inputs)
53+
mock.assert_called_with("`max_len` was not specified. Defaulting to 100 tokens.")

torchtext/prototype/generate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GenerationUtil:
2929
3030
More examples can be found in the `notebooks` directory of this repository.
3131
"""
32+
3233
def __init__(self, model: nn.Module, is_encoder_decoder: bool = True, is_huggingface_model: bool = False) -> None:
3334
self.model = model
3435
self.is_encoder_decoder = is_encoder_decoder
@@ -53,7 +54,7 @@ def greedy_search(
5354
eos_idx (int): End of sequence index.
5455
pad_idx (int): Padding index.
5556
**model_kwargs
56-
57+
5758
Returns:
5859
Batch of sequences decoded by greedy search.
5960
"""
@@ -125,7 +126,7 @@ def generate(
125126
encoder = self.model.get_encoder()
126127
model_kwargs["encoder_outputs"] = encoder(inputs)
127128
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)
128-
129+
129130
if max_len is None:
130131
# Too hard to try to figure out the exact max_seq_length for each model
131132
logger.warning("`max_len` was not specified. Defaulting to 256 tokens.")

torchtext/prototype/models/t5/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def __init__(
134134
for p in self.parameters():
135135
p.requires_grad = False
136136

137-
@torch.jit.ignore
138137
def prepare_inputs_for_generation(self, input_ids, encoder_outputs):
139138
return {"decoder_tokens": input_ids, "encoder_outputs": encoder_outputs}
140139

0 commit comments

Comments
 (0)