@@ -284,6 +284,56 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
284284 )
285285
286286
287+ def test_load_model_weights_sub_commands (
288+ fake_model_bundle_repository ,
289+ fake_model_endpoint_service ,
290+ fake_docker_repository_image_always_exists ,
291+ fake_model_primitive_gateway ,
292+ fake_llm_artifact_gateway ,
293+ ):
294+ fake_model_endpoint_service .model_bundle_repository = fake_model_bundle_repository
295+ bundle_use_case = CreateModelBundleV2UseCase (
296+ model_bundle_repository = fake_model_bundle_repository ,
297+ docker_repository = fake_docker_repository_image_always_exists ,
298+ model_primitive_gateway = fake_model_primitive_gateway ,
299+ )
300+ llm_bundle_use_case = CreateLLMModelBundleV1UseCase (
301+ create_model_bundle_use_case = bundle_use_case ,
302+ model_bundle_repository = fake_model_bundle_repository ,
303+ llm_artifact_gateway = fake_llm_artifact_gateway ,
304+ docker_repository = fake_docker_repository_image_always_exists ,
305+ )
306+
307+ framework = LLMInferenceFramework .VLLM
308+ framework_image_tag = "0.2.7"
309+ checkpoint_path = "fake-checkpoint"
310+ final_weights_folder = "test_folder"
311+
312+ subcommands = llm_bundle_use_case .load_model_weights_sub_commands (
313+ framework , framework_image_tag , checkpoint_path , final_weights_folder
314+ )
315+
316+ expected_result = [
317+ "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder" ,
318+ ]
319+ assert expected_result == subcommands
320+
321+ framework = LLMInferenceFramework .TEXT_GENERATION_INFERENCE
322+ framework_image_tag = "1.0.0"
323+ checkpoint_path = "fake-checkpoint"
324+ final_weights_folder = "test_folder"
325+
326+ subcommands = llm_bundle_use_case .load_model_weights_sub_commands (
327+ framework , framework_image_tag , checkpoint_path , final_weights_folder
328+ )
329+
330+ expected_result = [
331+ "s5cmd > /dev/null || conda install -c conda-forge -y s5cmd" ,
332+ "s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder" ,
333+ ]
334+ assert expected_result == subcommands
335+
336+
287337@pytest .mark .asyncio
288338async def test_create_model_endpoint_trt_llm_use_case_success (
289339 test_api_key : str ,
0 commit comments