Skip to content

Commit 3c52732

Browse files
youkaichaoLunrEclipse
authored andcommitted
[misc] hide best_of from engine (vllm-project#9261)
Co-authored-by: Brendan Wong <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
1 parent 283caf0 commit 3c52732

File tree

14 files changed

+46
-73
lines changed

14 files changed

+46
-73
lines changed

tests/entrypoints/openai/test_metrics.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ async def client(server):
7070
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
7171
("_count", _NUM_REQUESTS)],
7272
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
73-
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
7473
"vllm:prompt_tokens": [("_total",
7574
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
7675
"vllm:generation_tokens":
@@ -151,9 +150,6 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):
151150
"vllm:request_params_n_sum",
152151
"vllm:request_params_n_bucket",
153152
"vllm:request_params_n_count",
154-
"vllm:request_params_best_of_sum",
155-
"vllm:request_params_best_of_bucket",
156-
"vllm:request_params_best_of_count",
157153
"vllm:num_preemptions_total",
158154
"vllm:prompt_tokens_total",
159155
"vllm:generation_tokens_total",

tests/metrics/test_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
326326
"vllm:e2e_request_latency_seconds",
327327
"vllm:request_prompt_tokens",
328328
"vllm:request_generation_tokens",
329-
"vllm:request_params_best_of",
330329
"vllm:request_params_n",
331330
]
332331
for metric_name in request_histogram_metrics:

tests/tracing/test_tracing.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def test_traces(trace_service):
9898
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
9999
assert attributes.get(
100100
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
101-
assert attributes.get(
102-
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
103101
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
104102
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
105103
outputs[0].prompt_token_ids)
@@ -155,8 +153,6 @@ def test_traces_with_detailed_steps(trace_service):
155153
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
156154
assert attributes.get(
157155
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
158-
assert attributes.get(
159-
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
160156
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
161157
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
162158
outputs[0].prompt_token_ids)

vllm/core/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
12051205
# async_output_proc is allowed only when we have a single sequence
12061206
# in the sequence group
12071207
no_single_seq = seq_group.sampling_params is None or (
1208-
seq_group.sampling_params.best_of == 1)
1208+
seq_group.sampling_params.n == 1)
12091209
return no_single_seq
12101210

12111211
def schedule(

vllm/engine/llm_engine.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def add_request(
767767
Details:
768768
- Set arrival_time to the current time if it is None.
769769
- Set prompt_token_ids to the encoded prompt if it is None.
770-
- Create `best_of` number of :class:`~vllm.Sequence` objects.
770+
- Create `n` number of :class:`~vllm.Sequence` objects.
771771
- Create a :class:`~vllm.SequenceGroup` object
772772
from the list of :class:`~vllm.Sequence`.
773773
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
@@ -1242,8 +1242,7 @@ def _advance_to_next_step(
12421242
if seq_group_metadata.do_sample:
12431243
assert len(sequence_group_outputs.samples) == 1, (
12441244
"Async output processor expects a single sample"
1245-
" (i.e sampling_params.n == 1 and no "
1246-
"sampling_params.best_of > 1)")
1245+
" (i.e sampling_params.n == 1)")
12471246
sample = sequence_group_outputs.samples[0]
12481247

12491248
assert len(seq_group.seqs) == 1
@@ -1612,7 +1611,6 @@ def _get_stats(self,
16121611
# Metadata
16131612
num_prompt_tokens_requests: List[int] = []
16141613
num_generation_tokens_requests: List[int] = []
1615-
best_of_requests: List[int] = []
16161614
n_requests: List[int] = []
16171615
finished_reason_requests: List[str] = []
16181616

@@ -1683,8 +1681,6 @@ def _get_stats(self,
16831681
for seq in seq_group.get_finished_seqs()
16841682
])
16851683
if seq_group.sampling_params is not None:
1686-
best_of_requests.append(
1687-
seq_group.sampling_params.best_of)
16881684
n_requests.append(seq_group.sampling_params.n)
16891685
finished_reason_requests.extend([
16901686
SequenceStatus.get_finished_reason(seq.status)
@@ -1737,7 +1733,6 @@ def _get_stats(self,
17371733
# Metadata
17381734
num_prompt_tokens_requests=num_prompt_tokens_requests,
17391735
num_generation_tokens_requests=num_generation_tokens_requests,
1740-
best_of_requests=best_of_requests,
17411736
n_requests=n_requests,
17421737
finished_reason_requests=finished_reason_requests,
17431738
)
@@ -1824,8 +1819,6 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
18241819
seq_group.sampling_params.top_p)
18251820
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
18261821
seq_group.sampling_params.max_tokens)
1827-
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
1828-
seq_group.sampling_params.best_of)
18291822
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
18301823
seq_group.sampling_params.n)
18311824
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,

vllm/engine/metrics.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,6 @@ def __init__(self, labelnames: List[str], max_model_len: int):
134134
labelnames=labelnames,
135135
buckets=build_1_2_5_buckets(max_model_len),
136136
)
137-
self.histogram_best_of_request = self._histogram_cls(
138-
name="vllm:request_params_best_of",
139-
documentation="Histogram of the best_of request parameter.",
140-
labelnames=labelnames,
141-
buckets=[1, 2, 5, 10, 20],
142-
)
143137
self.histogram_n_request = self._histogram_cls(
144138
name="vllm:request_params_n",
145139
documentation="Histogram of the n request parameter.",
@@ -473,8 +467,6 @@ def _log_prometheus(self, stats: Stats) -> None:
473467
self.metrics.histogram_num_generation_tokens_request,
474468
stats.num_generation_tokens_requests)
475469
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
476-
self._log_histogram(self.metrics.histogram_best_of_request,
477-
stats.best_of_requests)
478470

