Skip to content

Commit 20f7cc4

Browse files
authored
Add skip_special_tokens sampling params (#1186)
1 parent 649aa73 commit 20f7cc4

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

vllm/engine/llm_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _process_sequence_group_samples(
387387
child_seqs.append((parent, parent))
388388

389389
for seq, _ in child_seqs:
390-
self._decode_sequence(seq)
390+
self._decode_sequence(seq, seq_group.sampling_params)
391391
self._check_stop(seq, seq_group.sampling_params)
392392

393393
# Non-beam search case
@@ -621,7 +621,8 @@ def _log_system_stats(
621621
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
622622
self.last_logging_time = now
623623

624-
def _decode_sequence(self, seq: Sequence) -> None:
624+
def _decode_sequence(self, seq: Sequence,
625+
sampling_params: SamplingParams) -> None:
625626
"""Decodes the new token for a sequence."""
626627
(new_tokens, new_output_text, prefix_offset,
627628
read_offset) = detokenize_incrementally(
@@ -630,7 +631,7 @@ def _decode_sequence(self, seq: Sequence) -> None:
630631
prev_tokens=seq.tokens,
631632
prefix_offset=seq.prefix_offset,
632633
read_offset=seq.read_offset,
633-
skip_special_tokens=True,
634+
skip_special_tokens=sampling_params.skip_special_tokens,
634635
)
635636
if seq.tokens is None:
636637
seq.tokens = new_tokens

vllm/entrypoints/openai/api_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
225225
top_k=request.top_k,
226226
ignore_eos=request.ignore_eos,
227227
use_beam_search=request.use_beam_search,
228+
skip_special_tokens=request.skip_special_tokens,
228229
)
229230
except ValueError as e:
230231
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
426427
max_tokens=request.max_tokens,
427428
logprobs=request.logprobs,
428429
use_beam_search=request.use_beam_search,
430+
skip_special_tokens=request.skip_special_tokens,
429431
)
430432
except ValueError as e:
431433
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

vllm/entrypoints/openai/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
7171
ignore_eos: Optional[bool] = False
7272
use_beam_search: Optional[bool] = False
7373
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
74+
skip_special_tokens: Optional[bool] = True
7475

7576

7677
class CompletionRequest(BaseModel):
@@ -96,6 +97,7 @@ class CompletionRequest(BaseModel):
9697
ignore_eos: Optional[bool] = False
9798
use_beam_search: Optional[bool] = False
9899
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
100+
skip_special_tokens: Optional[bool] = True
99101

100102

101103
class LogProbs(BaseModel):

vllm/sampling_params.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class SamplingParams:
6060
tokens after the EOS token is generated.
6161
max_tokens: Maximum number of tokens to generate per output sequence.
6262
logprobs: Number of log probabilities to return per output token.
63+
skip_special_tokens: Whether to skip special tokens in the output.
64+
Defaults to true.
6365
"""
6466

6567
def __init__(
@@ -79,6 +81,7 @@ def __init__(
7981
ignore_eos: bool = False,
8082
max_tokens: int = 16,
8183
logprobs: Optional[int] = None,
84+
skip_special_tokens: bool = True,
8285
) -> None:
8386
self.n = n
8487
self.best_of = best_of if best_of is not None else n
@@ -103,6 +106,7 @@ def __init__(
103106
self.ignore_eos = ignore_eos
104107
self.max_tokens = max_tokens
105108
self.logprobs = logprobs
109+
self.skip_special_tokens = skip_special_tokens
106110

107111
self._verify_args()
108112
if self.use_beam_search:
@@ -196,4 +200,5 @@ def __repr__(self) -> str:
196200
f"stop={self.stop}, "
197201
f"ignore_eos={self.ignore_eos}, "
198202
f"max_tokens={self.max_tokens}, "
199-
f"logprobs={self.logprobs})")
203+
f"logprobs={self.logprobs}, "
204+
f"skip_special_tokens={self.skip_special_tokens})")

0 commit comments

Comments
 (0)