Skip to content

Commit accac82

Browse files
authored
[Sampler] Introduce logprobs mode for logging (#21398)
Signed-off-by: Lu Fang <[email protected]>
1 parent 23637dc commit accac82

File tree

7 files changed

+83
-13
lines changed

7 files changed

+83
-13
lines changed

tests/v1/sample/test_logprobs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
assert_incr_detok_str_matches_non_incr_detok_str,
1313
compute_correct_cumulative_logprob, get_test_batch)
1414
from vllm import SamplingParams
15+
from vllm.config import LogprobsMode
1516

1617
from ...conftest import HfRunner, VllmRunner
1718

@@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
426427
# prompt token
427428
assert prompt_logprobs is not None
428429
assert len(prompt_token_ids) == len(prompt_logprobs)
430+
431+
432+
@pytest.mark.parametrize(
433+
"logprobs_mode",
434+
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
435+
def test_logprobs_mode(logprobs_mode: LogprobsMode,
436+
monkeypatch: pytest.MonkeyPatch):
437+
"""Test with LLM engine with different logprobs_mode.
438+
For logprobs, we should have non-positive values.
439+
For logits, we should expect at least one positive values.
440+
"""
441+
from vllm import LLM
442+
with monkeypatch.context() as m:
443+
m.setenv("VLLM_USE_V1", "1")
444+
445+
llm = LLM(
446+
"facebook/opt-125m",
447+
max_logprobs=5,
448+
enable_prefix_caching=False,
449+
# 2 other llms alive during whole session
450+
gpu_memory_utilization=0.05,
451+
max_model_len=16,
452+
logprobs_mode=logprobs_mode)
453+
vllm_sampling_params = SamplingParams(logprobs=1)
454+
results = llm.generate(["Hello world"],
455+
sampling_params=vllm_sampling_params)
456+
457+
total_token_with_logprobs = 0
458+
positive_values = 0
459+
for output in results[0].outputs:
460+
for logprobs in output.logprobs:
461+
for token_id in logprobs:
462+
logprob = logprobs[token_id]
463+
if "logprobs" in logprobs_mode:
464+
assert logprob.logprob <= 0
465+
if logprob.logprob > 0:
466+
positive_values = positive_values + 1
467+
total_token_with_logprobs = total_token_with_logprobs + 1
468+
assert total_token_with_logprobs >= len(results[0].outputs)
469+
if "logits" in logprobs_mode:
470+
assert positive_values > 0
471+
del llm

vllm/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
219219

220220
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
221221
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
222+
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
223+
"processed_logits"]
222224

223225

