99import math
1010import os
1111import re
12- from dataclasses import asdict
12+ from dataclasses import asdict , dataclass
1313from typing import Any , AsyncIterable , Dict , List , Optional , Union
1414
1515from model_engine_server .common .config import hmi_config
2121 CompletionStreamV1Response ,
2222 CompletionSyncV1Request ,
2323 CompletionSyncV1Response ,
24+ CreateBatchCompletionsEngineRequest ,
2425 CreateBatchCompletionsRequest ,
2526 CreateBatchCompletionsResponse ,
2627 CreateLLMModelEndpointV1Request ,
@@ -2200,6 +2201,27 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl
22002201 return ModelDownloadResponse (urls = urls )
22012202
22022203
2204+ @dataclass
2205+ class VLLMEngineArgs :
2206+ gpu_memory_utilization : Optional [float ] = None
2207+
2208+
2209+ def infer_addition_engine_args_from_model_name (model_name : str ) -> VLLMEngineArgs :
2210+ numbers = re .findall (r"\d+" , model_name )
2211+ if len (numbers ) == 0 :
2212+ raise ObjectHasInvalidValueException (
2213+ f"Model { model_name } is not supported for batch completions."
2214+ )
2215+
2216+ b_params = int (numbers [- 1 ])
2217+ if b_params >= 70 :
2218+ gpu_memory_utilization = 0.95
2219+ else :
2220+ gpu_memory_utilization = 0.9
2221+
2222+ return VLLMEngineArgs (gpu_memory_utilization = gpu_memory_utilization )
2223+
2224+
22032225def infer_hardware_from_model_name (model_name : str ) -> CreateDockerImageBatchJobResourceRequests :
22042226 if "mixtral-8x7b" in model_name :
22052227 cpus = "20"
@@ -2324,14 +2346,25 @@ async def execute(
23242346 assert hardware .gpus is not None
23252347 if request .model_config .num_shards :
23262348 hardware .gpus = max (hardware .gpus , request .model_config .num_shards )
2327- request .model_config .num_shards = hardware .gpus
23282349
2329- if request .tool_config and request .tool_config .name != "code_evaluator" :
2350+ engine_request = CreateBatchCompletionsEngineRequest .from_api (request )
2351+ engine_request .model_config .num_shards = hardware .gpus
2352+
2353+ if engine_request .tool_config and engine_request .tool_config .name != "code_evaluator" :
23302354 raise ObjectHasInvalidValueException (
23312355 "Only code_evaluator tool is supported for batch completions."
23322356 )
23332357
2334- batch_bundle = await self .create_batch_job_bundle (user , request , hardware )
2358+ additional_engine_args = infer_addition_engine_args_from_model_name (
2359+ engine_request .model_config .model
2360+ )
2361+
2362+ if additional_engine_args .gpu_memory_utilization is not None :
2363+ engine_request .max_gpu_memory_utilization = (
2364+ additional_engine_args .gpu_memory_utilization
2365+ )
2366+
2367+ batch_bundle = await self .create_batch_job_bundle (user , engine_request , hardware )
23352368
23362369 validate_resource_requests (
23372370 bundle = batch_bundle ,
@@ -2342,21 +2375,21 @@ async def execute(
23422375 gpu_type = hardware .gpu_type ,
23432376 )
23442377
2345- if request .max_runtime_sec is None or request .max_runtime_sec < 1 :
2378+ if engine_request .max_runtime_sec is None or engine_request .max_runtime_sec < 1 :
23462379 raise ObjectHasInvalidValueException ("max_runtime_sec must be a positive integer." )
23472380
23482381 job_id = await self .docker_image_batch_job_gateway .create_docker_image_batch_job (
23492382 created_by = user .user_id ,
23502383 owner = user .team_id ,
2351- job_config = request .dict (),
2384+ job_config = engine_request .dict (),
23522385 env = batch_bundle .env ,
23532386 command = batch_bundle .command ,
23542387 repo = batch_bundle .image_repository ,
23552388 tag = batch_bundle .image_tag ,
23562389 resource_requests = hardware ,
2357- labels = request .model_config .labels ,
2390+ labels = engine_request .model_config .labels ,
23582391 mount_location = batch_bundle .mount_location ,
2359- override_job_max_runtime_s = request .max_runtime_sec ,
2360- num_workers = request .data_parallelism ,
2392+ override_job_max_runtime_s = engine_request .max_runtime_sec ,
2393+ num_workers = engine_request .data_parallelism ,
23612394 )
23622395 return CreateBatchCompletionsResponse (job_id = job_id )
0 commit comments