Skip to content

Commit 40b6b38

Browse files
Jialin22quinn
andauthored
[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)
Signed-off-by: Jialin Ouyang <[email protected]> Co-authored-by: 22quinn <[email protected]>
1 parent da94c7c commit 40b6b38

File tree

6 files changed

+33
-41
lines changed

6 files changed

+33
-41
lines changed

tests/samplers/test_logprobs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ def test_ranks(
2424
greedy,
2525
flat_logprobs,
2626
example_prompts,
27-
monkeypatch: pytest.MonkeyPatch,
2827
):
29-
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
3028
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
3129
tokenizer = vllm_model.llm.get_tokenizer()
3230
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
@@ -36,6 +34,7 @@ def test_ranks(
3634
max_tokens=MAX_TOKENS,
3735
logprobs=NUM_TOP_LOGPROBS,
3836
prompt_logprobs=NUM_PROMPT_LOGPROBS,
37+
flat_logprobs=flat_logprobs,
3938
)
4039
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
4140

tests/test_logprobs.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5-
import pytest
6-
75
from vllm.logprobs import (
86
FlatLogprobs,
97
Logprob,
@@ -14,24 +12,20 @@
1412
)
1513

1614

17-
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
18-
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
19-
20-
prompt_logprobs = create_prompt_logprobs()
15+
def test_create_logprobs_non_flat() -> None:
16+
prompt_logprobs = create_prompt_logprobs(flat_logprobs=False)
2117
assert isinstance(prompt_logprobs, list)
2218
# Ensure first prompt position logprobs is None
2319
assert len(prompt_logprobs) == 1
2420
assert prompt_logprobs[0] is None
2521

26-
sample_logprobs = create_sample_logprobs()
22+
sample_logprobs = create_sample_logprobs(flat_logprobs=False)
2723
assert isinstance(sample_logprobs, list)
2824
assert len(sample_logprobs) == 0
2925

3026

31-
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
32-
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
33-
34-
prompt_logprobs = create_prompt_logprobs()
27+
def test_create_logprobs_flat() -> None:
28+
prompt_logprobs = create_prompt_logprobs(flat_logprobs=True)
3529
assert isinstance(prompt_logprobs, FlatLogprobs)
3630
assert prompt_logprobs.start_indices == [0]
3731
assert prompt_logprobs.end_indices == [0]
@@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
4337
assert len(prompt_logprobs) == 1
4438
assert prompt_logprobs[0] == dict()
4539

46-
sample_logprobs = create_sample_logprobs()
40+
sample_logprobs = create_sample_logprobs(flat_logprobs=True)
4741
assert isinstance(sample_logprobs, FlatLogprobs)
4842
assert len(sample_logprobs.start_indices) == 0
4943
assert len(sample_logprobs.end_indices) == 0
@@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
5448
assert len(sample_logprobs) == 0
5549

5650

57-
def test_append_logprobs_for_next_position_none_flat(
58-
monkeypatch: pytest.MonkeyPatch,
59-
) -> None:
60-
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
61-
logprobs = create_sample_logprobs()
51+
def test_append_logprobs_for_next_position_none_flat() -> None:
52+
logprobs = create_sample_logprobs(flat_logprobs=False)
6253
append_logprobs_for_next_position(
6354
logprobs,
6455
token_ids=[1],
@@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
8576
]
8677

8778

88-
def test_append_logprobs_for_next_position_flat(
89-
monkeypatch: pytest.MonkeyPatch,
90-
) -> None:
91-
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
92-
logprobs = create_sample_logprobs()
79+
def test_append_logprobs_for_next_position_flat() -> None:
80+
logprobs = create_sample_logprobs(flat_logprobs=True)
9381
append_logprobs_for_next_position(
9482
logprobs,
9583
token_ids=[1],

vllm/envs.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@
225225
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
226226
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
227227
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
228-
VLLM_FLAT_LOGPROBS: bool = False
229228

230229

231230
def get_default_cache_root():
@@ -1499,11 +1498,6 @@ def get_vllm_port() -> int | None:
14991498
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
15001499
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
15011500
),
1502-
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
1503-
# the original list[dict[int, Logprob]] approach.
1504-
# After enabled, PromptLogprobs and SampleLogprobs would populated as
1505-
# FlatLogprobs.
1506-
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
15071501
}
15081502

15091503
# --8<-- [end:env-vars-definition]

vllm/logprobs.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from dataclasses import dataclass, field
66
from typing import overload
77

8-
import vllm.envs as envs
9-
108

119
# We use dataclass for now because it is used for
1210
# openai server output, and msgspec is not serializable.
@@ -161,17 +159,17 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]:
161159
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
162160

163161

164-
def create_prompt_logprobs() -> PromptLogprobs:
162+
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
165163
"""Creates a container to store prompt logprobs for a request"""
166-
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
164+
logprobs = FlatLogprobs() if flat_logprobs else []
167165
# NOTE: logprob of first prompt token is None.
168166
logprobs.append(None)
169167
return logprobs
170168

171169

172-
def create_sample_logprobs() -> SampleLogprobs:
170+
def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
173171
"""Creates a container to store decode logprobs for a request"""
174-
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
172+
return FlatLogprobs() if flat_logprobs else []
175173

176174

177175
def append_logprobs_for_next_position(

vllm/sampling_params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ class SamplingParams(
204204
prompt_logprobs: int | None = None
205205
"""Number of log probabilities to return per prompt token.
206206
When set to -1, return all `vocab_size` log probabilities."""
207+
flat_logprobs: bool = False
208+
"""Whether to return logprobs in flatten format (i.e. FlatLogprob)
209+
for better performance.
210+
NOTE: GC costs of FlatLogprobs is significantly smaller than
211+
list[dict[int, Logprob]]. After enabled, PromptLogprobs and
212+
SampleLogprobs would populated as FlatLogprobs."""
207213
# NOTE: This parameter is only exposed at the engine level for now.
208214
# It is not exposed in the OpenAI API server, as the OpenAI API does
209215
# not support returning only a list of token IDs.

vllm/v1/engine/logprobs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,22 @@ def from_new_request(
4343
tokenizer: AnyTokenizer | None,
4444
request: EngineCoreRequest,
4545
) -> "LogprobsProcessor":
46-
assert request.sampling_params is not None
47-
num_logprobs = request.sampling_params.logprobs
48-
num_prompt_logprobs = request.sampling_params.prompt_logprobs
46+
sampling_params = request.sampling_params
47+
assert sampling_params is not None
48+
num_logprobs = sampling_params.logprobs
49+
num_prompt_logprobs = sampling_params.prompt_logprobs
4950
return cls(
5051
tokenizer=tokenizer,
5152
cumulative_logprob=(None if num_logprobs is None else 0.0),
52-
logprobs=(None if num_logprobs is None else create_sample_logprobs()),
53+
logprobs=(
54+
None
55+
if num_logprobs is None
56+
else create_sample_logprobs(sampling_params.flat_logprobs)
57+
),
5358
prompt_logprobs=(
54-
None if num_prompt_logprobs is None else create_prompt_logprobs()
59+
None
60+
if num_prompt_logprobs is None
61+
else create_prompt_logprobs(sampling_params.flat_logprobs)
5562
),
5663
num_prompt_logprobs=num_prompt_logprobs,
5764
num_logprobs=num_logprobs,

0 commit comments

Comments
 (0)