Skip to content

Commit 98e09b5

Browse files
authored
Convert compare script to tests (OpenGVLab#43)
1 parent 85560ac commit 98e09b5

File tree

6 files changed

+183
-185
lines changed

6 files changed

+183
-185
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ data
88
checkpoints
99
!data/shakespeare/prepare.py
1010

11-
# downloaded by scripts/compare.py
11+
# downloaded by our tests
1212
original_model.py

scripts/compare.py

Lines changed: 0 additions & 184 deletions
This file was deleted.

tests/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
import sys
3+
4+
import pytest
5+
6+
7+
@pytest.fixture()
8+
def orig_llama():
9+
wd = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
10+
sys.path.append(wd)
11+
12+
from scripts.download import download_original
13+
14+
download_original(wd)
15+
16+
import original_model
17+
18+
return original_model

tests/test_model.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
3+
import lit_llama.model as lit_llama
4+
5+
6+
def copy_mlp(llama_mlp, orig_llama_mlp) -> None:
7+
orig_llama_mlp.w1.weight.copy_(llama_mlp.c_fc1.weight)
8+
orig_llama_mlp.w3.weight.copy_(llama_mlp.c_fc2.weight)
9+
orig_llama_mlp.w2.weight.copy_(llama_mlp.c_proj.weight)
10+
11+
12+
def copy_attention(llama_attn, orig_llama_attn) -> None:
13+
n_embd = llama_attn.c_attn.weight.shape[1]
14+
orig_llama_attn.wq.weight.copy_(llama_attn.c_attn.weight[:n_embd])
15+
orig_llama_attn.wk.weight.copy_(llama_attn.c_attn.weight[n_embd:-n_embd])
16+
orig_llama_attn.wv.weight.copy_(llama_attn.c_attn.weight[-n_embd:])
17+
orig_llama_attn.wo.weight.copy_(llama_attn.c_proj.weight)
18+
19+
20+
def copy_block(llama_block, orig_llama_block) -> None:
21+
orig_llama_block.attention_norm.weight.copy_(llama_block.rms_1.scale)
22+
copy_attention(llama_block.attn, orig_llama_block.attention)
23+
orig_llama_block.ffn_norm.weight.copy_(llama_block.rms_2.scale)
24+
copy_mlp(llama_block.mlp, orig_llama_block.feed_forward)
25+
26+
27+
def copy_weights(llama_model, orig_llama_model) -> None:
28+
orig_llama_model.tok_embeddings.weight.copy_(llama_model.transformer.wte.weight)
29+
for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers):
30+
copy_block(llama_block, orig_llama_block)
31+
orig_llama_model.norm.weight.copy_(llama_model.transformer.ln_f.scale)
32+
orig_llama_model.output.weight.copy_(llama_model.lm_head.weight)
33+
34+
35+
@torch.no_grad()
36+
def test_to_orig_llama(orig_llama) -> None:
37+
block_size = 64
38+
vocab_size = 32000
39+
n_layer = 16
40+
n_head = 16
41+
n_embd = 32
42+
43+
llama_config = lit_llama.LLaMAConfig(
44+
block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd
45+
)
46+
orig_llama_config = orig_llama.ModelArgs(
47+
dim=n_embd, n_layers=n_layer, n_heads=n_head, vocab_size=vocab_size, norm_eps=1e-5, max_seq_len=block_size
48+
)
49+
50+
batch_size = 3
51+
52+
token_sample = torch.randint(
53+
0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64
54+
)
55+
56+
llama_model = lit_llama.LLaMA(llama_config)
57+
orig_llama_model = orig_llama.Transformer(orig_llama_config)
58+
59+
copy_weights(llama_model, orig_llama_model)
60+
61+
orig_llama_embed = orig_llama_model.tok_embeddings(token_sample)
62+
llama_embed = llama_model.transformer.wte(token_sample)
63+
assert torch.allclose(orig_llama_embed, llama_embed)
64+
65+
seq_len = token_sample.shape[1]
66+
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"))
67+
mask = torch.triu(mask, diagonal=1)
68+
orig_llama_block_out = orig_llama_model.layers[0](orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], mask)
69+
llama_block_out = llama_model.transformer.h[0](llama_embed)
70+
assert torch.allclose(orig_llama_block_out, llama_block_out)
71+
72+
expected = orig_llama_model(token_sample, 0)
73+
out = llama_model(token_sample)
74+
assert torch.allclose(out, expected)

