3
3
from __future__ import annotations
4
4
5
5
import multiprocessing
6
- from concurrent .futures import ThreadPoolExecutor
6
+ from concurrent .futures import Future , ThreadPoolExecutor
7
7
from typing import TYPE_CHECKING , Optional
8
8
9
9
from vllm .config import VllmConfig
@@ -40,6 +40,17 @@ def __init__(self, vllm_config: VllmConfig):
40
40
self ._grammar_bitmask : Optional [torch .Tensor ] = None
41
41
self ._full_mask = torch .tensor (- 1 , dtype = torch .int32 )
42
42
43
+ max_batch_size = self .vllm_config .scheduler_config .max_num_seqs
44
+ self .fill_bitmask_parallel_threshold = 128
45
+ if self .fill_bitmask_parallel_threshold < max_batch_size :
46
+ self .fill_bitmask_parallel_batch_size = 16
47
+ # Use:
48
+ # - at least 1 CPU
49
+ # - at most half the number of CPUs or 8, whichever is less
50
+ max_workers = max (1 , min (multiprocessing .cpu_count () // 2 , 8 ))
51
+ self .executor_for_fillmask = ThreadPoolExecutor (
52
+ max_workers = max_workers )
53
+
43
54
if not self .vllm_config .model_config .skip_tokenizer_init :
44
55
# The default max_workers if not specified is the number of
45
56
# CPUs * 5, which is way too high since these tasks are CPU-bound,
@@ -120,6 +131,26 @@ def _async_create_grammar(
120
131
assert self .backend is not None
121
132
return self .backend .compile_grammar (request_type , grammar_spec )
122
133
134
+ def _fill_bitmasks (
135
+ self ,
136
+ batch : list [tuple [StructuredOutputGrammar , int , bool ]],
137
+ ) -> None :
138
+ assert self ._grammar_bitmask is not None
139
+ for grammar , index , apply_bitmask in batch :
140
+ if apply_bitmask and not grammar .is_terminated ():
141
+ grammar .fill_bitmask (self ._grammar_bitmask , index )
142
+ else :
143
+ # Note that for thinking support, we will need to
144
+ # reset the relevant part of the bitmask for consequent
145
+ # requests here.
146
+ self ._grammar_bitmask [index ].fill_ (self ._full_mask )
147
+
148
+ def _async_submit_fill_bitmask (
149
+ self ,
150
+ batch : list [tuple [StructuredOutputGrammar , int , bool ]],
151
+ ) -> Future :
152
+ return self .executor_for_fillmask .submit (self ._fill_bitmasks , batch )
153
+
123
154
def grammar_bitmask (
124
155
self ,
125
156
requests : dict [str , Request ],
@@ -146,7 +177,6 @@ def grammar_bitmask(
146
177
self .backend .allocate_token_bitmask (
147
178
max_batch_size * (1 + max_num_spec_tokens ))
148
179
149
- bitmask_tensor = self ._grammar_bitmask
150
180
# Generate a batched bitmask for all structured output requests.
151
181
# When speculative decoding is enabled, we need to include multiple
152
182
# masks for each request, one for each possible bonus token position.
@@ -155,47 +185,61 @@ def grammar_bitmask(
155
185
ordered_seq = sorted (structured_output_request_ids .items (),
156
186
key = lambda x : x [1 ])
157
187
158
- # Note that for thinking support, we will need to
159
- # reset the relevant part of the bitmask for consequent
160
- # request here.
161
- bitmask_tensor [:(len (ordered_seq ) * (1 + max_num_spec_tokens ))].fill_ (
162
- self ._full_mask )
163
-
164
- # NOTE: This outer loop can likely be parallelized to improve
165
- # performance of bitmask generation for large batches.
166
- for req_id , _ in ordered_seq :
167
- request = requests [req_id ]
168
- structured_output_request = request .structured_output_request
169
-
170
- if TYPE_CHECKING :
171
- assert structured_output_request is not None
172
- assert structured_output_request .grammar is not None
173
- apply_bitmask : bool = True
174
- if self .reasoner is not None :
175
- if structured_output_request .reasoning_ended is None :
176
- structured_output_request .reasoning_ended = \
177
- self .reasoner .is_reasoning_end (request .prompt_token_ids )
178
- apply_bitmask = structured_output_request .reasoning_ended
179
-
180
- state_advancements = 0
181
- req_tokens = scheduled_spec_decode_tokens .get (req_id , []) + [None ]
182
- for i , token in enumerate (req_tokens ):
183
- if apply_bitmask and not \
184
- structured_output_request .grammar .is_terminated ():
185
- structured_output_request .grammar .fill_bitmask (
186
- bitmask_tensor , cumulative_index )
187
- if token is not None :
188
- # In order to generate the correct bitmask for each
189
- # position in the speculative sequence, we advance
190
- # the FSM state for each speculative token and rollback
191
- # to restore the previous state when we are finished.
188
+ # Optimized parallel filling of bitmasks for
189
+ # non-spec, large-batch-size cases
190
+ if len (ordered_seq ) > self .fill_bitmask_parallel_threshold and \
191
+ max_num_spec_tokens == 0 :
192
+ promises = []
193
+ batch = []
194
+ for req_id , _ in ordered_seq :
195
+ request = requests [req_id ]
196
+ structured_output_request = request .structured_output_request
197
+ if TYPE_CHECKING :
198
+ assert structured_output_request is not None
199
+ assert structured_output_request .grammar is not None
200
+
201
+ apply_bitmask = self .should_fill_bitmask (request )
202
+ batch .append ((structured_output_request .grammar ,
203
+ cumulative_index , apply_bitmask ))
204
+ if len (batch ) == self .fill_bitmask_parallel_batch_size :
205
+ promises .append (self ._async_submit_fill_bitmask (batch ))
206
+ batch = []
207
+
208
+ cumulative_index += 1
209
+ if batch :
210
+ promises .append (self ._async_submit_fill_bitmask (batch ))
211
+
212
+ # Wait for all bitmask filling tasks to complete.
213
+ for promise in promises :
214
+ promise .result ()
215
+ else :
216
+ # Fallback to serial filling of bitmasks for small-batch-size cases
217
+ for req_id , _ in ordered_seq :
218
+ request = requests [req_id ]
219
+ structured_output_request = request .structured_output_request
220
+
221
+ if TYPE_CHECKING :
222
+ assert structured_output_request is not None
223
+ assert structured_output_request .grammar is not None
224
+ apply_bitmask = self .should_fill_bitmask (request )
225
+
226
+ state_advancements = 0
227
+ req_tokens = scheduled_spec_decode_tokens .get (req_id , [])
228
+ for i , token in enumerate (req_tokens + [None ]):
229
+ self ._fill_bitmasks ([(structured_output_request .grammar ,
230
+ cumulative_index , apply_bitmask )])
231
+
232
+ if apply_bitmask and token is not None and \
233
+ not structured_output_request .grammar .is_terminated ():
192
234
assert structured_output_request .grammar .accept_tokens (
193
235
req_id , [token ])
194
236
state_advancements += 1
195
- cumulative_index += 1
196
- if state_advancements > 0 :
197
- structured_output_request .grammar .rollback (state_advancements )
237
+ cumulative_index += 1
238
+ if state_advancements > 0 :
239
+ structured_output_request .grammar .rollback (
240
+ state_advancements )
198
241
242
+ bitmask_tensor = self ._grammar_bitmask
199
243
if cumulative_index < bitmask_tensor .shape [0 ]:
200
244
bitmask_tensor = bitmask_tensor [:cumulative_index ]
201
245
@@ -204,6 +248,15 @@ def grammar_bitmask(
204
248
# and deserialization when sending this to the GPU workers.
205
249
return bitmask_tensor .numpy ()
206
250
251
+ def should_fill_bitmask (self , request : Request ) -> bool :
252
+ if self .reasoner is not None :
253
+ assert request .structured_output_request is not None
254
+ if request .structured_output_request .reasoning_ended is None :
255
+ request .structured_output_request .reasoning_ended = \
256
+ self .reasoner .is_reasoning_end (request .prompt_token_ids )
257
+ return request .structured_output_request .reasoning_ended
258
+ return True
259
+
207
260
def should_advance (self , request : Request ) -> bool :
208
261
if not request .use_structured_output :
209
262
return False
0 commit comments