Skip to content

Commit cb08cd0

Browse files
authored
[Minor] Fix duplication of ignored seq group in engine step (#1666)
1 parent 2a2c135 commit cb08cd0

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

tests/test_regression.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Containing tests that check for regressions in vLLM's behavior.
2+
3+
It should include tests that are reported by users and making sure they
4+
will never happen again.
5+
6+
"""
7+
from vllm import LLM, SamplingParams
8+
9+
10+
def test_duplicated_ignored_sequence_group():
11+
"""https://github.com/vllm-project/vllm/issues/1655"""
12+
13+
sampling_params = SamplingParams(temperature=0.01,
14+
top_p=0.1,
15+
max_tokens=256)
16+
llm = LLM(model="facebook/opt-125m",
17+
max_num_batched_tokens=4096,
18+
tensor_parallel_size=1)
19+
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
20+
outputs = llm.generate(prompts, sampling_params=sampling_params)
21+
22+
assert len(prompts) == len(outputs)
23+
24+
25+
if __name__ == "__main__":
26+
import pytest
27+
pytest.main([__file__])

vllm/engine/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def step(self) -> List[RequestOutput]:
567567
blocks_to_copy=scheduler_outputs.blocks_to_copy,
568568
)
569569

570-
return self._process_model_outputs(output, scheduler_outputs) + ignored
570+
return self._process_model_outputs(output, scheduler_outputs)
571571

572572
def _log_system_stats(
573573
self,

0 commit comments

Comments
 (0)