Skip to content

Commit c409960

Browse files
authored
Fix tests without pip install (OpenGVLab#45)
1 parent ef1ea26 commit c409960

File tree

8 files changed

+28
-32
lines changed

8 files changed

+28
-32
lines changed

.github/workflows/cpu-tests.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ jobs:
3434
with:
3535
python-version: ${{ matrix.python-version }}
3636

37-
- name: Install dependencies
37+
- name: Run tests without the package installed
3838
run: |
39-
pip install pytest . -r requirements.txt
39+
pip install pytest -r requirements.txt
4040
pip list
41+
pytest --disable-pytest-warnings --strict-markers
4142
4243
- name: Run tests
4344
run: |
45+
pip install . --no-deps
4446
pytest -v --durations=10 --disable-pytest-warnings --strict-markers

lit_llama/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
2+
from lit_llama.tokenizer import Tokenizer

scripts/prepare_shakespeare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def prepare(destination_path: Path = Path("data/shakespeare")) -> None:
4343
train_data = data[: int(n * 0.9)]
4444
val_data = data[int(n * 0.9) :]
4545

46-
from lit_llama.tokenizer import Tokenizer
46+
from lit_llama import Tokenizer
4747

4848
Tokenizer.train(input=input_file_path, destination=destination_path)
4949
tokenizer = Tokenizer(destination_path / "tokenizer.model")

tests/conftest.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import os
21
import sys
2+
from pathlib import Path
33

44
import pytest
55

6+
wd = Path(__file__).parent.parent.absolute()
7+
68

79
@pytest.fixture()
810
def orig_llama():
9-
wd = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
10-
sys.path.append(wd)
11+
sys.path.append(str(wd))
1112

1213
from scripts.download import download_original
1314

@@ -16,3 +17,13 @@ def orig_llama():
1617
import original_model
1718

1819
return original_model
20+
21+
22+
@pytest.fixture()
23+
def lit_llama():
24+
# this adds support for running tests without the package installed
25+
sys.path.append(str(wd))
26+
27+
import lit_llama
28+
29+
return lit_llama

tests/test_generate.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from io import StringIO
66
from pathlib import Path
77
from unittest import mock
8-
from unittest.mock import Mock, PropertyMock, call, ANY
8+
from unittest.mock import Mock, call, ANY
99

1010
import pytest
1111
import torch
@@ -59,13 +59,12 @@ def test_main(tmp_path, monkeypatch):
5959
tokenizer_path = tmp_path / "tokenizer"
6060
tokenizer_path.touch()
6161

62-
class FabricMock(PropertyMock):
62+
class FabricMock(Mock):
6363
@property
6464
def device(self):
6565
return torch.device("cpu")
6666

67-
fabric_mock = FabricMock()
68-
monkeypatch.setattr(generate.L, "Fabric", fabric_mock)
67+
monkeypatch.setattr(generate.L, "Fabric", FabricMock)
6968
model_mock = Mock()
7069
monkeypatch.setattr(generate.LLaMA, "from_name", model_mock)
7170
load_mock = Mock()
@@ -85,7 +84,6 @@ def device(self):
8584
checkpoint_path=checkpoint_path,
8685
tokenizer_path=tokenizer_path,
8786
model_size="1T",
88-
accelerator="litpu",
8987
temperature=2.0,
9088
top_k=2,
9189
num_samples=num_samples,
@@ -96,18 +94,7 @@ def device(self):
9694
tokenizer_mock.assert_called_once_with(tokenizer_path)
9795
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
9896
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
99-
model = model_mock.return_value
100-
assert fabric_mock.mock_calls == [
101-
call(accelerator="litpu", devices=1),
102-
call().device.__enter__(),
103-
call().device.__exit__(None, None, None),
104-
call().setup_module(model),
105-
]
106-
model = fabric_mock.return_value.setup_module.return_value
107-
assert (
108-
generate_mock.mock_calls
109-
== [call(model, ANY, 50, model.config.block_size, temperature=2.0, top_k=2)] * num_samples
110-
)
97+
assert generate_mock.mock_calls == [call(ANY, ANY, 50, ANY, temperature=2.0, top_k=2)] * num_samples
11198
# only the generated result is printed to stdout
11299
assert out.getvalue() == "foo bar baz\n" * num_samples
113100

tests/test_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import torch
22

3-
import lit_llama.model as lit_llama
4-
53

64
def copy_mlp(llama_mlp, orig_llama_mlp) -> None:
75
orig_llama_mlp.w1.weight.copy_(llama_mlp.c_fc1.weight)
@@ -33,7 +31,7 @@ def copy_weights(llama_model, orig_llama_model) -> None:
3331

3432

3533
@torch.no_grad()
36-
def test_to_orig_llama(orig_llama) -> None:
34+
def test_to_orig_llama(lit_llama, orig_llama) -> None:
3735
block_size = 64
3836
vocab_size = 32000
3937
n_layer = 16

tests/test_rmsnorm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import torch
22

3-
import lit_llama.model as lit_llama
4-
53

64
@torch.no_grad()
7-
def test_rmsnorm(orig_llama) -> None:
5+
def test_rmsnorm(lit_llama, orig_llama) -> None:
86
block_size = 16
97
vocab_size = 16
108

tests/test_rope.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import torch
22

3-
import lit_llama.model as lit_llama
4-
53

64
def build_rope_cache_old(seq_len: int, n_elem: int, dtype: torch.dtype, base: int = 10000) -> torch.Tensor:
75
"""This is the `build_rope_cache` implementation we initially intended to use, but it is numerically not
@@ -53,7 +51,7 @@ def apply_rope_old(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
5351

5452

5553
@torch.no_grad()
56-
def test_rope(orig_llama) -> None:
54+
def test_rope(lit_llama, orig_llama) -> None:
5755
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
5856
x = torch.randint(0, 10000, size=(bs, seq_len, n_head, n_embed // n_head)).float()
5957

0 commit comments

Comments
 (0)