Skip to content

Commit fa8a2fa

Browse files
njhillzhewenl
authored andcommitted
[LogitsProcs] Deduplicate built-in LP implementation logic (vllm-project#23362)
Signed-off-by: Nick Hill <[email protected]>
1 parent ed51563 commit fa8a2fa

File tree

4 files changed

+95
-143
lines changed

4 files changed

+95
-143
lines changed

examples/offline_inference/logits_processor.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class object.
4242
from vllm.v1.sample.logits_processor import (
4343
BatchUpdate,
4444
LogitsProcessor,
45-
MoveDirectionality,
4645
)
46+
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
4747

4848

4949
# Hypothetical custom logits processor
@@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor):
5353
def __init__(
5454
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
5555
):
56-
self.req_info: dict[int, SamplingParams] = {}
56+
self.req_info: dict[int, int] = {}
5757

5858
def is_argmax_invariant(self) -> bool:
5959
"""Never impacts greedy sampling"""
6060
return False
6161

6262
def update_state(self, batch_update: Optional[BatchUpdate]):
63-
if not batch_update:
64-
return
65-
66-
# Process added requests.
67-
for index, params, _, _ in batch_update.added:
68-
assert params is not None
69-
if params.extra_args and (
70-
target_token := params.extra_args.get("target_token")
71-
):
72-
self.req_info[index] = target_token
73-
74-
if self.req_info:
75-
# Process removed requests.
76-
for index in batch_update.removed:
77-
self.req_info.pop(index, None)
78-
79-
# Process moved requests, unidirectional move (a->b) and swap
80-
# (a<->b)
81-
for adx, bdx, direct in batch_update.moved:
82-
a_val = self.req_info.pop(adx, None)
83-
b_val = self.req_info.pop(bdx, None)
84-
if a_val is not None:
85-
self.req_info[bdx] = a_val
86-
if direct == MoveDirectionality.SWAP and b_val is not None:
87-
self.req_info[adx] = b_val
63+
process_dict_updates(
64+
self.req_info,
65+
batch_update,
66+
# This function returns the LP's per-request state based on the
67+
# request details, or None if this LP does not apply to the
68+
# request.
69+
lambda params, _, __: params.extra_args
70+
and (params.extra_args.get("target_token")),
71+
)
8872

8973
def apply(self, logits: torch.Tensor) -> torch.Tensor:
9074
if not self.req_info:

tests/v1/logits_processors/utils.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import torch
99

1010
from vllm.config import VllmConfig
11-
from vllm.sampling_params import SamplingParams
1211
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
13-
LogitsProcessor,
14-
MoveDirectionality)
12+
LogitsProcessor)
13+
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
1514

1615
MODEL_NAME = "facebook/opt-125m"
1716
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
4544

4645
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
4746
is_pin_memory: bool):
48-
self.req_info: dict[int, SamplingParams] = {}
47+
self.req_info: dict[int, int] = {}
4948

5049
def is_argmax_invariant(self) -> bool:
5150
"""Never impacts greedy sampling"""
5251
return False
5352

5453
def update_state(self, batch_update: Optional[BatchUpdate]):
55-
if not batch_update:
56-
return
57-
58-
# Process added requests.
59-
for index, params, _, _ in batch_update.added:
60-
assert params is not None
61-
if params.extra_args and (target_token :=
62-
params.extra_args.get("target_token")):
63-
self.req_info[index] = target_token
64-
65-
if self.req_info:
66-
# Process removed requests.
67-
for index in batch_update.removed:
68-
self.req_info.pop(index, None)
69-
70-
# Process moved requests, unidirectional move (a->b) and swap
71-
# (a<->b)
72-
for adx, bdx, direct in batch_update.moved:
73-
a_val = self.req_info.pop(adx, None)
74-
b_val = self.req_info.pop(bdx, None)
75-
if a_val is not None:
76-
self.req_info[bdx] = a_val
77-
if direct == MoveDirectionality.SWAP and b_val is not None:
78-
self.req_info[adx] = b_val
54+
process_dict_updates(
55+
self.req_info,
56+
batch_update,
57+
lambda params, _, __: params.extra_args and
58+
(params.extra_args.get("target_token")),
59+
)
7960

