Skip to content

Commit 7e6544c

Browse files
authored
[Perf] Parallelize fill_bitmask to accelerate high-throughput guided decoding (#21862)
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent 8e6c7e8 commit 7e6544c

File tree

3 files changed

+105
-42
lines changed

3 files changed

+105
-42
lines changed

vllm/v1/structured_output/__init__.py

Lines changed: 92 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import multiprocessing
6-
from concurrent.futures import ThreadPoolExecutor
6+
from concurrent.futures import Future, ThreadPoolExecutor
77
from typing import TYPE_CHECKING, Optional
88

99
from vllm.config import VllmConfig
@@ -40,6 +40,17 @@ def __init__(self, vllm_config: VllmConfig):
4040
self._grammar_bitmask: Optional[torch.Tensor] = None
4141
self._full_mask = torch.tensor(-1, dtype=torch.int32)
4242

43+
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
44+
self.fill_bitmask_parallel_threshold = 128
45+
if self.fill_bitmask_parallel_threshold < max_batch_size:
46+
self.fill_bitmask_parallel_batch_size = 16
47+
# Use:
48+
# - at least 1 CPU
49+
# - at most half the number of CPUs or 8, whichever is less
50+
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
51+
self.executor_for_fillmask = ThreadPoolExecutor(
52+
max_workers=max_workers)
53+
4354
if not self.vllm_config.model_config.skip_tokenizer_init:
4455
# The default max_workers if not specified is the number of
4556
# CPUs * 5, which is way too high since these tasks are CPU-bound,
@@ -120,6 +131,26 @@ def _async_create_grammar(
120131
assert self.backend is not None
121132
return self.backend.compile_grammar(request_type, grammar_spec)
122133

134+
def _fill_bitmasks(
135+
self,
136+
batch: list[tuple[StructuredOutputGrammar, int, bool]],
137+
) -> None:
138+
assert self._grammar_bitmask is not None
139+
for grammar, index, apply_bitmask in batch:
140+
if apply_bitmask and not grammar.is_terminated():
141+
grammar.fill_bitmask(self._grammar_bitmask, index)
142+
else:
143+
# Note that for thinking support, we will need to
144+
# reset the relevant part of the bitmask for consequent
145+
# requests here.
146+
self._grammar_bitmask[index].fill_(self._full_mask)
147+
148+
def _async_submit_fill_bitmask(
149+
self,
150+
batch: list[tuple[StructuredOutputGrammar, int, bool]],
151+
) -> Future:
152+
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
153+
123154
def grammar_bitmask(
124155
self,
125156
requests: dict[str, Request],
@@ -146,7 +177,6 @@ def grammar_bitmask(
146177
self.backend.allocate_token_bitmask(
147178
max_batch_size * (1 + max_num_spec_tokens))
148179

149-
bitmask_tensor = self._grammar_bitmask
150180
# Generate a batched bitmask for all structured output requests.
151181
# When speculative decoding is enabled, we need to include multiple
152182
# masks for each request, one for each possible bonus token position.
@@ -155,47 +185,61 @@ def grammar_bitmask(
155185
ordered_seq = sorted(structured_output_request_ids.items(),
156186
key=lambda x: x[1])
157187

158-
# Note that for thinking support, we will need to
159-
# reset the relevant part of the bitmask for consequent
160-
# request here.
161-
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
162-
self._full_mask)
163-
164-
# NOTE: This outer loop can likely be parallelized to improve
165-
# performance of bitmask generation for large batches.
166-
for req_id, _ in ordered_seq:
167-
request = requests[req_id]
168-
structured_output_request = request.structured_output_request
169-
170-
if TYPE_CHECKING:
171-
assert structured_output_request is not None
172-
assert structured_output_request.grammar is not None
173-
apply_bitmask: bool = True
174-
if self.reasoner is not None:
175-
if structured_output_request.reasoning_ended is None:
176-
structured_output_request.reasoning_ended = \
177-
self.reasoner.is_reasoning_end(request.prompt_token_ids)
178-
apply_bitmask = structured_output_request.reasoning_ended
179-
180-
state_advancements = 0
181-
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
182-
for i, token in enumerate(req_tokens):
183-
if apply_bitmask and not \
184-
structured_output_request.grammar.is_terminated():
185-
structured_output_request.grammar.fill_bitmask(
186-
bitmask_tensor, cumulative_index)
187-
if token is not None:
188-
# In order to generate the correct bitmask for each
189-
# position in the speculative sequence, we advance
190-
# the FSM state for each speculative token and rollback
191-
# to restore the previous state when we are finished.
188+
# Optimized parallel filling of bitmasks for
189+
# non-spec, large-batch-size cases
190+
if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
191+
max_num_spec_tokens == 0:
192+
promises = []
193+
batch = []
194+
for req_id, _ in ordered_seq:
195+
request = requests[req_id]
196+
structured_output_request = request.structured_output_request
197+
if TYPE_CHECKING:
198+
assert structured_output_request is not None
199+
assert structured_output_request.grammar is not None
200+
201+
apply_bitmask = self.should_fill_bitmask(request)
202+
batch.append((structured_output_request.grammar,
203+
cumulative_index, apply_bitmask))
204+
if len(batch) == self.fill_bitmask_parallel_batch_size:
205+
promises.append(self._async_submit_fill_bitmask(batch))
206+
batch = []
207+
208+
cumulative_index += 1
209+
if batch:
210+
promises.append(self._async_submit_fill_bitmask(batch))
211+
212+
# Wait for all bitmask filling tasks to complete.
213+
for promise in promises:
214+
promise.result()
215+
else:
216+
# Fallback to serial filling of bitmasks for small-batch-size cases
217+
for req_id, _ in ordered_seq:
218+
request = requests[req_id]
219+
structured_output_request = request.structured_output_request
220+
221+
if TYPE_CHECKING:
222+
assert structured_output_request is not None
223+
assert structured_output_request.grammar is not None
224+
apply_bitmask = self.should_fill_bitmask(request)
225+
226+
state_advancements = 0
227+
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
228+
for i, token in enumerate(req_tokens + [None]):
229+
self._fill_bitmasks([(structured_output_request.grammar,
230+
cumulative_index, apply_bitmask)])
231+
232+
if apply_bitmask and token is not None and \
233+
not structured_output_request.grammar.is_terminated():
192234
assert structured_output_request.grammar.accept_tokens(
193235
req_id, [token])
194236
state_advancements += 1
195-
cumulative_index += 1
196-
if state_advancements > 0:
197-
structured_output_request.grammar.rollback(state_advancements)
237+
cumulative_index += 1
238+
if state_advancements > 0:
239+
structured_output_request.grammar.rollback(
240+
state_advancements)
198241

242+
bitmask_tensor = self._grammar_bitmask
199243
if cumulative_index < bitmask_tensor.shape[0]:
200244
bitmask_tensor = bitmask_tensor[:cumulative_index]
201245

@@ -204,6 +248,15 @@ def grammar_bitmask(
204248
# and deserialization when sending this to the GPU workers.
205249
return bitmask_tensor.numpy()
206250

251+
def should_fill_bitmask(self, request: Request) -> bool:
252+
if self.reasoner is not None:
253+
assert request.structured_output_request is not None
254+
if request.structured_output_request.reasoning_ended is None:
255+
request.structured_output_request.reasoning_ended = \
256+
self.reasoner.is_reasoning_end(request.prompt_token_ids)
257+
return request.structured_output_request.reasoning_ended
258+
return True
259+
207260
def should_advance(self, request: Request) -> bool:
208261
if not request.use_structured_output:
209262
return False

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,20 +148,24 @@ class XgrammarGrammar(StructuredOutputGrammar):
148148
repr=False,
149149
hash=False,
150150
init=False)
151+
_is_terminated: bool = field(default=False, repr=False, hash=False)
151152

152153
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
153154
"""Accepts a list of tokens and advances the FSM.
154155
155156
Returns True if the FSM was advanced successfully.
156157
Returns False if the FSM failed to advance.
157158
"""
159+
if self._is_terminated:
160+
return False
158161
for token in tokens:
159162
if not self.matcher.accept_token(token):
160163
logger.error(
161164
"Failed to advance FSM for request %s "
162165
"for tokens %s. Please file an issue.", request_id, token)
163166
return False
164167
self.num_processed_tokens += 1
168+
self._is_terminated = self.matcher.is_terminated()
165169
return True
166170

167171
def validate_tokens(self, tokens: list[int]) -> list[int]:
@@ -184,12 +188,13 @@ def validate_tokens(self, tokens: list[int]) -> list[int]:
184188
def rollback(self, num_tokens: int) -> None:
185189
self.matcher.rollback(num_tokens)
186190
self.num_processed_tokens -= num_tokens
191+
self._is_terminated = self.matcher.is_terminated()
187192

188193
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
189194
self.matcher.fill_next_token_bitmask(bitmask, idx)
190195

191196
def is_terminated(self) -> bool:
192-
return self.matcher.is_terminated()
197+
return self._is_terminated
193198

194199
def reset(self):
195200
self.num_processed_tokens = 0

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,17 +1324,22 @@ def apply_grammar_bitmask(
13241324
cumulative_index += 1 + num_spec_tokens
13251325
grammar_bitmask = sorted_bitmask
13261326

1327+
# If the grammar bitmask and the logits have the same shape
1328+
# we don't need to pass indices to the kernel,
1329+
# since the bitmask is already aligned with the logits.
1330+
skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0]
1331+
13271332
# Serialization of np.ndarray is much more efficient than a tensor,
13281333
# so we receive it in that format.
1329-
grammar_bitmask = torch.from_numpy(grammar_bitmask)
1334+
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
13301335

13311336
# Force use of the torch.compile implementation from xgrammar to work
13321337
# around issues with the Triton kernel in concurrent structured output
13331338
# scenarios. See PR #19565 and issues #19493, #18376 for details.
13341339
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
13351340
logits,
13361341
grammar_bitmask.to(self.device, non_blocking=True),
1337-
indices=out_indices,
1342+
indices=out_indices if not skip_out_indices else None,
13381343
)
13391344

13401345
def sync_and_slice_intermediate_tensors(

0 commit comments

Comments
 (0)