479471
def _log_prometheus_interval(self, prompt_throughput: float,
480472
generation_throughput: float) -> None:

vllm/engine/metrics_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class Stats:
4949
# Metadata
5050
num_prompt_tokens_requests: List[int]
5151
num_generation_tokens_requests: List[int]
52-
best_of_requests: List[int]
5352
n_requests: List[int]
5453
finished_reason_requests: List[str]
5554

vllm/engine/output_processor/single_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
112112
outputs: SequenceGroupOutput,
113113
is_async: bool) -> None:
114114
sampling_params = seq_group.sampling_params
115-
if sampling_params.best_of == 1:
115+
if sampling_params.n == 1:
116116
# only have one output sample
117117
sample = outputs.samples[0]
118118
# only have one sequence

vllm/model_executor/layers/sampler.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def _random_sample(
508508
same as the length of selected_seq_groups. If the corresponding
509509
seq_group has do_sample=False, tuple contains ([], [])
510510
"""
511-
# Find the maximum best_of value of the prompt phase requests.
511+
# Find the maximum n value of the prompt phase requests.
512512
random_samples = random_samples.cpu()
513513
sample_idx = 0
514514
results: SampleResultType = []
@@ -523,9 +523,9 @@ def _random_sample(
523523
num_parent_seqs = len(seq_ids)
524524
if is_prompt:
525525
# Prompt phase.
526-
parent_ids = [0] * sampling_params.best_of
526+
parent_ids = [0] * sampling_params.n
527527
next_token_ids = random_samples[
528-
sample_idx, :sampling_params.best_of].tolist()
528+
sample_idx, :sampling_params.n].tolist()
529529
else:
530530
# Generation phase.
531531
parent_ids = list(range(num_parent_seqs))
@@ -570,7 +570,7 @@ def _beam_search_sample(
570570
is_prompt = seq_group.is_prompt
571571
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
572572
num_parent_seqs = len(seq_ids)
573-
beam_width = sampling_params.best_of
573+
beam_width = sampling_params.n
574574
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
575575
if is_prompt:
576576
# Prompt phase.
@@ -797,12 +797,11 @@ def _sample_with_torch(
797797
greedy_samples)
798798

799799
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
800-
max_best_of_in_batch = 1
800+
max_n_in_batch = 1
801801
for seq_group in seq_groups:
802802
if seq_group.is_prompt:
803803
sampling_params = seq_group.sampling_params
804-
max_best_of_in_batch = max(max_best_of_in_batch,
805-
sampling_params.best_of)
804+
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
806805
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
807806
seq_groups)
808807

@@ -812,13 +811,13 @@ def _sample_with_torch(
812811
probs[long_sample_indices],
813812
sampling_tensors.top_ks[long_sample_indices],
814813
sampling_tensors.top_ps[long_sample_indices],
815-
max_best_of_in_batch,
814+
max_n_in_batch,
816815
seq_groups_arg,
817816
)
818817
else:
819818
multinomial_samples[sampling_type] = _multinomial(
820819
probs[long_sample_indices],
821-
max_best_of_in_batch,
820+
max_n_in_batch,
822821
seq_groups=seq_groups_arg)
823822

824823
if sampled_token_ids_tensor is not None:

vllm/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
141141
top_n_seqs = seqs
142142
else:
143143
# Get the top-n sequences.
144-
n = sampling_params.n
144+
n = sampling_params._real_n or sampling_params.n
145145
sorting_key = lambda seq: seq.get_cumulative_logprob()
146146
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
147147
top_n_seqs = sorted_seqs[:n]

0 commit comments

Comments
 (0)