|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 | from collections.abc import Sequence
|
4 |
| -from typing import TYPE_CHECKING, Optional |
| 4 | +from typing import TYPE_CHECKING, Callable, Optional, TypeVar |
5 | 5 |
|
6 | 6 | import torch
|
7 | 7 |
|
| 8 | +from vllm import SamplingParams |
8 | 9 | from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
9 | 10 | LogitsProcessor,
|
10 | 11 | MoveDirectionality)
|
11 | 12 |
|
12 | 13 | if TYPE_CHECKING:
|
13 | 14 | from vllm.config import VllmConfig
|
14 | 15 |
|
| 16 | +T = TypeVar("T") |
| 17 | + |
15 | 18 |
|
16 | 19 | class MinPLogitsProcessor(LogitsProcessor):
|
17 | 20 |
|
@@ -130,49 +133,15 @@ def is_argmax_invariant(self) -> bool:
|
130 | 133 | return False
|
131 | 134 |
|
132 | 135 | def update_state(self, batch_update: Optional[BatchUpdate]):
|
133 |
| - if not batch_update: |
134 |
| - return |
135 |
| - |
136 |
| - needs_update: bool = False |
137 |
| - # Process added requests. |
138 |
| - for index, params, _, _ in batch_update.added: |
139 |
| - if lb := params.logit_bias: |
140 |
| - self.biases[index] = lb |
141 |
| - needs_update = True |
142 |
| - else: |
143 |
| - # Drop biases metadata at batch index |
144 |
| - if self.biases.pop(index, None) is not None: |
145 |
| - # If a new request replaces an old request which |
146 |
| - # specified biases, we should update processor tensors |
147 |
| - needs_update = True |
148 |
| - |
149 |
| - if self.biases: |
150 |
| - # Process removed requests. |
151 |
| - for index in batch_update.removed: |
152 |
| - if self.biases.pop(index, None): |
153 |
| - needs_update = True |
154 |
| - |
155 |
| - # Process moved requests, unidirectional (a->b) and swap (a<->b) |
156 |
| - for a_index, b_index, direct in batch_update.moved: |
157 |
| - if direct == MoveDirectionality.UNIDIRECTIONAL: |
158 |
| - if (a_entry := self.biases.pop(a_index, None)) is None: |
159 |
| - if self.biases.pop(b_index, None) is not None: |
160 |
| - needs_update = True |
161 |
| - else: |
162 |
| - self.biases[b_index] = a_entry |
163 |
| - needs_update = True |
164 |
| - else: |
165 |
| - a_entry = self.biases.pop(a_index, None) |
166 |
| - if (b_entry := self.biases.pop(b_index, None)) is not None: |
167 |
| - self.biases[a_index] = b_entry |
168 |
| - needs_update = True |
169 |
| - if a_entry is not None: |
170 |
| - self.biases[b_index] = a_entry |
171 |
| - needs_update = True |
| 136 | + needs_update = process_dict_updates( |
| 137 | + self.biases, batch_update, |
| 138 | + lambda params, _, __: params.logit_bias or None) |
172 | 139 |
|
173 | 140 | # Update tensors if needed.
|
174 | 141 | if needs_update:
|
175 |
| - reqs, tok_ids, biases = [], [], [] |
| 142 | + reqs: list[int] = [] |
| 143 | + tok_ids: list[int] = [] |
| 144 | + biases: list[float] = [] |
176 | 145 | for req, lb in self.biases.items():
|
177 | 146 | reqs.extend([req] * len(lb))
|
178 | 147 | tok_ids.extend(lb.keys())
|
@@ -216,52 +185,18 @@ def is_argmax_invariant(self) -> bool:
|
216 | 185 | of the argmax operation in greedy sampling."""
|
217 | 186 | return False
|
218 | 187 |
|
219 |
| - def update_state(self, batch_update: Optional[BatchUpdate]): |
220 |
| - needs_update = False |
221 |
| - |
222 |
| - if batch_update: |
223 |
| - # Process added requests. |
224 |
| - for index, params, _, output_tok_ids in batch_update.added: |
225 |
| - if ((min_tokens := params.min_tokens) |
226 |
| - and len(output_tok_ids) < min_tokens): |
227 |
| - # Replace request metadata at batch index |
228 |
| - self.min_toks[index] = (min_tokens, output_tok_ids, |
229 |
| - params.all_stop_token_ids) |
230 |
| - needs_update = True |
231 |
| - else: |
232 |
| - # Drop min_toks metadata at batch index |
233 |
| - if self.min_toks.pop(index, None) is not None: |
234 |
| - # If a new request replaces an old request which |
235 |
| - # specified min_toks, we should update processor tensors |
236 |
| - needs_update = True |
237 |
| - |
238 |
| - if self.min_toks: |
239 |
| - # Process removed requests. |
240 |
| - for index in batch_update.removed: |
241 |
| - if self.min_toks.pop(index, None): |
242 |
| - needs_update = True |
243 |
| - |
244 |
| - # Process moved requests, unidirectional (a->b) and |
245 |
| - # swapped (a<->b) |
246 |
| - for a_index, b_index, direct in batch_update.moved: |
247 |
| - if direct == MoveDirectionality.UNIDIRECTIONAL: |
248 |
| - if (a_entry := self.min_toks.pop(a_index, |
249 |
| - None)) is None: |
250 |
| - if self.min_toks.pop(b_index, None) is not None: |
251 |
| - needs_update = True |
252 |
| - else: |
253 |
| - self.min_toks[b_index] = a_entry |
254 |
| - needs_update = True |
255 |
| - else: |
256 |
| - a_entry = self.min_toks.pop(a_index, None) |
257 |
| - if (b_entry := self.min_toks.pop(b_index, |
258 |
| - None)) is not None: |
259 |
| - self.min_toks[a_index] = b_entry |
260 |
| - needs_update = True |
261 |
| - if a_entry is not None: |
262 |
| - self.min_toks[b_index] = a_entry |
263 |
| - needs_update = True |
| 188 | + @staticmethod |
| 189 | + def add_request( |
| 190 | + params: SamplingParams, _: list[int], output_tok_ids: list[int] |
| 191 | + ) -> Optional[tuple[int, Sequence[int], set[int]]]: |
| 192 | + min_tokens = params.min_tokens |
| 193 | + if not min_tokens or len(output_tok_ids) >= min_tokens: |
| 194 | + return None |
| 195 | + return min_tokens, output_tok_ids, params.all_stop_token_ids |
264 | 196 |
|
| 197 | + def update_state(self, batch_update: Optional[BatchUpdate]): |
| 198 | + needs_update = process_dict_updates(self.min_toks, batch_update, |
| 199 | + self.add_request) |
265 | 200 | if self.min_toks:
|
266 | 201 | # Check for any requests that have attained their min tokens.
|
267 | 202 | to_remove = tuple(index for index, (min_toks, out_tok_ids,
|
@@ -295,3 +230,44 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
295 | 230 | # Inhibit EOS token for requests which have not reached min length
|
296 | 231 | logits[self.logits_slice] = -float("inf")
|
297 | 232 | return logits
|
| 233 | + |
| 234 | + |
| 235 | +def process_dict_updates( |
| 236 | + req_entries: dict[int, T], batch_update: Optional[BatchUpdate], |
| 237 | + new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] |
| 238 | +) -> bool: |
| 239 | + """Utility function to update dict state for sparse LogitsProcessors.""" |
| 240 | + |
| 241 | + if not batch_update: |
| 242 | + # Nothing to do. |
| 243 | + return False |
| 244 | + |
| 245 | + updated = False |
| 246 | + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: |
| 247 | + if (state := new_state(params, prompt_tok_ids, |
| 248 | + output_tok_ids)) is not None: |
| 249 | + req_entries[index] = state |
| 250 | + updated = True |
| 251 | + elif req_entries.pop(index, None) is not None: |
| 252 | + updated = True |
| 253 | + |
| 254 | + if req_entries: |
| 255 | + # Process removed requests. |
| 256 | + for index in batch_update.removed: |
| 257 | + if req_entries.pop(index, None): |
| 258 | + updated = True |
| 259 | + |
| 260 | + # Process moved requests, unidirectional (a->b) and |
| 261 | + # swapped (a<->b) |
| 262 | + for a_index, b_index, direct in batch_update.moved: |
| 263 | + a_entry = req_entries.pop(a_index, None) |
| 264 | + b_entry = req_entries.pop(b_index, None) |
| 265 | + if a_entry is not None: |
| 266 | + req_entries[b_index] = a_entry |
| 267 | + updated = True |
| 268 | + if b_entry is not None: |
| 269 | + updated = True |
| 270 | + if direct == MoveDirectionality.SWAP: |
| 271 | + req_entries[a_index] = b_entry |
| 272 | + |
| 273 | + return updated |
0 commit comments