Skip to content

Commit e3bb078

Browse files
authored
More rigorous endpoint update handling (#558)
* Fix metadata update * Update tests
1 parent 21cd469 commit e3bb078

File tree

3 files changed

+166
-11
lines changed

3 files changed

+166
-11
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111

112112
logger = make_logger(logger_name())
113113

114+
LLM_METADATA_KEY = "_llm"
115+
RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY]
114116

115117
INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = {
116118
LLMInferenceFramework.DEEPSPEED: "instant-llm",
@@ -279,11 +281,14 @@ async def _get_recommended_hardware_config_map() -> Dict[str, Any]:
279281
def _model_endpoint_entity_to_get_llm_model_endpoint_response(
280282
model_endpoint: ModelEndpoint,
281283
) -> GetLLMModelEndpointV1Response:
282-
if model_endpoint.record.metadata is None or "_llm" not in model_endpoint.record.metadata:
284+
if (
285+
model_endpoint.record.metadata is None
286+
or LLM_METADATA_KEY not in model_endpoint.record.metadata
287+
):
283288
raise ObjectHasInvalidValueException(
284289
f"Can't translate model entity to response, endpoint {model_endpoint.record.id} does not have LLM metadata."
285290
)
286-
llm_metadata = model_endpoint.record.metadata.get("_llm", {})
291+
llm_metadata = model_endpoint.record.metadata.get(LLM_METADATA_KEY, {})
287292
response = GetLLMModelEndpointV1Response(
288293
id=model_endpoint.record.id,
289294
name=model_endpoint.record.name,
@@ -962,7 +967,7 @@ async def execute(
962967
aws_role = self.authz_module.get_aws_role_for_user(user)
963968
results_s3_bucket = self.authz_module.get_s3_bucket_for_user(user)
964969

965-
request.metadata["_llm"] = asdict(
970+
request.metadata[LLM_METADATA_KEY] = asdict(
966971
LLMMetadata(
967972
model_name=request.model_name,
968973
source=request.source,
@@ -1088,6 +1093,16 @@ async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndp
10881093
return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)
10891094

10901095

1096+
def merge_metadata(
1097+
request: Optional[Dict[str, Any]], record: Optional[Dict[str, Any]]
1098+
) -> Optional[Dict[str, Any]]:
1099+
if request is None:
1100+
return record
1101+
if record is None:
1102+
return request
1103+
return {**record, **request}
1104+
1105+
10911106
class UpdateLLMModelEndpointV1UseCase:
10921107
def __init__(
10931108
self,
@@ -1131,6 +1146,7 @@ async def execute(
11311146
raise EndpointInfraStateNotFound(error_msg)
11321147

11331148
infra_state = model_endpoint.infra_state
1149+
metadata: Optional[Dict[str, Any]]
11341150

11351151
if (
11361152
request.model_name
@@ -1140,7 +1156,7 @@ async def execute(
11401156
or request.quantize
11411157
or request.checkpoint_path
11421158
):
1143-
llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {})
1159+
llm_metadata = (model_endpoint.record.metadata or {}).get(LLM_METADATA_KEY, {})
11441160
inference_framework = llm_metadata["inference_framework"]
11451161

11461162
if request.inference_framework_image_tag == "latest":
@@ -1177,7 +1193,7 @@ async def execute(
11771193
)
11781194

11791195
metadata = endpoint_record.metadata or {}
1180-
metadata["_llm"] = asdict(
1196+
metadata[LLM_METADATA_KEY] = asdict(
11811197
LLMMetadata(
11821198
model_name=model_name,
11831199
source=source,
@@ -1188,7 +1204,7 @@ async def execute(
11881204
checkpoint_path=checkpoint_path,
11891205
)
11901206
)
1191-
request.metadata = metadata
1207+
endpoint_record.metadata = metadata
11921208

11931209
# For resources that are not specified in the update endpoint request, pass in resource from
11941210
# infra_state to make sure that after the update, all resources are valid and in sync.
@@ -1209,15 +1225,20 @@ async def execute(
12091225
endpoint_type=endpoint_record.endpoint_type,
12101226
)
12111227

1212-
if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata:
1213-
raise ObjectHasInvalidValueException(
1214-
f"{CONVERTED_FROM_ARTIFACT_LIKE_KEY} is a reserved metadata key and cannot be used by user."
1215-
)
1228+
if request.metadata is not None:
1229+
# If reserved metadata key is provided, throw ObjectHasInvalidValueException
1230+
for key in RESERVED_METADATA_KEYS:
1231+
if key in request.metadata:
1232+
raise ObjectHasInvalidValueException(
1233+
f"{key} is a reserved metadata key and cannot be used by user."
1234+
)
1235+
1236+
metadata = merge_metadata(request.metadata, endpoint_record.metadata)
12161237

12171238
updated_endpoint_record = await self.model_endpoint_service.update_model_endpoint(
12181239
model_endpoint_id=model_endpoint_id,
12191240
model_bundle_id=bundle.id,
1220-
metadata=request.metadata,
1241+
metadata=metadata,
12211242
post_inference_hooks=request.post_inference_hooks,
12221243
cpus=request.cpus,
12231244
gpus=request.gpus,

model-engine/tests/unit/domain/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
Quantization,
3232
StreamingEnhancedRunnableImageFlavor,
3333
)
34+
from model_engine_server.domain.use_cases.model_endpoint_use_cases import (
35+
CONVERTED_FROM_ARTIFACT_LIKE_KEY,
36+
)
3437

3538

3639
@pytest.fixture
@@ -265,6 +268,19 @@ def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request:
265268
)
266269

267270

271+
@pytest.fixture
272+
def update_llm_model_endpoint_request_only_workers() -> UpdateLLMModelEndpointV1Request:
273+
return UpdateLLMModelEndpointV1Request(
274+
min_workers=5,
275+
max_workers=10,
276+
)
277+
278+
279+
@pytest.fixture
280+
def update_llm_model_endpoint_request_bad_metadata() -> UpdateLLMModelEndpointV1Request:
281+
return UpdateLLMModelEndpointV1Request(metadata={CONVERTED_FROM_ARTIFACT_LIKE_KEY: {}})
282+
283+
268284
@pytest.fixture
269285
def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Request:
270286
return CreateLLMModelEndpointV1Request(

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
UpdateLLMModelEndpointV1UseCase,
5252
_fill_hardware_info,
5353
_infer_hardware,
54+
merge_metadata,
5455
validate_and_update_completion_params,
5556
validate_checkpoint_files,
5657
)
@@ -614,6 +615,7 @@ async def test_update_model_endpoint_use_case_success(
614615
fake_llm_model_endpoint_service,
615616
create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request,
616617
update_llm_model_endpoint_request: UpdateLLMModelEndpointV1Request,
618+
update_llm_model_endpoint_request_only_workers: UpdateLLMModelEndpointV1Request,
617619
):
618620
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
619621
bundle_use_case = CreateModelBundleV2UseCase(
@@ -687,6 +689,102 @@ async def test_update_model_endpoint_use_case_success(
687689
== update_llm_model_endpoint_request.max_workers
688690
)
689691

692+
update_response2 = await update_use_case.execute(
693+
user=user,
694+
model_endpoint_name=create_llm_model_endpoint_request_streaming.name,
695+
request=update_llm_model_endpoint_request_only_workers,
696+
)
697+
assert update_response2.endpoint_creation_task_id
698+
699+
endpoint = (
700+
await fake_model_endpoint_service.list_model_endpoints(
701+
owner=None,
702+
name=create_llm_model_endpoint_request_streaming.name,
703+
order_by=None,
704+
)
705+
)[0]
706+
assert endpoint.record.metadata == {
707+
"_llm": {
708+
"model_name": create_llm_model_endpoint_request_streaming.model_name,
709+
"source": create_llm_model_endpoint_request_streaming.source,
710+
"inference_framework": create_llm_model_endpoint_request_streaming.inference_framework,
711+
"inference_framework_image_tag": "fake_docker_repository_latest_image_tag",
712+
"num_shards": create_llm_model_endpoint_request_streaming.num_shards,
713+
"quantize": None,
714+
"checkpoint_path": update_llm_model_endpoint_request.checkpoint_path,
715+
}
716+
}
717+
assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory
718+
assert (
719+
endpoint.infra_state.deployment_state.min_workers
720+
== update_llm_model_endpoint_request_only_workers.min_workers
721+
)
722+
assert (
723+
endpoint.infra_state.deployment_state.max_workers
724+
== update_llm_model_endpoint_request_only_workers.max_workers
725+
)
726+
727+
728+
@pytest.mark.asyncio
729+
@mock.patch(
730+
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag",
731+
mocked__get_latest_tag(),
732+
)
733+
async def test_update_model_endpoint_use_case_failure(
734+
test_api_key: str,
735+
fake_model_bundle_repository,
736+
fake_model_endpoint_service,
737+
fake_docker_repository_image_always_exists,
738+
fake_model_primitive_gateway,
739+
fake_llm_artifact_gateway,
740+
fake_llm_model_endpoint_service,
741+
create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request,
742+
update_llm_model_endpoint_request_bad_metadata: UpdateLLMModelEndpointV1Request,
743+
):
744+
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
745+
bundle_use_case = CreateModelBundleV2UseCase(
746+
model_bundle_repository=fake_model_bundle_repository,
747+
docker_repository=fake_docker_repository_image_always_exists,
748+
model_primitive_gateway=fake_model_primitive_gateway,
749+
)
750+
llm_bundle_use_case = CreateLLMModelBundleV1UseCase(
751+
create_model_bundle_use_case=bundle_use_case,
752+
model_bundle_repository=fake_model_bundle_repository,
753+
llm_artifact_gateway=fake_llm_artifact_gateway,
754+
docker_repository=fake_docker_repository_image_always_exists,
755+
)
756+
create_use_case = CreateLLMModelEndpointV1UseCase(
757+
create_llm_model_bundle_use_case=llm_bundle_use_case,
758+
model_endpoint_service=fake_model_endpoint_service,
759+
docker_repository=fake_docker_repository_image_always_exists,
760+
llm_artifact_gateway=fake_llm_artifact_gateway,
761+
)
762+
update_use_case = UpdateLLMModelEndpointV1UseCase(
763+
create_llm_model_bundle_use_case=llm_bundle_use_case,
764+
model_endpoint_service=fake_model_endpoint_service,
765+
llm_model_endpoint_service=fake_llm_model_endpoint_service,
766+
docker_repository=fake_docker_repository_image_always_exists,
767+
)
768+
769+
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
770+
771+
await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming)
772+
endpoint = (
773+
await fake_model_endpoint_service.list_model_endpoints(
774+
owner=None,
775+
name=create_llm_model_endpoint_request_streaming.name,
776+
order_by=None,
777+
)
778+
)[0]
779+
fake_llm_model_endpoint_service.add_model_endpoint(endpoint)
780+
781+
with pytest.raises(ObjectHasInvalidValueException):
782+
await update_use_case.execute(
783+
user=user,
784+
model_endpoint_name=create_llm_model_endpoint_request_streaming.name,
785+
request=update_llm_model_endpoint_request_bad_metadata,
786+
)
787+
690788

691789
def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa
692790
class mocked_encode:
@@ -2241,3 +2339,23 @@ async def test_create_batch_completions(
22412339
"-c",
22422340
"ddtrace-run python vllm_batch.py",
22432341
]
2342+
2343+
2344+
def test_merge_metadata():
2345+
request_metadata = {
2346+
"key1": "value1",
2347+
"key2": "value2",
2348+
}
2349+
2350+
endpoint_metadata = {
2351+
"key1": "value0",
2352+
"key3": "value3",
2353+
}
2354+
2355+
assert merge_metadata(request_metadata, None) == request_metadata
2356+
assert merge_metadata(None, endpoint_metadata) == endpoint_metadata
2357+
assert merge_metadata(request_metadata, endpoint_metadata) == {
2358+
"key1": "value1",
2359+
"key2": "value2",
2360+
"key3": "value3",
2361+
}

0 commit comments

Comments
 (0)