Skip to content

Commit c06170c

Browse files
Add a flag to include stop string in output text (#1976)
1 parent 614856d commit c06170c

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

vllm/engine/llm_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,10 @@ def _check_stop(self, seq: Sequence,
682682
"""Stop the finished sequences."""
683683
for stop_str in sampling_params.stop:
684684
if seq.output_text.endswith(stop_str):
685-
# Truncate the output text so that the stop string is
686-
# not included in the output.
687-
seq.output_text = seq.output_text[:-len(stop_str)]
685+
if not sampling_params.include_stop_str_in_output:
686+
# Truncate the output text so that the stop string is
687+
# not included in the output.
688+
seq.output_text = seq.output_text[:-len(stop_str)]
688689
seq.status = SequenceStatus.FINISHED_STOPPED
689690
return
690691
if seq.get_last_token_id() in sampling_params.stop_token_ids:

vllm/sampling_params.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from enum import IntEnum
33
from functools import cached_property
44
from typing import Callable, List, Optional, Union
5+
56
import torch
67

78
_SAMPLING_EPS = 1e-5
@@ -70,6 +71,8 @@ class SamplingParams:
7071
stop_token_ids: List of tokens that stop the generation when they are
7172
generated. The returned output will contain the stop tokens unless
7273
the stop tokens are special tokens.
74+
include_stop_str_in_output: Whether to include the stop strings in output
75+
text. Defaults to False.
7376
ignore_eos: Whether to ignore the EOS token and continue generating
7477
tokens after the EOS token is generated.
7578
max_tokens: Maximum number of tokens to generate per output sequence.
@@ -103,6 +106,7 @@ def __init__(
103106
early_stopping: Union[bool, str] = False,
104107
stop: Optional[Union[str, List[str]]] = None,
105108
stop_token_ids: Optional[List[int]] = None,
109+
include_stop_str_in_output: bool = False,
106110
ignore_eos: bool = False,
107111
max_tokens: int = 16,
108112
logprobs: Optional[int] = None,
@@ -140,6 +144,7 @@ def __init__(
140144
self.skip_special_tokens = skip_special_tokens
141145
self.spaces_between_special_tokens = spaces_between_special_tokens
142146
self.logits_processors = logits_processors
147+
self.include_stop_str_in_output = include_stop_str_in_output
143148
self._verify_args()
144149
if self.use_beam_search:
145150
self._verify_beam_search()
@@ -227,24 +232,26 @@ def sampling_type(self) -> SamplingType:
227232
return SamplingType.RANDOM
228233

229234
def __repr__(self) -> str:
230-
return (f"SamplingParams(n={self.n}, "
231-
f"best_of={self.best_of}, "
232-
f"presence_penalty={self.presence_penalty}, "
233-
f"frequency_penalty={self.frequency_penalty}, "
234-
f"repetition_penalty={self.repetition_penalty}, "
235-
f"temperature={self.temperature}, "
236-
f"top_p={self.top_p}, "
237-
f"top_k={self.top_k}, "
238-
f"min_p={self.min_p}, "
239-
f"use_beam_search={self.use_beam_search}, "
240-
f"length_penalty={self.length_penalty}, "
241-
f"early_stopping={self.early_stopping}, "
242-
f"stop={self.stop}, "
243-
f"stop_token_ids={self.stop_token_ids}, "
244-
f"ignore_eos={self.ignore_eos}, "
245-
f"max_tokens={self.max_tokens}, "
246-
f"logprobs={self.logprobs}, "
247-
f"prompt_logprobs={self.prompt_logprobs}, "
248-
f"skip_special_tokens={self.skip_special_tokens}, "
249-
"spaces_between_special_tokens="
250-
f"{self.spaces_between_special_tokens})")
235+
return (
236+
f"SamplingParams(n={self.n}, "
237+
f"best_of={self.best_of}, "
238+
f"presence_penalty={self.presence_penalty}, "
239+
f"frequency_penalty={self.frequency_penalty}, "
240+
f"repetition_penalty={self.repetition_penalty}, "
241+
f"temperature={self.temperature}, "
242+
f"top_p={self.top_p}, "
243+
f"top_k={self.top_k}, "
244+
f"min_p={self.min_p}, "
245+
f"use_beam_search={self.use_beam_search}, "
246+
f"length_penalty={self.length_penalty}, "
247+
f"early_stopping={self.early_stopping}, "
248+
f"stop={self.stop}, "
249+
f"stop_token_ids={self.stop_token_ids}, "
250+
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
251+
f"ignore_eos={self.ignore_eos}, "
252+
f"max_tokens={self.max_tokens}, "
253+
f"logprobs={self.logprobs}, "
254+
f"prompt_logprobs={self.prompt_logprobs}, "
255+
f"skip_special_tokens={self.skip_special_tokens}, "
256+
"spaces_between_special_tokens="
257+
f"{self.spaces_between_special_tokens})")

0 commit comments

Comments
 (0)