Skip to content

Commit f1c8520

Browse files
authored
[BugFix] Fix input positions for long context with sliding window (#2088)
1 parent 096827c commit f1c8520

File tree

5 files changed

+75
-17
lines changed

5 files changed

+75
-17
lines changed

tests/conftest.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import List, Optional, Tuple
23

34
import pytest
@@ -7,21 +8,32 @@
78
from vllm import LLM, SamplingParams
89
from vllm.transformers_utils.tokenizer import get_tokenizer
910

10-
_TEST_PROMPTS = [
11-
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
12-
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
13-
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
14-
"Describe the basic components of a neural network and how it can be trained.",
15-
"Write a short story about a robot that dreams for the first time.",
16-
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
17-
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
18-
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
19-
]
11+
_TEST_PROMPTS = ["prompts/example.txt"]
12+
_LONG_PROMPTS = ["prompts/summary.txt"]
13+
14+
15+
def _read_prompts(filename: str) -> str:
16+
prompts = []
17+
with open(filename, "r") as f:
18+
prompt = f.readline()
19+
prompts.append(prompt)
20+
return prompts
2021

2122

2223
@pytest.fixture
2324
def example_prompts() -> List[str]:
24-
return _TEST_PROMPTS
25+
prompts = []
26+
for filename in _TEST_PROMPTS:
27+
prompts += _read_prompts(os.path.join("tests", filename))
28+
return prompts
29+
30+
31+
@pytest.fixture
32+
def example_long_prompts() -> List[str]:
33+
prompts = []
34+
for filename in _LONG_PROMPTS:
35+
prompts += _read_prompts(os.path.join("tests", filename))
36+
return prompts
2537

2638

2739
_STR_DTYPE_TO_TORCH_DTYPE = {

tests/models/test_mistral.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
2+
3+
Run `pytest tests/models/test_mistral.py --forked`.
4+
"""
5+
import pytest
6+
7+
MODELS = [
8+
"mistralai/Mistral-7B-Instruct-v0.1",
9+
]
10+
11+
12+
@pytest.mark.parametrize("model", MODELS)
13+
@pytest.mark.parametrize("dtype", ["bfloat16"])
14+
@pytest.mark.parametrize("max_tokens", [128])
15+
def test_models(
16+
hf_runner,
17+
vllm_runner,
18+
example_long_prompts,
19+
model: str,
20+
dtype: str,
21+
max_tokens: int,
22+
) -> None:
23+
hf_model = hf_runner(model, dtype=dtype)
24+
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
25+
del hf_model
26+
27+
vllm_model = vllm_runner(model, dtype=dtype)
28+
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
29+
del vllm_model
30+
31+
for i in range(len(example_long_prompts)):
32+
hf_output_ids, hf_output_str = hf_outputs[i]
33+
vllm_output_ids, vllm_output_str = vllm_outputs[i]
34+
assert hf_output_str == vllm_output_str, (
35+
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
36+
assert hf_output_ids == vllm_output_ids, (
37+
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

tests/prompts/example.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
2+
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
3+
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
4+
Describe the basic components of a neural network and how it can be trained.
5+
Write a short story about a robot that dreams for the first time.
6+
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
7+
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
8+
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'

0 commit comments

Comments
 (0)