tests/test_rmsnorm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
import lit_llama.model as lit_llama
4+
5+
6+
@torch.no_grad()
7+
def test_rmsnorm(orig_llama) -> None:
8+
block_size = 16
9+
vocab_size = 16
10+
11+
sample = torch.rand(size=(2, block_size, vocab_size), dtype=torch.float32)
12+
13+
eps = 1e-6
14+
orig_llama_rmsnorm = orig_llama.RMSNorm(vocab_size, eps=eps)(sample)
15+
llama_rmsnorm = lit_llama.RMSNorm(vocab_size, eps=eps)(sample)
16+
17+
assert torch.allclose(orig_llama_rmsnorm, llama_rmsnorm)

tests/test_rope.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
3+
import lit_llama.model as lit_llama
4+
5+
6+
def build_rope_cache_old(seq_len: int, n_elem: int, dtype: torch.dtype, base: int = 10000) -> torch.Tensor:
7+
"""This is the `build_rope_cache` implementation we initially intended to use, but it is numerically not
8+
exactly equivalent to the one in the Meta model. We keep it here for posterity.
9+
10+
Derived from:mers/rope/__init__.py
11+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license MIT License:
12+
""" # noqa: E501
13+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
14+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem))
15+
16+
# Create position indexes `[0, 1, ..., seq_len - 1]`
17+
seq_idx = torch.arange(seq_len, dtype=dtype)
18+
19+
# Calculate the product of position index and $\theta_i$
20+
idx_theta = torch.outer(seq_idx, theta)
21+
22+
# Concatenate so that for row $m$ we have
23+
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
24+
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
25+
26+
# Cache them
27+
cos_cache = idx_theta2.cos()[None, None, :, :]
28+
sin_cache = idx_theta2.sin()[None, None, :, :]
29+
30+
return torch.stack((cos_cache, sin_cache), dim=0)
31+
32+
33+
def rotate_neg_half(x: torch.Tensor) -> torch.Tensor:
34+
# $\frac{d}{2}$
35+
d_2 = x.shape[-1] // 2
36+
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # noqa: E501
37+
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
38+
39+
40+
def apply_rope_old(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
41+
"""This is the `apply_rope` implementation we initially intended to use, but it is numerically not exactly
42+
equivalent to the one in the Meta model.
43+
44+
We keep it here for posterity.
45+
"""
46+
neg_half_x = rotate_neg_half(x)
47+
cos, sin = rope_cache
48+
# truncate to support variable sizes
49+
T = x.size(2)
50+
cos = cos[:, :, :T]
51+
sin = sin[:, :, :T]
52+
return (x * cos) + (neg_half_x * sin)
53+
54+
55+
@torch.no_grad()
56+
def test_rope(orig_llama) -> None:
57+
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
58+
x = torch.randint(0, 10000, size=(bs, seq_len, n_head, n_embed // n_head)).float()
59+
60+
freqs_cis = orig_llama.precompute_freqs_cis(n_embed // n_head, seq_len)
61+
llama_rope_cache = lit_llama.build_rope_cache(seq_len, n_embed // n_head, dtype=x.dtype)
62+
assert torch.equal(freqs_cis, llama_rope_cache)
63+
64+
llama_x_rope = lit_llama.apply_rope(x.transpose(1, 2), llama_rope_cache).transpose(1, 2)
65+
orig_llama_x_rope, _ = orig_llama.apply_rotary_emb(x, x, freqs_cis)
66+
67+
assert torch.equal(llama_x_rope, orig_llama_x_rope)
68+
69+
# For posterity, we show here that our older implementation we initially wanted to use
70+
# is not numerically equivalent to Meta's rope implementation
71+
llama_rope_cache_old = build_rope_cache_old(seq_len, n_embed // n_head, dtype=x.dtype)
72+
llama_x_rope_old = apply_rope_old(x, llama_rope_cache_old)
73+
assert not torch.allclose(llama_x_rope_old, orig_llama_x_rope)

0 commit comments

Comments
 (0)