8061
def apply(self, logits: torch.Tensor) -> torch.Tensor:
8162
if not self.req_info:

vllm/v1/sample/logits_processor/builtin.py

Lines changed: 62 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Sequence
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
55

66
import torch
77

8+
from vllm import SamplingParams
89
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
910
LogitsProcessor,
1011
MoveDirectionality)
1112

1213
if TYPE_CHECKING:
1314
from vllm.config import VllmConfig
1415

16+
T = TypeVar("T")
17+
1518

1619
class MinPLogitsProcessor(LogitsProcessor):
1720

@@ -130,49 +133,15 @@ def is_argmax_invariant(self) -> bool:
130133
return False
131134

132135
def update_state(self, batch_update: Optional[BatchUpdate]):
133-
if not batch_update:
134-
return
135-
136-
needs_update: bool = False
137-
# Process added requests.
138-
for index, params, _, _ in batch_update.added:
139-
if lb := params.logit_bias:
140-
self.biases[index] = lb
141-
needs_update = True
142-
else:
143-
# Drop biases metadata at batch index
144-
if self.biases.pop(index, None) is not None:
145-
# If a new request replaces an old request which
146-
# specified biases, we should update processor tensors
147-
needs_update = True
148-
149-
if self.biases:
150-
# Process removed requests.
151-
for index in batch_update.removed:
152-
if self.biases.pop(index, None):
153-
needs_update = True
154-
155-
# Process moved requests, unidirectional (a->b) and swap (a<->b)
156-
for a_index, b_index, direct in batch_update.moved:
157-
if direct == MoveDirectionality.UNIDIRECTIONAL:
158-
if (a_entry := self.biases.pop(a_index, None)) is None:
159-
if self.biases.pop(b_index, None) is not None:
160-
needs_update = True
161-
else:
162-
self.biases[b_index] = a_entry
163-
needs_update = True
164-
else:
165-
a_entry = self.biases.pop(a_index, None)
166-
if (b_entry := self.biases.pop(b_index, None)) is not None:
167-
self.biases[a_index] = b_entry
168-
needs_update = True
169-
if a_entry is not None:
170-
self.biases[b_index] = a_entry
171-
needs_update = True
136+
needs_update = process_dict_updates(
137+
self.biases, batch_update,
138+
lambda params, _, __: params.logit_bias or None)
172139

