Skip to content

Commit e0ff920

Browse files
authored
[BUGFIX] Do not return ignored sentences twice in async llm engine (#2258)
1 parent face83c commit e0ff920

File tree

2 files changed

+7
-22
lines changed

2 files changed

+7
-22
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,20 +183,18 @@ async def step_async(self) -> List[RequestOutput]:
183183
and updates the scheduler with the model outputs. Finally, it decodes
184184
the sequences and returns the newly generated results.
185185
"""
186-
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
187-
if scheduler_outputs.is_empty():
188-
return ignored
186+
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
189187

190188
# Execute the model.
191-
output = await self._run_workers_async(
189+
output = (await self._run_workers_async(
192190
"execute_model",
193191
seq_group_metadata_list=seq_group_metadata_list,
194192
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
195193
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
196194
blocks_to_copy=scheduler_outputs.blocks_to_copy,
197-
)
195+
)) if not scheduler_outputs.is_empty() else []
198196

199-
return self._process_model_outputs(output, scheduler_outputs) + ignored
197+
return self._process_model_outputs(output, scheduler_outputs)
200198

201199
async def _run_workers_async(
202200
self,

vllm/engine/llm_engine.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from vllm.outputs import RequestOutput
1515
from vllm.sampling_params import SamplingParams
1616
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
17-
SequenceGroupMetadata, SequenceGroupOutput,
18-
SequenceOutput, SequenceStatus)
17+
SequenceGroupOutput, SequenceOutput, SequenceStatus)
1918
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
2019
get_tokenizer)
2120
from vllm.utils import Counter
@@ -328,16 +327,6 @@ def has_unfinished_requests(self) -> bool:
328327
"""Returns True if there are unfinished requests."""
329328
return self.scheduler.has_unfinished_seqs()
330329

331-
def _schedule(
332-
self
333-
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
334-
List[RequestOutput]]:
335-
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
336-
return seq_group_metadata_list, scheduler_outputs, [
337-
RequestOutput.from_seq_group(seq_group)
338-
for seq_group in scheduler_outputs.ignored_seq_groups
339-
]
340-
341330
def _check_beam_search_early_stopping(
342331
self,
343332
early_stopping: Union[bool, str],
@@ -586,9 +575,7 @@ def step(self) -> List[RequestOutput]:
586575
and updates the scheduler with the model outputs. Finally, it decodes
587576
the sequences and returns the newly generated results.
588577
"""
589-
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
590-
if scheduler_outputs.is_empty():
591-
return ignored
578+
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
592579

593580
# Execute the model.
594581
output = self._run_workers(
@@ -597,7 +584,7 @@ def step(self) -> List[RequestOutput]:
597584
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
598585
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
599586
blocks_to_copy=scheduler_outputs.blocks_to_copy,
600-
)
587+
) if not scheduler_outputs.is_empty() else []
601588

602589
return self._process_model_outputs(output, scheduler_outputs)
603590

0 commit comments

Comments
 (0)