File tree Expand file tree Collapse file tree 4 files changed +42
-0
lines changed
model_engine_server/domain/use_cases Expand file tree Collapse file tree 4 files changed +42
-0
lines changed Original file line number Diff line number Diff line change @@ -677,6 +677,9 @@ async def create_vllm_bundle(
677677 if hmi_config .sensitive_log_mode : # pragma: no cover
678678 subcommands [- 1 ] = subcommands [- 1 ] + " --disable-log-requests"
679679
680+ if "llama-3-70b" in model_name :
681+ subcommands [- 1 ] = subcommands [- 1 ] + " --gpu-memory-utilization 0.95 --enforce-eager"
682+
680683 command = [
681684 "/bin/bash" ,
682685 "-c" ,
Original file line number Diff line number Diff line change @@ -763,6 +763,7 @@ def __init__(self):
763763 "llama-7b/special_tokens_map.json" : ["llama-7b/special_tokens_map.json" ],
764764 "llama-2-7b" : ["model-fake.safetensors" ],
765765 "mpt-7b" : ["model-fake.safetensors" ],
766+ "llama-3-70b" : ["model-fake.safetensors" ],
766767 }
767768 self .urls = {"filename" : "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz" }
768769 self .model_config = {
Original file line number Diff line number Diff line change @@ -292,6 +292,33 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque
292292 )
293293
294294
295+ @pytest .fixture
296+ def create_llm_model_endpoint_request_llama_3_70b () -> CreateLLMModelEndpointV1Request :
297+ return CreateLLMModelEndpointV1Request (
298+ name = "test_llm_endpoint_name_llama_3_70b" ,
299+ model_name = "llama-3-70b" ,
300+ source = "hugging_face" ,
301+ inference_framework = "vllm" ,
302+ inference_framework_image_tag = "1.0.0" ,
303+ num_shards = 2 ,
304+ endpoint_type = ModelEndpointType .STREAMING ,
305+ metadata = {},
306+ post_inference_hooks = ["billing" ],
307+ cpus = 1 ,
308+ gpus = 2 ,
309+ memory = "8G" ,
310+ gpu_type = GpuType .NVIDIA_HOPPER_H100 ,
311+ storage = "10G" ,
312+ min_workers = 1 ,
313+ max_workers = 3 ,
314+ per_worker = 2 ,
315+ labels = {"team" : "infra" , "product" : "my_product" },
316+ aws_role = "test_aws_role" ,
317+ results_s3_bucket = "test_s3_bucket" ,
318+ checkpoint_path = "s3://llama-3-70b" ,
319+ )
320+
321+
295322@pytest .fixture
296323def create_llm_model_endpoint_text_generation_inference_request_streaming () -> (
297324 CreateLLMModelEndpointV1Request
Original file line number Diff line number Diff line change @@ -80,6 +80,7 @@ async def test_create_model_endpoint_use_case_success(
8080 create_llm_model_endpoint_request_sync : CreateLLMModelEndpointV1Request ,
8181 create_llm_model_endpoint_request_streaming : CreateLLMModelEndpointV1Request ,
8282 create_llm_model_endpoint_request_llama_2 : CreateLLMModelEndpointV1Request ,
83+ create_llm_model_endpoint_request_llama_3_70b : CreateLLMModelEndpointV1Request ,
8384):
8485 fake_model_endpoint_service .model_bundle_repository = fake_model_bundle_repository
8586 bundle_use_case = CreateModelBundleV2UseCase (
@@ -182,6 +183,16 @@ async def test_create_model_endpoint_use_case_success(
182183 )
183184 assert "--max-total-tokens" in bundle .flavor .command [- 1 ] and "4096" in bundle .flavor .command [- 1 ]
184185
186+ response_5 = await use_case .execute (
187+ user = user , request = create_llm_model_endpoint_request_llama_3_70b
188+ )
189+ assert response_5 .endpoint_creation_task_id
190+ assert isinstance (response_5 , CreateLLMModelEndpointV1Response )
191+ bundle = await fake_model_bundle_repository .get_latest_model_bundle_by_name (
192+ owner = user .team_id , name = create_llm_model_endpoint_request_llama_3_70b .name
193+ )
194+ assert " --gpu-memory-utilization 0.95" in bundle .flavor .command [- 1 ]
195+
185196
186197@pytest .mark .asyncio
187198@pytest .mark .parametrize (
You can’t perform that action at this time.
0 commit comments