Skip to content

Commit db1bb28

Browse files
authored
Add generate.py and prepare_shakespeare.py tests (OpenGVLab#42)
1 parent 98e09b5 commit db1bb28

File tree

10 files changed

+195
-45
lines changed

10 files changed

+195
-45
lines changed

.github/workflows/cpu-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ jobs:
3636

3737
- name: Install dependencies
3838
run: |
39-
pip install pytest .
39+
pip install pytest . -r requirements.txt
4040
pip list
4141
4242
- name: Run tests
4343
run: |
44-
pytest -v --durations=10
44+
pytest -v --durations=10 --disable-pytest-warnings --strict-markers

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ __pycache__
22
.idea
33
.DS_Store
44
*.egg-info
5+
build
56

67
# data
78
data

generate.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import time
4+
from pathlib import Path
45
from typing import Optional
56

67
import lightning as L
@@ -47,7 +48,7 @@ def generate(
4748

4849
# forward
4950
logits = model(idx_cond)
50-
logits = logits[:, -1, :] / temperature
51+
logits = logits[:, -1] / temperature
5152

5253
# optionally crop the logits to only the top k options
5354
if top_k is not None:
@@ -58,7 +59,7 @@ def generate(
5859
idx_next = torch.multinomial(probs, num_samples=1)
5960

6061
# concatenate the new column
61-
idx[:, t] = idx_next
62+
idx[:, t:] = idx_next
6263

6364
return idx
6465

@@ -73,8 +74,8 @@ def main(
7374
# compilation fails as it does not support torch.complex64 for RoPE
7475
# compile: bool = False,
7576
accelerator: str = "auto",
76-
checkpoint_path: Optional[str] = None,
77-
tokenizer_path: Optional[str] = None,
77+
checkpoint_path: Optional[Path] = None,
78+
tokenizer_path: Optional[Path] = None,
7879
model_size: str = "7B",
7980
quantize: bool = False,
8081
) -> None:
@@ -95,12 +96,11 @@ def main(
9596
quantize: Whether to quantize the model using the `LLM.int8()` method
9697
"""
9798
if not checkpoint_path:
98-
checkpoint_path = f"./checkpoints/lit-llama/{model_size}/state_dict.pth"
99+
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
99100
if not tokenizer_path:
100-
tokenizer_path = "./checkpoints/lit-llama/tokenizer.model"
101-
102-
assert os.path.isfile(checkpoint_path)
103-
assert os.path.isfile(tokenizer_path)
101+
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
102+
assert checkpoint_path.is_file()
103+
assert tokenizer_path.is_file()
104104

105105
fabric = L.Fabric(accelerator=accelerator, devices=1)
106106

@@ -128,8 +128,8 @@ def main(
128128
model = fabric.setup_module(model)
129129

130130
tokenizer = Tokenizer(tokenizer_path)
131-
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False).to(fabric.device)
132-
encoded_prompt = encoded_prompt[None, :]
131+
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
132+
encoded_prompt = encoded_prompt[None, :] # add batch dimension
133133

134134
L.seed_everything(1234)
135135
t0 = time.time()
@@ -141,8 +141,8 @@ def main(
141141
model.config.block_size, # type: ignore[union-attr,arg-type]
142142
temperature=temperature,
143143
top_k=top_k,
144-
)
145-
print(tokenizer.decode(y[0]))
144+
)[0] # unpack batch dimension
145+
print(tokenizer.decode(y))
146146

147147
print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr)
148148
print(f"Memory used (GB): {torch.cuda.max_memory_reserved() / 1e9:.02f}", file=sys.stderr)

lit_llama/model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7070
return self.scale * x_normed
7171

7272

73-
llama_configs = {
74-
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
75-
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
76-
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
77-
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
78-
}
79-
80-
8173
@dataclass
8274
class LLaMAConfig:
8375
block_size: int = 4096
@@ -88,7 +80,15 @@ class LLaMAConfig:
8880

8981
@classmethod
9082
def from_name(cls, name: str) -> Self:
91-
return cls(**llama_configs[name])
83+
return llama_configs[name]
84+
85+
86+
llama_configs = {
87+
"7B": LLaMAConfig(n_layer=32, n_head=32, n_embd=4096),
88+
"13B": LLaMAConfig(n_layer=40, n_head=40, n_embd=5120),
89+
"30B": LLaMAConfig(n_layer=60, n_head=52, n_embd=6656),
90+
"65B": LLaMAConfig(n_layer=80, n_head=64, n_embd=8192),
91+
}
9292

9393

9494
class CausalSelfAttention(nn.Module):
@@ -206,7 +206,7 @@ def forward(self, idx: torch.Tensor) -> torch.Tensor:
206206
x = block(x)
207207
x = self.transformer.ln_f(x)
208208

209-
logits = self.lm_head(x)
209+
logits = self.lm_head(x) # (b, t, vocab_size)
210210

211211
return logits
212212

lit_llama/tokenizer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import os
2+
from pathlib import Path
3+
from typing import Optional
4+
25
import torch
36
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
47

58

69
class Tokenizer:
710
"""Tokenizer for LLaMA."""
811

9-
def __init__(self, model_path: str) -> None:
10-
self.processor = SentencePieceProcessor(model_file=model_path)
12+
def __init__(self, model_path: Path) -> None:
13+
self.processor = SentencePieceProcessor(model_file=str(model_path))
1114
self.bos_id = self.processor.bos_id()
1215
self.eos_id = self.processor.eos_id()
1316
self.pad_id = self.processor.pad_id()
@@ -16,13 +19,15 @@ def __init__(self, model_path: str) -> None:
1619
def vocab_size(self) -> int:
1720
return self.processor.vocab_size()
1821

19-
def encode(self, string: str, bos: bool = True, eos: bool = False) -> torch.Tensor:
22+
def encode(
23+
self, string: str, bos: bool = True, eos: bool = False, device: Optional[torch.device] = None
24+
) -> torch.Tensor:
2025
tokens = self.processor.encode(string)
2126
if bos:
2227
tokens = [self.bos_id] + tokens
2328
if eos:
2429
tokens = tokens + [self.eos_id]
25-
return torch.tensor(tokens, dtype=torch.int)
30+
return torch.tensor(tokens, dtype=torch.int, device=device)
2631

2732
def decode(self, tokens: torch.Tensor) -> str:
2833
return self.processor.decode(tokens.tolist())

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ lightning>=2.0.0
33
sentencepiece
44
tqdm # convert_checkpoint.py
55
numpy # train.py dataset memmap
6-
jsonargparse # generate.py, convert_checkpoint.py CLI
6+
jsonargparse[signatures] # generate.py, convert_checkpoint.py CLI
77
bitsandbytes # quantization.py

scripts/prepare_shakespeare.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
22-
import os
2322
import sys
24-
import requests
23+
from pathlib import Path
24+
2525
import numpy as np
26+
import requests
2627

2728

28-
def prepare(destination_path: str = "data/shakespeare") -> None:
29-
os.makedirs(destination_path, exist_ok=True)
29+
def prepare(destination_path: Path = Path("data/shakespeare")) -> None:
30+
"""Prepare the "Tiny Shakespeare" dataset."""
31+
destination_path.mkdir(parents=True, exist_ok=True)
32+
3033
# download the tiny shakespeare dataset
31-
input_file_path = os.path.join(destination_path, "input.txt")
32-
if not os.path.exists(input_file_path):
34+
input_file_path = destination_path / "input.txt"
35+
if not input_file_path.exists():
3336
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
3437
with open(input_file_path, "w") as f:
3538
f.write(requests.get(data_url).text)
@@ -40,10 +43,10 @@ def prepare(destination_path: str = "data/shakespeare") -> None:
4043
train_data = data[: int(n * 0.9)]
4144
val_data = data[int(n * 0.9) :]
4245

43-
from tokenizer import Tokenizer
44-
46+
from lit_llama.tokenizer import Tokenizer
47+
4548
Tokenizer.train(input=input_file_path, destination=destination_path)
46-
tokenizer = Tokenizer(os.path.join(destination_path, "tokenizer.model"))
49+
tokenizer = Tokenizer(destination_path / "tokenizer.model")
4750
train_ids = tokenizer.encode(train_data)
4851
val_ids = tokenizer.encode(val_data)
4952
print(f"train has {len(train_ids):,} tokens")
@@ -52,13 +55,14 @@ def prepare(destination_path: str = "data/shakespeare") -> None:
5255
# export to bin files
5356
train_ids = np.array(train_ids, dtype=np.uint16)
5457
val_ids = np.array(val_ids, dtype=np.uint16)
55-
train_ids.tofile(os.path.join(destination_path, "train.bin"))
56-
val_ids.tofile(os.path.join(destination_path, "val.bin"))
58+
train_ids.tofile(destination_path / "train.bin")
59+
val_ids.tofile(destination_path / "val.bin")
5760

5861

5962
if __name__ == "__main__":
60-
wd = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
61-
sys.path.append(wd)
63+
# support running without installing as a package
64+
wd = Path(__file__).parent.parent.resolve()
65+
sys.path.append(str(wd))
6266

6367
from jsonargparse import CLI
6468

tests/test_basic_functionality.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

tests/test_generate.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import functools
2+
import subprocess
3+
import sys
4+
from contextlib import redirect_stdout
5+
from io import StringIO
6+
from pathlib import Path
7+
from unittest import mock
8+
from unittest.mock import Mock, PropertyMock, call, ANY
9+
10+
import pytest
11+
import torch
12+
13+
wd = Path(__file__).parent.parent.absolute()
14+
15+
16+
@functools.lru_cache(maxsize=1)
17+
def load_generate_script():
18+
sys.path.append(str(wd))
19+
20+
import generate
21+
22+
return generate
23+
24+
25+
@pytest.mark.parametrize("B", (1, 2))
26+
def test_generate(B):
27+
generate = load_generate_script()
28+
29+
T, C = 5, 3
30+
logits = torch.randn(B, T, C)
31+
input_idx = torch.randint(10, size=(B, T))
32+
33+
model = Mock(return_value=logits)
34+
max_new_tokens = 20
35+
36+
multinomial_results = []
37+
original_multinomial = torch.multinomial
38+
39+
def multinomial(*args, **kwargs):
40+
out = original_multinomial(*args, **kwargs)
41+
multinomial_results.append(out)
42+
return out
43+
44+
with mock.patch("torch.multinomial", multinomial):
45+
out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10)
46+
47+
assert out.shape == (B, T + max_new_tokens)
48+
multinomial_results = torch.hstack(multinomial_results)
49+
expected = torch.cat((input_idx, multinomial_results), dim=1)
50+
assert out.shape == expected.shape
51+
torch.testing.assert_close(out, expected)
52+
53+
54+
def test_main(tmp_path, monkeypatch):
55+
generate = load_generate_script()
56+
57+
checkpoint_path = tmp_path / "ckpt"
58+
checkpoint_path.touch()
59+
tokenizer_path = tmp_path / "tokenizer"
60+
tokenizer_path.touch()
61+
62+
class FabricMock(PropertyMock):
63+
@property
64+
def device(self):
65+
return torch.device("cpu")
66+
67+
fabric_mock = FabricMock()
68+
monkeypatch.setattr(generate.L, "Fabric", fabric_mock)
69+
model_mock = Mock()
70+
monkeypatch.setattr(generate.LLaMA, "from_name", model_mock)
71+
load_mock = Mock()
72+
monkeypatch.setattr(generate.torch, "load", load_mock)
73+
tokenizer_mock = Mock()
74+
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
75+
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
76+
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
77+
generate_mock = Mock()
78+
generate_mock.return_value = torch.tensor([[3, 2, 1]])
79+
monkeypatch.setattr(generate, "generate", generate_mock)
80+
81+
num_samples = 2
82+
out = StringIO()
83+
with redirect_stdout(out):
84+
generate.main(
85+
checkpoint_path=checkpoint_path,
86+
tokenizer_path=tokenizer_path,
87+
model_size="1T",
88+
accelerator="litpu",
89+
temperature=2.0,
90+
top_k=2,
91+
num_samples=num_samples,
92+
)
93+
94+
model_mock.assert_called_once_with("1T")
95+
load_mock.assert_called_once_with(checkpoint_path)
96+
tokenizer_mock.assert_called_once_with(tokenizer_path)
97+
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
98+
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
99+
model = model_mock.return_value
100+
assert fabric_mock.mock_calls == [
101+
call(accelerator="litpu", devices=1),
102+
call().device.__enter__(),
103+
call().device.__exit__(None, None, None),
104+
call().setup_module(model),
105+
]
106+
model = fabric_mock.return_value.setup_module.return_value
107+
assert (
108+
generate_mock.mock_calls
109+
== [call(model, ANY, 50, model.config.block_size, temperature=2.0, top_k=2)] * num_samples
110+
)
111+
# only the generated result is printed to stdout
112+
assert out.getvalue() == "foo bar baz\n" * num_samples
113+
114+
115+
def test_cli():
116+
cli_path = wd / "generate.py"
117+
output = subprocess.check_output([sys.executable, cli_path, "-h"])
118+
output = str(output.decode())
119+
assert "Generates text samples" in output

tests/test_prepare_shakespeare.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import subprocess
3+
import sys
4+
from pathlib import Path
5+
6+
wd = (Path(__file__).parent.parent / "scripts").absolute()
7+
8+
9+
def test_prepare(tmp_path):
10+
sys.path.append(str(wd))
11+
12+
import prepare_shakespeare
13+
14+
prepare_shakespeare.prepare(tmp_path)
15+
16+
assert set(os.listdir(tmp_path)) == {"train.bin", "tokenizer.model", "tokenizer.vocab", "input.txt", "val.bin"}
17+
18+
19+
def test_cli():
20+
cli_path = wd / "prepare_shakespeare.py"
21+
output = subprocess.check_output([sys.executable, cli_path, "-h"])
22+
output = str(output.decode())
23+
assert 'Prepare the "Tiny Shakespeare"' in output

0 commit comments

Comments
 (0)