224226
@config
@@ -316,6 +318,13 @@ class ModelConfig:
316318
"""Maximum number of log probabilities to return when `logprobs` is
317319
specified in `SamplingParams`. The default value comes the default for the
318320
OpenAI Chat Completions API."""
321+
logprobs_mode: LogprobsMode = "raw_logprobs"
322+
"""Indicates the content returned in the logprobs and prompt_logprobs.
323+
Supported mode:
324+
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
325+
Raw means the values before applying logit processors, like bad words.
326+
Processed means the values after applying such processors.
327+
"""
319328
disable_sliding_window: bool = False
320329
"""Whether to disable sliding window. If True, we will disable the sliding
321330
window functionality of the model, capping to sliding window size. If the

vllm/engine/arg_utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
DetailedTraceModules, Device, DeviceConfig,
2727
DistributedExecutorBackend, GuidedDecodingBackend,
2828
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
29-
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
30-
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
31-
ObservabilityConfig, ParallelConfig, PoolerConfig,
32-
PrefixCachingHashAlgo, PromptAdapterConfig,
33-
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
34-
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
35-
get_field)
29+
KVTransferConfig, LoadConfig, LoadFormat,
30+
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
31+
ModelImpl, MultiModalConfig, ObservabilityConfig,
32+
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
33+
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
34+
SpeculativeConfig, TaskOption, TokenizerMode,
35+
VllmConfig, get_attr_docs, get_field)
3636
from vllm.logger import init_logger
3737
from vllm.platforms import CpuArchEnum, current_platform
3838
from vllm.plugins import load_general_plugins
@@ -324,6 +324,7 @@ class EngineArgs:
324324
SchedulerConfig.long_prefill_token_threshold
325325
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
326326
max_logprobs: int = ModelConfig.max_logprobs
327+
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
327328
disable_log_stats: bool = False
328329
revision: Optional[str] = ModelConfig.revision
329330
code_revision: Optional[str] = ModelConfig.code_revision
@@ -490,6 +491,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
490491
**model_kwargs["max_seq_len_to_capture"])
491492
model_group.add_argument("--max-logprobs",
492493
**model_kwargs["max_logprobs"])
494+
model_group.add_argument("--logprobs-mode",
495+
**model_kwargs["logprobs_mode"])
493496
model_group.add_argument("--disable-sliding-window",
494497
**model_kwargs["disable_sliding_window"])
495498
model_group.add_argument("--disable-cascade-attn",
@@ -892,6 +895,7 @@ def create_model_config(self) -> ModelConfig:
892895
enforce_eager=self.enforce_eager,
893896
max_seq_len_to_capture=self.max_seq_len_to_capture,
894897
max_logprobs=self.max_logprobs,
898+
logprobs_mode=self.logprobs_mode,
895899
disable_sliding_window=self.disable_sliding_window,
896900
disable_cascade_attn=self.disable_cascade_attn,
897901
skip_tokenizer_init=self.skip_tokenizer_init,

vllm/v1/sample/sampler.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn as nn
77

8+
from vllm.config import LogprobsMode
89
from vllm.utils import is_pin_memory_available
910
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
1011
from vllm.v1.sample.metadata import SamplingMetadata
@@ -18,10 +19,11 @@
1819

1920
class Sampler(nn.Module):
2021

21-
def __init__(self):
22+
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
2223
super().__init__()
2324
self.topk_topp_sampler = TopKTopPSampler()
2425
self.pin_memory = is_pin_memory_available()
26+
self.logprobs_mode = logprobs_mode
2527

2628
def forward(
2729
self,
@@ -36,7 +38,10 @@ def forward(
3638
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
3739
num_logprobs = sampling_metadata.max_num_logprobs
3840
if num_logprobs is not None:
39-
raw_logprobs = self.compute_logprobs(logits)
41+
if self.logprobs_mode == "raw_logprobs":
42+
raw_logprobs = self.compute_logprobs(logits)
43+
elif self.logprobs_mode == "raw_logits":
44+
raw_logprobs = logits.clone()
4045

4146
# Use float32 for the logits.
4247
logits = logits.to(torch.float32)
@@ -51,6 +56,14 @@ def forward(
5156

5257
# Apply penalties (e.g., min_tokens, freq_penalties).
5358
logits = self.apply_penalties(logits, sampling_metadata)
59+
60+
# Get the process logprobs or logits.
61+
if num_logprobs is not None:
62+
if self.logprobs_mode == "processed_logprobs":
63+
raw_logprobs = self.compute_logprobs(logits)
64+
elif self.logprobs_mode == "processed_logits":
65+
raw_logprobs = logits.clone()
66+
5467
# Sample the next token.
5568
sampled = self.sample(logits, sampling_metadata)
5669
# Convert sampled token ids to int64 (long) type to ensure compatibility

vllm/v1/sample/tpu/sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class Sampler(nn.Module):
1616

1717
def __init__(self):
18+
# TODO(houseroad): Add support for logprobs_mode.
1819
super().__init__()
1920
self.topk_topp_sampler = TopKTopPSampler()
2021

vllm/v1/worker/gpu_input_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def add_request(
389389

390390
def remove_request(self, req_id: str) -> Optional[int]:
391391
"""This method must always be followed by a call to condense().
392-
392+
393393
Args:
394394
req_id: request to remove
395395
@@ -590,7 +590,7 @@ def condense(self) -> None:
590590

591591
def refresh_metadata(self):
592592
"""Apply batch updates, reset input batch at end of step
593-
593+
594594
* Apply batch add/remove/permute to logits procs' states
595595
* If batch state is modified, update sampling metadata
596596
"""

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
self.encoder_cache_size = encoder_cache_size
152152

153153
# Sampler
154-
self.sampler = Sampler()
154+
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
155155

156156
self.eplb_state: Optional[EplbState] = None
157157
"""
@@ -1996,7 +1996,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor):
19961996
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
19971997
This is to help balance expert-selection
19981998
- during profile_run
1999-
- during DP rank dummy run
1999+
- during DP rank dummy run
20002000
"""
20012001
dp_size = self.vllm_config.parallel_config.data_parallel_size
20022002
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1

0 commit comments

Comments
 (0)