|
51 | 51 | UpdateLLMModelEndpointV1UseCase, |
52 | 52 | _fill_hardware_info, |
53 | 53 | _infer_hardware, |
| 54 | + merge_metadata, |
54 | 55 | validate_and_update_completion_params, |
55 | 56 | validate_checkpoint_files, |
56 | 57 | ) |
@@ -614,6 +615,7 @@ async def test_update_model_endpoint_use_case_success( |
614 | 615 | fake_llm_model_endpoint_service, |
615 | 616 | create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, |
616 | 617 | update_llm_model_endpoint_request: UpdateLLMModelEndpointV1Request, |
| 618 | + update_llm_model_endpoint_request_only_workers: UpdateLLMModelEndpointV1Request, |
617 | 619 | ): |
618 | 620 | fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository |
619 | 621 | bundle_use_case = CreateModelBundleV2UseCase( |
@@ -687,6 +689,102 @@ async def test_update_model_endpoint_use_case_success( |
687 | 689 | == update_llm_model_endpoint_request.max_workers |
688 | 690 | ) |
689 | 691 |
|
| 692 | + update_response2 = await update_use_case.execute( |
| 693 | + user=user, |
| 694 | + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, |
| 695 | + request=update_llm_model_endpoint_request_only_workers, |
| 696 | + ) |
| 697 | + assert update_response2.endpoint_creation_task_id |
| 698 | + |
| 699 | + endpoint = ( |
| 700 | + await fake_model_endpoint_service.list_model_endpoints( |
| 701 | + owner=None, |
| 702 | + name=create_llm_model_endpoint_request_streaming.name, |
| 703 | + order_by=None, |
| 704 | + ) |
| 705 | + )[0] |
| 706 | + assert endpoint.record.metadata == { |
| 707 | + "_llm": { |
| 708 | + "model_name": create_llm_model_endpoint_request_streaming.model_name, |
| 709 | + "source": create_llm_model_endpoint_request_streaming.source, |
| 710 | + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, |
| 711 | + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", |
| 712 | + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, |
| 713 | + "quantize": None, |
| 714 | + "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, |
| 715 | + } |
| 716 | + } |
| 717 | + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory |
| 718 | + assert ( |
| 719 | + endpoint.infra_state.deployment_state.min_workers |
| 720 | + == update_llm_model_endpoint_request_only_workers.min_workers |
| 721 | + ) |
| 722 | + assert ( |
| 723 | + endpoint.infra_state.deployment_state.max_workers |
| 724 | + == update_llm_model_endpoint_request_only_workers.max_workers |
| 725 | + ) |
| 726 | + |
| 727 | + |
| 728 | +@pytest.mark.asyncio |
| 729 | +@mock.patch( |
| 730 | + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", |
| 731 | + mocked__get_latest_tag(), |
| 732 | +) |
| 733 | +async def test_update_model_endpoint_use_case_failure( |
| 734 | + test_api_key: str, |
| 735 | + fake_model_bundle_repository, |
| 736 | + fake_model_endpoint_service, |
| 737 | + fake_docker_repository_image_always_exists, |
| 738 | + fake_model_primitive_gateway, |
| 739 | + fake_llm_artifact_gateway, |
| 740 | + fake_llm_model_endpoint_service, |
| 741 | + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, |
| 742 | + update_llm_model_endpoint_request_bad_metadata: UpdateLLMModelEndpointV1Request, |
| 743 | +): |
| 744 | + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository |
| 745 | + bundle_use_case = CreateModelBundleV2UseCase( |
| 746 | + model_bundle_repository=fake_model_bundle_repository, |
| 747 | + docker_repository=fake_docker_repository_image_always_exists, |
| 748 | + model_primitive_gateway=fake_model_primitive_gateway, |
| 749 | + ) |
| 750 | + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( |
| 751 | + create_model_bundle_use_case=bundle_use_case, |
| 752 | + model_bundle_repository=fake_model_bundle_repository, |
| 753 | + llm_artifact_gateway=fake_llm_artifact_gateway, |
| 754 | + docker_repository=fake_docker_repository_image_always_exists, |
| 755 | + ) |
| 756 | + create_use_case = CreateLLMModelEndpointV1UseCase( |
| 757 | + create_llm_model_bundle_use_case=llm_bundle_use_case, |
| 758 | + model_endpoint_service=fake_model_endpoint_service, |
| 759 | + docker_repository=fake_docker_repository_image_always_exists, |
| 760 | + llm_artifact_gateway=fake_llm_artifact_gateway, |
| 761 | + ) |
| 762 | + update_use_case = UpdateLLMModelEndpointV1UseCase( |
| 763 | + create_llm_model_bundle_use_case=llm_bundle_use_case, |
| 764 | + model_endpoint_service=fake_model_endpoint_service, |
| 765 | + llm_model_endpoint_service=fake_llm_model_endpoint_service, |
| 766 | + docker_repository=fake_docker_repository_image_always_exists, |
| 767 | + ) |
| 768 | + |
| 769 | + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) |
| 770 | + |
| 771 | + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) |
| 772 | + endpoint = ( |
| 773 | + await fake_model_endpoint_service.list_model_endpoints( |
| 774 | + owner=None, |
| 775 | + name=create_llm_model_endpoint_request_streaming.name, |
| 776 | + order_by=None, |
| 777 | + ) |
| 778 | + )[0] |
| 779 | + fake_llm_model_endpoint_service.add_model_endpoint(endpoint) |
| 780 | + |
| 781 | + with pytest.raises(ObjectHasInvalidValueException): |
| 782 | + await update_use_case.execute( |
| 783 | + user=user, |
| 784 | + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, |
| 785 | + request=update_llm_model_endpoint_request_bad_metadata, |
| 786 | + ) |
| 787 | + |
690 | 788 |
|
691 | 789 | def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa |
692 | 790 | class mocked_encode: |
@@ -2241,3 +2339,23 @@ async def test_create_batch_completions( |
2241 | 2339 | "-c", |
2242 | 2340 | "ddtrace-run python vllm_batch.py", |
2243 | 2341 | ] |
| 2342 | + |
| 2343 | + |
| 2344 | +def test_merge_metadata(): |
| 2345 | + request_metadata = { |
| 2346 | + "key1": "value1", |
| 2347 | + "key2": "value2", |
| 2348 | + } |
| 2349 | + |
| 2350 | + endpoint_metadata = { |
| 2351 | + "key1": "value0", |
| 2352 | + "key3": "value3", |
| 2353 | + } |
| 2354 | + |
| 2355 | + assert merge_metadata(request_metadata, None) == request_metadata |
| 2356 | + assert merge_metadata(None, endpoint_metadata) == endpoint_metadata |
| 2357 | + assert merge_metadata(request_metadata, endpoint_metadata) == { |
| 2358 | + "key1": "value1", |
| 2359 | + "key2": "value2", |
| 2360 | + "key3": "value3", |
| 2361 | + } |
0 commit comments