Skip to content

Commit 3b751fa

Browse files
authored
Handle large batches w/ guided decoding (#687)
* Handle large batches w/ guided decoding * Fix mkdocs ci issue
1 parent 3be3d8b commit 3b751fa

File tree

2 files changed

+59
-22
lines changed

2 files changed

+59
-22
lines changed

model-engine/model_engine_server/inference/vllm/vllm_batch.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858

5959
CONFIG_FILE = os.getenv("CONFIG_FILE")
6060
AWS_REGION = os.getenv("AWS_REGION", "us-west-2")
61-
MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights")
61+
MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "model_weights")
6262

6363
SKIP_AWS_PROFILE_SET = os.getenv("SKIP_AWS_PROFILE_SET", "false").lower() == "true"
6464
if not SKIP_AWS_PROFILE_SET:
@@ -191,6 +191,26 @@ async def generate_v1_completions(
191191
return outputs
192192

193193

194+
# This is needed to handle the cases where it takes too long to process all of the requests before
195+
# the configured 'VLLM_ENGINE_ITERATION_TIMEOUT_S' 30s timeout.
196+
def determine_max_concurrent_requests(
197+
requests: Union[List[CompletionRequest], List[ChatCompletionRequest]]
198+
) -> int:
199+
# Guided decoding
200+
# For example, with guided decoding, vLLM initializes a guided decoding logit processor per request, and
201+
# anecdotally, we're seeing the engine able to handle around 7req/s (for outlines), so set to 30 * 7 ~= 200
202+
if any(
203+
request.to_sampling_params(
204+
default_max_tokens=0, logits_processor_pattern=None
205+
).guided_decoding
206+
for request in requests
207+
):
208+
return 200
209+
210+
# Kinda arbitrary number
211+
return 10000
212+
213+
194214
async def generate_v2_completions(
195215
engine: EngineClient,
196216
requests: Union[List[CompletionRequest], List[ChatCompletionRequest]],
@@ -203,15 +223,28 @@ async def generate_v2_completions(
203223
Union[ErrorResponse, AsyncGenerator[str, None], CompletionResponse],
204224
]
205225
] = []
226+
227+
max_concurrent_requests = determine_max_concurrent_requests(requests)
228+
print(f"max_concurrent_requests: {max_concurrent_requests}")
229+
semaphore = asyncio.Semaphore(max_concurrent_requests)
230+
231+
async def process_request(
232+
request: Union[CompletionRequest, ChatCompletionRequest]
233+
) -> Coroutine[
234+
Any,
235+
Any,
236+
Union[ErrorResponse, AsyncGenerator[str, None], CompletionResponse],
237+
]:
238+
async with semaphore:
239+
if isinstance(request, CompletionRequest):
240+
return await openai_serving_completion.create_completion(request, dummy_request)
241+
elif isinstance(request, ChatCompletionRequest):
242+
return await openai_serving_chat.create_chat_completion(request)
243+
else:
244+
assert_never(request)
245+
206246
for request in requests:
207-
if isinstance(request, CompletionRequest):
208-
results_generators.append(
209-
openai_serving_completion.create_completion(request, dummy_request)
210-
)
211-
elif isinstance(request, ChatCompletionRequest):
212-
results_generators.append(openai_serving_chat.create_chat_completion(request))
213-
else:
214-
assert_never(request)
247+
results_generators.append(process_request(request))
215248

216249
results_generator = await_coroutines(*results_generators)
217250
outputs: List[Optional[CompletionResponse]] = [None] * len(requests)
@@ -236,7 +269,8 @@ async def generate_completions(
236269

237270

238271
async def init_engine(
239-
model: str,
272+
model_id: str,
273+
served_model_name: str,
240274
request: CreateBatchCompletionsEngineRequest,
241275
) -> EngineClient:
242276
global openai_serving_chat
@@ -253,7 +287,7 @@ async def init_engine(
253287

254288
engine_args_dict = parsed_configs.model_dump(exclude_none=True)
255289
default_engine_args_dict = dict(
256-
model=model,
290+
model=model_id,
257291
tensor_parallel_size=request.model_cfg.num_shards,
258292
pipeline_parallel_size=int(
259293
os.environ.get("NUM_INSTANCES", 1)
@@ -269,7 +303,7 @@ async def init_engine(
269303
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
270304
model_config = await engine_client.get_model_config()
271305
resolved_chat_template = load_chat_template(parsed_configs.chat_template)
272-
base_model_paths = [BaseModelPath(name=model, model_path=model)]
306+
base_model_paths = [BaseModelPath(name=served_model_name, model_path=model_id)]
273307

274308
openai_serving_chat = OpenAIServingChat(
275309
engine_client,
@@ -312,7 +346,7 @@ def load_batch_content(
312346

313347
# Recast the content to vLLMs schema
314348
if isinstance(content, List) and len(content) > 0:
315-
model = get_model_name(request.model_cfg)
349+
model = request.model_cfg.model
316350
return TypeAdapter(
317351
Union[List[CompletionRequest], List[ChatCompletionRequest]]
318352
).validate_python(
@@ -325,7 +359,7 @@ def load_batch_content(
325359
return content
326360

327361

328-
def get_model_name(model_config: BatchCompletionsModelConfig) -> str:
362+
def get_model_id(model_config: BatchCompletionsModelConfig) -> str:
329363
return MODEL_WEIGHTS_FOLDER if model_config.checkpoint_path else model_config.model
330364

331365

@@ -334,7 +368,9 @@ async def handle_batch_job(
334368
) -> None:
335369
metrics_gateway = DatadogInferenceMonitoringMetricsGateway()
336370

337-
model = get_model_name(request.model_cfg)
371+
served_model_name = request.model_cfg.model
372+
model_id = get_model_id(request.model_cfg)
373+
338374
if request.model_cfg.checkpoint_path:
339375
await download_model(
340376
checkpoint_path=request.model_cfg.checkpoint_path,
@@ -378,7 +414,8 @@ async def handle_batch_job(
378414

379415
content = load_batch_content(request)
380416
engine = await init_engine(
381-
model,
417+
model_id,
418+
served_model_name,
382419
request=request,
383420
)
384421

@@ -387,7 +424,7 @@ async def handle_batch_job(
387424
f.write(json.dumps([output.model_dump() if output else None for output in outputs]))
388425

389426
metrics_gateway.emit_batch_completions_metric(
390-
model,
427+
served_model_name,
391428
use_tool=False,
392429
num_prompt_tokens=0,
393430
num_completion_tokens=0,

requirements-docs.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
mdx-include~=1.4.2
2-
mkdocs~=1.4.2
3-
mkdocs-material~=9.1.6
4-
mkdocs-material-extensions==1.1.1
5-
mkdocs-render-swagger-plugin~=0.0.4
2+
mkdocs~=1.6.1
3+
mkdocs-material~=9.6.5
4+
mkdocs-material-extensions==1.3.1
5+
mkdocs-render-swagger-plugin~=0.1.2
66
mkdocs-simple-hooks~=0.1.5
77
mkdocs-video~=1.5.0
8-
mkdocstrings[python]~=0.24.0
8+
mkdocstrings[python]~=0.28.2
99
pydantic==2.8.2
1010
griffe<1.0
1111
neoteroi-mkdocs~=1.0.0

0 commit comments

Comments
 (0)