Skip to content

Commit b7284df

Browse files
authored
Enforce model checkpoints existing for endpoint/bundle creation (#503)
* Enforce model checkpoints existing for endpoint/bundle creation * Add test mock for good models info * Clean up checkpoint validation * Rename validate to get for semantics
1 parent 55d538b commit b7284df

File tree

3 files changed

+157
-62
lines changed

3 files changed

+157
-62
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,27 @@ def validate_quantization(
329329
)
330330

331331

332+
def validate_checkpoint_path_uri(checkpoint_path: str) -> None:
333+
if not checkpoint_path.startswith("s3://"):
334+
raise ObjectHasInvalidValueException(
335+
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
336+
)
337+
338+
339+
def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str:
340+
checkpoint_path = (
341+
SUPPORTED_MODELS_INFO[model_name].s3_repo
342+
if not checkpoint_path_override
343+
else checkpoint_path_override
344+
)
345+
346+
if not checkpoint_path:
347+
raise InvalidRequestException(f"No checkpoint path found for model {model_name}")
348+
349+
validate_checkpoint_path_uri(checkpoint_path)
350+
return checkpoint_path
351+
352+
332353
class CreateLLMModelBundleV1UseCase:
333354
def __init__(
334355
self,
@@ -449,22 +470,16 @@ async def create_text_generation_inference_bundle(
449470
max_total_tokens = 4096
450471

451472
subcommands = []
452-
if checkpoint_path is not None:
453-
if checkpoint_path.startswith("s3://"):
454-
final_weights_folder = "model_files"
455473

456-
subcommands += self.load_model_weights_sub_commands(
457-
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
458-
framework_image_tag,
459-
checkpoint_path,
460-
final_weights_folder,
461-
)
462-
else:
463-
raise ObjectHasInvalidValueException(
464-
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
465-
)
466-
else:
467-
final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo
474+
checkpoint_path = get_checkpoint_path(model_name, checkpoint_path)
475+
final_weights_folder = "model_files"
476+
477+
subcommands += self.load_model_weights_sub_commands(
478+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
479+
framework_image_tag,
480+
checkpoint_path,
481+
final_weights_folder,
482+
)
468483

469484
subcommands.append(
470485
f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}"
@@ -672,25 +687,19 @@ async def create_vllm_bundle(
672687
break
673688

674689
subcommands = []
675-
if checkpoint_path is not None:
676-
if checkpoint_path.startswith("s3://"):
677-
# added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder
678-
if "mistral" in model_name:
679-
final_weights_folder = "mistral_files"
680-
else:
681-
final_weights_folder = "model_files"
682-
subcommands += self.load_model_weights_sub_commands(
683-
LLMInferenceFramework.VLLM,
684-
framework_image_tag,
685-
checkpoint_path,
686-
final_weights_folder,
687-
)
688-
else:
689-
raise ObjectHasInvalidValueException(
690-
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
691-
)
690+
691+
checkpoint_path = get_checkpoint_path(model_name, checkpoint_path)
692+
# added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder
693+
if "mistral" in model_name:
694+
final_weights_folder = "mistral_files"
692695
else:
693-
final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo
696+
final_weights_folder = "model_files"
697+
subcommands += self.load_model_weights_sub_commands(
698+
LLMInferenceFramework.VLLM,
699+
framework_image_tag,
700+
checkpoint_path,
701+
final_weights_folder,
702+
)
694703

695704
if max_model_len:
696705
subcommands.append(
@@ -770,21 +779,15 @@ async def create_lightllm_bundle(
770779
max_req_total_len = 4096
771780

772781
subcommands = []
773-
if checkpoint_path is not None:
774-
if checkpoint_path.startswith("s3://"):
775-
final_weights_folder = "model_files"
776-
subcommands += self.load_model_weights_sub_commands(
777-
LLMInferenceFramework.LIGHTLLM,
778-
framework_image_tag,
779-
checkpoint_path,
780-
final_weights_folder,
781-
)
782-
else:
783-
raise ObjectHasInvalidValueException(
784-
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
785-
)
786-
else:
787-
final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo
782+
783+
checkpoint_path = get_checkpoint_path(model_name, checkpoint_path)
784+
final_weights_folder = "model_files"
785+
subcommands += self.load_model_weights_sub_commands(
786+
LLMInferenceFramework.LIGHTLLM,
787+
framework_image_tag,
788+
checkpoint_path,
789+
final_weights_folder,
790+
)
788791

789792
subcommands.append(
790793
f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto"
@@ -835,20 +838,18 @@ async def create_tensorrt_llm_bundle(
835838
command = []
836839

837840
subcommands = []
838-
if checkpoint_path is not None:
839-
if checkpoint_path.startswith("s3://"):
840-
subcommands += self.load_model_files_sub_commands_trt_llm(
841-
checkpoint_path,
842-
)
843-
else:
844-
raise ObjectHasInvalidValueException(
845-
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
846-
)
847-
else:
841+
842+
if not checkpoint_path:
848843
raise ObjectHasInvalidValueException(
849844
"Checkpoint must be provided for TensorRT-LLM models."
850845
)
851846

847+
validate_checkpoint_path_uri(checkpoint_path)
848+
849+
subcommands += self.load_model_files_sub_commands_trt_llm(
850+
checkpoint_path,
851+
)
852+
852853
subcommands.append(
853854
f"python3 launch_triton_server.py --world_size={num_shards} --model_repo=./model_repo/"
854855
)

model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
2-
from collections import namedtuple
32
from functools import lru_cache
4-
from typing import Dict, Optional
3+
from typing import Dict, NamedTuple, Optional
54

65
from huggingface_hub import list_repo_refs
76
from huggingface_hub.utils._errors import RepositoryNotFoundError
@@ -25,7 +24,9 @@
2524
TOKENIZER_TARGET_DIR = "/opt/.cache/model_engine_server/tokenizers"
2625

2726

28-
ModelInfo = namedtuple("ModelInfo", ["hf_repo", "s3_repo"])
27+
class ModelInfo(NamedTuple):
28+
hf_repo: str
29+
s3_repo: Optional[str]
2930

3031

3132
def get_default_supported_models_info() -> Dict[str, ModelInfo]:

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, List, Tuple
2+
from typing import Any, Dict, List, Tuple
33
from unittest import mock
44

55
import pytest
@@ -54,9 +54,22 @@
5454
validate_and_update_completion_params,
5555
)
5656
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
57+
from model_engine_server.infra.repositories import live_tokenizer_repository
58+
from model_engine_server.infra.repositories.live_tokenizer_repository import ModelInfo
59+
60+
61+
def good_models_info() -> Dict[str, ModelInfo]:
62+
return {
63+
k: ModelInfo(v.hf_repo, "s3://test-s3.tar")
64+
for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items()
65+
}
5766

5867

5968
@pytest.mark.asyncio
69+
@mock.patch(
70+
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO",
71+
good_models_info(),
72+
)
6073
async def test_create_model_endpoint_use_case_success(
6174
test_api_key: str,
6275
fake_model_bundle_repository,
@@ -170,6 +183,82 @@ async def test_create_model_endpoint_use_case_success(
170183
assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1]
171184

172185

186+
def bad_models_info() -> Dict[str, ModelInfo]:
187+
info = {
188+
k: ModelInfo(v.hf_repo, v.s3_repo)
189+
for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items()
190+
}
191+
info.update(
192+
{
193+
"mpt-7b": ModelInfo("mosaicml/mpt-7b", None),
194+
"mpt-7b-instruct": ModelInfo("mosaicml/mpt-7b-instruct", "gibberish"),
195+
}
196+
)
197+
return info
198+
199+
200+
@pytest.mark.asyncio
201+
@pytest.mark.parametrize(
202+
"inference_framework, model_name, expected_error",
203+
[
204+
(LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", InvalidRequestException),
205+
(
206+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
207+
"mpt-7b-instruct",
208+
ObjectHasInvalidValueException,
209+
),
210+
(LLMInferenceFramework.LIGHTLLM, "mpt-7b", InvalidRequestException),
211+
(LLMInferenceFramework.LIGHTLLM, "mpt-7b-instruct", ObjectHasInvalidValueException),
212+
(LLMInferenceFramework.VLLM, "mpt-7b", InvalidRequestException),
213+
(LLMInferenceFramework.VLLM, "mpt-7b-instruct", ObjectHasInvalidValueException),
214+
],
215+
)
216+
@mock.patch(
217+
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO",
218+
bad_models_info(),
219+
)
220+
async def test_create_model_bundle_fails_if_no_checkpoint(
221+
test_api_key: str,
222+
fake_model_bundle_repository,
223+
fake_model_endpoint_service,
224+
fake_docker_repository_image_always_exists,
225+
fake_model_primitive_gateway,
226+
fake_llm_artifact_gateway,
227+
create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request,
228+
inference_framework,
229+
model_name,
230+
expected_error,
231+
):
232+
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
233+
bundle_use_case = CreateModelBundleV2UseCase(
234+
model_bundle_repository=fake_model_bundle_repository,
235+
docker_repository=fake_docker_repository_image_always_exists,
236+
model_primitive_gateway=fake_model_primitive_gateway,
237+
)
238+
use_case = CreateLLMModelBundleV1UseCase(
239+
create_model_bundle_use_case=bundle_use_case,
240+
model_bundle_repository=fake_model_bundle_repository,
241+
llm_artifact_gateway=fake_llm_artifact_gateway,
242+
docker_repository=fake_docker_repository_image_always_exists,
243+
)
244+
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
245+
request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy()
246+
247+
with pytest.raises(expected_error):
248+
await use_case.execute(
249+
user=user,
250+
endpoint_name=request.name,
251+
model_name=model_name,
252+
source=request.source,
253+
framework=inference_framework,
254+
framework_image_tag="0.0.0",
255+
endpoint_type=request.endpoint_type,
256+
num_shards=request.num_shards,
257+
quantize=request.quantize,
258+
checkpoint_path=None,
259+
)
260+
261+
173262
@pytest.mark.asyncio
174263
@pytest.mark.parametrize(
175264
"valid, inference_framework, inference_framework_image_tag",
@@ -180,6 +269,10 @@ async def test_create_model_endpoint_use_case_success(
180269
(True, LLMInferenceFramework.VLLM, "0.1.3.6"),
181270
],
182271
)
272+
@mock.patch(
273+
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO",
274+
good_models_info(),
275+
)
183276
async def test_create_model_bundle_inference_framework_image_tag_validation(
184277
test_api_key: str,
185278
fake_model_bundle_repository,

0 commit comments

Comments
 (0)