173140
# Update tensors if needed.
174141
if needs_update:
175-
reqs, tok_ids, biases = [], [], []
142+
reqs: list[int] = []
143+
tok_ids: list[int] = []
144+
biases: list[float] = []
176145
for req, lb in self.biases.items():
177146
reqs.extend([req] * len(lb))
178147
tok_ids.extend(lb.keys())
@@ -216,52 +185,18 @@ def is_argmax_invariant(self) -> bool:
216185
of the argmax operation in greedy sampling."""
217186
return False
218187

219-
def update_state(self, batch_update: Optional[BatchUpdate]):
220-
needs_update = False
221-
222-
if batch_update:
223-
# Process added requests.
224-
for index, params, _, output_tok_ids in batch_update.added:
225-
if ((min_tokens := params.min_tokens)
226-
and len(output_tok_ids) < min_tokens):
227-
# Replace request metadata at batch index
228-
self.min_toks[index] = (min_tokens, output_tok_ids,
229-
params.all_stop_token_ids)
230-
needs_update = True
231-
else:
232-
# Drop min_toks metadata at batch index
233-
if self.min_toks.pop(index, None) is not None:
234-
# If a new request replaces an old request which
235-
# specified min_toks, we should update processor tensors
236-
needs_update = True
237-
238-
if self.min_toks:
239-
# Process removed requests.
240-
for index in batch_update.removed:
241-
if self.min_toks.pop(index, None):
242-
needs_update = True
243-
244-
# Process moved requests, unidirectional (a->b) and
245-
# swapped (a<->b)
246-
for a_index, b_index, direct in batch_update.moved:
247-
if direct == MoveDirectionality.UNIDIRECTIONAL:
248-
if (a_entry := self.min_toks.pop(a_index,
249-
None)) is None:
250-
if self.min_toks.pop(b_index, None) is not None:
251-
needs_update = True
252-
else:
253-
self.min_toks[b_index] = a_entry
254-
needs_update = True
255-
else:
256-
a_entry = self.min_toks.pop(a_index, None)
257-
if (b_entry := self.min_toks.pop(b_index,
258-
None)) is not None:
259-
self.min_toks[a_index] = b_entry
260-
needs_update = True
261-
if a_entry is not None:
262-
self.min_toks[b_index] = a_entry
263-
needs_update = True
188+
@staticmethod
189+
def add_request(
190+
params: SamplingParams, _: list[int], output_tok_ids: list[int]
191+
) -> Optional[tuple[int, Sequence[int], set[int]]]:
192+
min_tokens = params.min_tokens
193+
if not min_tokens or len(output_tok_ids) >= min_tokens:
194+
return None
195+
return min_tokens, output_tok_ids, params.all_stop_token_ids
264196

197+
def update_state(self, batch_update: Optional[BatchUpdate]):
198+
needs_update = process_dict_updates(self.min_toks, batch_update,
199+
self.add_request)
265200
if self.min_toks:
266201
# Check for any requests that have attained their min tokens.
267202
to_remove = tuple(index for index, (min_toks, out_tok_ids,
@@ -295,3 +230,44 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
295230
# Inhibit EOS token for requests which have not reached min length
296231
logits[self.logits_slice] = -float("inf")
297232
return logits
233+
234+
235+
def process_dict_updates(
236+
req_entries: dict[int, T], batch_update: Optional[BatchUpdate],
237+
new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]]
238+
) -> bool:
239+
"""Utility function to update dict state for sparse LogitsProcessors."""
240+
241+
if not batch_update:
242+
# Nothing to do.
243+
return False
244+
245+
updated = False
246+
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
247+
if (state := new_state(params, prompt_tok_ids,
248+
output_tok_ids)) is not None:
249+
req_entries[index] = state
250+
updated = True
251+
elif req_entries.pop(index, None) is not None:
252+
updated = True
253+
254+
if req_entries:
255+
# Process removed requests.
256+
for index in batch_update.removed:
257+
if req_entries.pop(index, None):
258+
updated = True
259+
260+
# Process moved requests, unidirectional (a->b) and
261+
# swapped (a<->b)
262+
for a_index, b_index, direct in batch_update.moved:
263+
a_entry = req_entries.pop(a_index, None)
264+
b_entry = req_entries.pop(b_index, None)
265+
if a_entry is not None:
266+
req_entries[b_index] = a_entry
267+
updated = True
268+
if b_entry is not None:
269+
updated = True
270+
if direct == MoveDirectionality.SWAP:
271+
req_entries[a_index] = b_entry
272+
273+
return updated

vllm/v1/sample/logits_processor/interface.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,16 @@ class BatchUpdate:
4444
# Key assumption: the `output_tok_ids` list (which is an element of each
4545
# tuple in `added`) is a reference to the request's running output tokens
4646
# list; via this reference, the logits processors always see the latest
47-
# list of generated output tokens
47+
# list of generated output tokens.
48+
#
49+
# NOTE:
50+
# * Added or moved requests may replace existing requests with the same
51+
# index.
52+
# * Operations should be processed in the following order:
53+
# - removed, added, moved
4854
removed: Sequence[RemovedRequest]
49-
moved: Sequence[MovedRequest]
5055
added: Sequence[AddedRequest]
56+
moved: Sequence[MovedRequest]
5157

5258

5359
class LogitsProcessor(ABC):
@@ -59,6 +65,11 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device,
5965

6066
@abstractmethod
6167
def apply(self, logits: torch.Tensor) -> torch.Tensor:
68+
"""Apply LogitsProcessor to batch logits tensor.
69+
70+
The updated tensor must be returned but may be
71+
modified in-place.
72+
"""
6273
raise NotImplementedError
6374

6475
@abstractmethod

0 commit comments

Comments
 (0)