5858
5959CONFIG_FILE = os .getenv ("CONFIG_FILE" )
6060AWS_REGION = os .getenv ("AWS_REGION" , "us-west-2" )
61- MODEL_WEIGHTS_FOLDER = os .getenv ("MODEL_WEIGHTS_FOLDER" , "./ model_weights" )
61+ MODEL_WEIGHTS_FOLDER = os .getenv ("MODEL_WEIGHTS_FOLDER" , "model_weights" )
6262
6363SKIP_AWS_PROFILE_SET = os .getenv ("SKIP_AWS_PROFILE_SET" , "false" ).lower () == "true"
6464if not SKIP_AWS_PROFILE_SET :
@@ -191,6 +191,26 @@ async def generate_v1_completions(
191191 return outputs
192192
193193
194+ # This is needed to handle the cases where it takes too long to process all of the requests before
195+ # the configured 'VLLM_ENGINE_ITERATION_TIMEOUT_S' 30s timeout.
196+ def determine_max_concurrent_requests (
197+ requests : Union [List [CompletionRequest ], List [ChatCompletionRequest ]]
198+ ) -> int :
199+ # Guided decoding
200+ # For example, with guided decoding, vLLM initializes a guided decoding logit processor per request, and
201+ # anecdotally, we're seeing the engine able to handle around 7req/s (for outlines), so set to 30 * 7 ~= 200
202+ if any (
203+ request .to_sampling_params (
204+ default_max_tokens = 0 , logits_processor_pattern = None
205+ ).guided_decoding
206+ for request in requests
207+ ):
208+ return 200
209+
210+ # Kinda arbitrary number
211+ return 10000
212+
213+
194214async def generate_v2_completions (
195215 engine : EngineClient ,
196216 requests : Union [List [CompletionRequest ], List [ChatCompletionRequest ]],
@@ -203,15 +223,28 @@ async def generate_v2_completions(
203223 Union [ErrorResponse , AsyncGenerator [str , None ], CompletionResponse ],
204224 ]
205225 ] = []
226+
227+ max_concurrent_requests = determine_max_concurrent_requests (requests )
228+ print (f"max_concurrent_requests: { max_concurrent_requests } " )
229+ semaphore = asyncio .Semaphore (max_concurrent_requests )
230+
231+ async def process_request (
232+ request : Union [CompletionRequest , ChatCompletionRequest ]
233+ ) -> Coroutine [
234+ Any ,
235+ Any ,
236+ Union [ErrorResponse , AsyncGenerator [str , None ], CompletionResponse ],
237+ ]:
238+ async with semaphore :
239+ if isinstance (request , CompletionRequest ):
240+ return await openai_serving_completion .create_completion (request , dummy_request )
241+ elif isinstance (request , ChatCompletionRequest ):
242+ return await openai_serving_chat .create_chat_completion (request )
243+ else :
244+ assert_never (request )
245+
206246 for request in requests :
207- if isinstance (request , CompletionRequest ):
208- results_generators .append (
209- openai_serving_completion .create_completion (request , dummy_request )
210- )
211- elif isinstance (request , ChatCompletionRequest ):
212- results_generators .append (openai_serving_chat .create_chat_completion (request ))
213- else :
214- assert_never (request )
247+ results_generators .append (process_request (request ))
215248
216249 results_generator = await_coroutines (* results_generators )
217250 outputs : List [Optional [CompletionResponse ]] = [None ] * len (requests )
@@ -236,7 +269,8 @@ async def generate_completions(
236269
237270
238271async def init_engine (
239- model : str ,
272+ model_id : str ,
273+ served_model_name : str ,
240274 request : CreateBatchCompletionsEngineRequest ,
241275) -> EngineClient :
242276 global openai_serving_chat
@@ -253,7 +287,7 @@ async def init_engine(
253287
254288 engine_args_dict = parsed_configs .model_dump (exclude_none = True )
255289 default_engine_args_dict = dict (
256- model = model ,
290+ model = model_id ,
257291 tensor_parallel_size = request .model_cfg .num_shards ,
258292 pipeline_parallel_size = int (
259293 os .environ .get ("NUM_INSTANCES" , 1 )
@@ -269,7 +303,7 @@ async def init_engine(
269303 engine_client = AsyncLLMEngine .from_engine_args (engine_args )
270304 model_config = await engine_client .get_model_config ()
271305 resolved_chat_template = load_chat_template (parsed_configs .chat_template )
272- base_model_paths = [BaseModelPath (name = model , model_path = model )]
306+ base_model_paths = [BaseModelPath (name = served_model_name , model_path = model_id )]
273307
274308 openai_serving_chat = OpenAIServingChat (
275309 engine_client ,
@@ -312,7 +346,7 @@ def load_batch_content(
312346
313347 # Recast the content to vLLMs schema
314348 if isinstance (content , List ) and len (content ) > 0 :
315- model = get_model_name ( request .model_cfg )
349+ model = request .model_cfg . model
316350 return TypeAdapter (
317351 Union [List [CompletionRequest ], List [ChatCompletionRequest ]]
318352 ).validate_python (
@@ -325,7 +359,7 @@ def load_batch_content(
325359 return content
326360
327361
328- def get_model_name (model_config : BatchCompletionsModelConfig ) -> str :
362+ def get_model_id (model_config : BatchCompletionsModelConfig ) -> str :
329363 return MODEL_WEIGHTS_FOLDER if model_config .checkpoint_path else model_config .model
330364
331365
@@ -334,7 +368,9 @@ async def handle_batch_job(
334368) -> None :
335369 metrics_gateway = DatadogInferenceMonitoringMetricsGateway ()
336370
337- model = get_model_name (request .model_cfg )
371+ served_model_name = request .model_cfg .model
372+ model_id = get_model_id (request .model_cfg )
373+
338374 if request .model_cfg .checkpoint_path :
339375 await download_model (
340376 checkpoint_path = request .model_cfg .checkpoint_path ,
@@ -378,7 +414,8 @@ async def handle_batch_job(
378414
379415 content = load_batch_content (request )
380416 engine = await init_engine (
381- model ,
417+ model_id ,
418+ served_model_name ,
382419 request = request ,
383420 )
384421
@@ -387,7 +424,7 @@ async def handle_batch_job(
387424 f .write (json .dumps ([output .model_dump () if output else None for output in outputs ]))
388425
389426 metrics_gateway .emit_batch_completions_metric (
390- model ,
427+ served_model_name ,
391428 use_tool = False ,
392429 num_prompt_tokens = 0 ,
393430 num_completion_tokens = 0 ,
0 commit comments