Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e2e781f

Browse files
committed
Add generation utils with greedy search and tests
1 parent 7c58534 commit e2e781f

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from unittest.mock import patch
2+
from torchtext.prototype.generate import GenerationUtil
3+
from torchtext.prototype.models import T5_BASE_GENERATION
4+
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
5+
import torch
6+
7+
8+
class TestGenerationUtil(TorchtextTestCase):
9+
def setUp(self) -> None:
10+
super().setUp()
11+
t5_base = T5_BASE_GENERATION
12+
self.transform = t5_base.transform()
13+
self.model = t5_base.get_model()
14+
self.model.eval()
15+
# Examples taken from T5 Paper and Huggingface
16+
self.inputs = self.transform(
17+
[
18+
"summarize: studies have shown that owning a dog is good for you",
19+
"translate English to German: That is good.",
20+
"cola sentence: The course is jumping well.",
21+
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22+
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
23+
]
24+
)
25+
torch.manual_seed(0)
26+
27+
def test_greedy_generate_with_t5(self) -> None:
28+
generation_model = GenerationUtil(self.model)
29+
30+
tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
31+
generated_text = self.transform.decode(tokens.tolist())
32+
33+
expected_generated_text = [
34+
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .",
35+
"Das ist gut.",
36+
"acceptable",
37+
"4.0",
38+
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage",
39+
]
40+
41+
self.assertEqual(generated_text, expected_generated_text)
42+
43+
def test_generate_errors_with_incorrect_beams(self) -> None:
44+
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)
45+
46+
with self.assertRaises(ValueError):
47+
generation_model.generate(self.inputs, num_beams=0)
48+
49+
@patch("logging.Logger.warning")
50+
def test_warns_when_no_max_len_provided(self, mock) -> None:
51+
generation_model = GenerationUtil(self.model)
52+
generation_model.generate(self.inputs)
53+
mock.assert_called_with("`max_len` was not specified. Defaulting to 100 tokens.")

0 commit comments

Comments
 (0)