Skip to content

Commit 6fe18d3

Browse files
authored
get latest inference framework tag from configmap (#505)
* get latest inference framework tag from configmap * comments * fix for test * make namespace a config * fix s3 prefix bug * fix checkpoint path fn + tests * values change * quotes
1 parent d0d6b8b commit 6fe18d3

File tree

9 files changed

+122
-51
lines changed

9 files changed

+122
-51
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
apiVersion: v1
2+
kind: ConfigMap
3+
metadata:
4+
name: model-engine-inference-framework-latest-config
5+
labels:
6+
product: common
7+
team: infra
8+
annotations:
9+
"helm.sh/hook": pre-install
10+
"helm.sh/hook-weight": "-2"
11+
data:
12+
deepspeed: "latest"
13+
text_generation_inference: "latest"
14+
vllm: "latest"
15+
lightllm: "latest"
16+
tensorrt_llm: "latest"

charts/model-engine/templates/service_config_map.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ metadata:
1111
data:
1212
launch_service_config: |-
1313
dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }}
14+
gateway_namespace: {{ .Release.Namespace | quote }}
1415
{{- with .Values.config.values.launch }}
1516
{{- range $key, $value := . }}
1617
{{ $key }}: {{ $value | quote }}
@@ -39,6 +40,7 @@ metadata:
3940
data:
4041
launch_service_config: |-
4142
dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }}
43+
gateway_namespace: {{ .Release.Namespace | quote }}
4244
{{- with .Values.config.values.launch }}
4345
{{- range $key, $value := . }}
4446
{{ $key }}: {{ $value | quote }}

model-engine/model_engine_server/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_model_cache_directory_name(model_name: str):
4747

4848
@dataclass
4949
class HostedModelInferenceServiceConfig:
50+
gateway_namespace: str
5051
endpoint_namespace: str
5152
billing_queue_arn: str
5253
sqs_profile: str
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Read configmap from k8s."""
2+
3+
from typing import Dict
4+
5+
from kubernetes_asyncio import client
6+
from kubernetes_asyncio import config as kube_config
7+
from kubernetes_asyncio.client.rest import ApiException
8+
from kubernetes_asyncio.config.config_exception import ConfigException
9+
from model_engine_server.common.config import hmi_config
10+
from model_engine_server.core.loggers import logger_name, make_logger
11+
12+
DEFAULT_NAMESPACE = "default"
13+
14+
logger = make_logger(logger_name())
15+
16+
17+
async def read_config_map(
18+
config_map_name: str, namespace: str = hmi_config.gateway_namespace
19+
) -> Dict[str, str]:
20+
try:
21+
kube_config.load_incluster_config()
22+
except ConfigException:
23+
logger.info("No incluster kubernetes config, falling back to local")
24+
await kube_config.load_kube_config()
25+
26+
core_api = client.CoreV1Api()
27+
28+
try:
29+
config_map = await core_api.read_namespaced_config_map(
30+
name=config_map_name, namespace=namespace
31+
)
32+
return config_map.data
33+
except ApiException as e:
34+
logger.exception(f"Error reading configmap {config_map_name}")
35+
raise e

model-engine/model_engine_server/domain/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,9 @@ class PostInferenceHooksException(DomainException):
182182
"""
183183
Thrown if the post inference hooks are invalid.
184184
"""
185+
186+
187+
class LatestImageTagNotFoundException(DomainException):
188+
"""
189+
Thrown if the latest image tag cannot be found.
190+
"""

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus
4040
from model_engine_server.common.resource_limits import validate_resource_requests
4141
from model_engine_server.core.auth.authentication_repository import User
42+
from model_engine_server.core.configmap import read_config_map
4243
from model_engine_server.core.loggers import (
4344
LoggerTagKey,
4445
LoggerTagManager,
@@ -67,6 +68,7 @@
6768
EndpointLabelsException,
6869
EndpointUnsupportedInferenceTypeException,
6970
InvalidRequestException,
71+
LatestImageTagNotFoundException,
7072
ObjectHasInvalidValueException,
7173
ObjectNotAuthorizedException,
7274
ObjectNotFoundException,
@@ -82,7 +84,10 @@
8284
)
8385
from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService
8486
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway
85-
from model_engine_server.infra.repositories.live_tokenizer_repository import SUPPORTED_MODELS_INFO
87+
from model_engine_server.infra.repositories.live_tokenizer_repository import (
88+
SUPPORTED_MODELS_INFO,
89+
get_models_s3_uri,
90+
)
8691

8792
from ...common.datadog_utils import add_trace_request_id
8893
from ..authorization.live_authorization_module import LiveAuthorizationModule
@@ -246,6 +251,8 @@
246251
NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes
247252
DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes
248253

254+
LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = "model-engine-inference-framework-latest-config"
255+
249256

250257
def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int:
251258
"""
@@ -255,6 +262,15 @@ def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRep
255262
return len(tokenizer.encode(input))
256263

257264

265+
async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str:
266+
config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME)
267+
if inference_framework not in config_map:
268+
raise LatestImageTagNotFoundException(
269+
f"Could not find latest tag for inference framework {inference_framework}."
270+
)
271+
return config_map[inference_framework]
272+
273+
258274
def _include_safetensors_bin_or_pt(model_files: List[str]) -> Optional[str]:
259275
"""
260276
This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files
@@ -337,11 +353,11 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None:
337353

338354

339355
def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str:
340-
checkpoint_path = (
341-
SUPPORTED_MODELS_INFO[model_name].s3_repo
342-
if not checkpoint_path_override
343-
else checkpoint_path_override
344-
)
356+
checkpoint_path = None
357+
if SUPPORTED_MODELS_INFO[model_name].s3_repo:
358+
checkpoint_path = get_models_s3_uri(SUPPORTED_MODELS_INFO[model_name].s3_repo, "")
359+
if checkpoint_path_override:
360+
checkpoint_path = checkpoint_path_override
345361

346362
if not checkpoint_path:
347363
raise InvalidRequestException(f"No checkpoint path found for model {model_name}")
@@ -931,8 +947,8 @@ async def execute(
931947
)
932948

933949
if request.inference_framework_image_tag == "latest":
934-
request.inference_framework_image_tag = self.docker_repository.get_latest_image_tag(
935-
INFERENCE_FRAMEWORK_REPOSITORY[request.inference_framework]
950+
request.inference_framework_image_tag = await _get_latest_tag(
951+
request.inference_framework
936952
)
937953

938954
bundle = await self.create_llm_model_bundle_use_case.execute(
@@ -1149,9 +1165,7 @@ async def execute(
11491165
inference_framework = llm_metadata["inference_framework"]
11501166

11511167
if request.inference_framework_image_tag == "latest":
1152-
inference_framework_image_tag = self.docker_repository.get_latest_image_tag(
1153-
INFERENCE_FRAMEWORK_REPOSITORY[inference_framework]
1154-
)
1168+
inference_framework_image_tag = await _get_latest_tag(inference_framework)
11551169
else:
11561170
inference_framework_image_tag = (
11571171
request.inference_framework_image_tag

model-engine/service_configs/service_config_circleci.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Config to know where model-engine is running
2+
gateway_namespace: default
3+
14
# Config for Model Engine running in CircleCI
25
model_primitive_host: "none"
36

model-engine/tests/unit/domain/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request
222222
labels={"team": "infra", "product": "my_product"},
223223
aws_role="test_aws_role",
224224
results_s3_bucket="test_s3_bucket",
225-
checkpoint_path="s3://test_checkpoint_path",
225+
checkpoint_path="s3://test-s3.tar",
226226
)
227227

228228

@@ -286,6 +286,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque
286286
labels={"team": "infra", "product": "my_product"},
287287
aws_role="test_aws_role",
288288
results_s3_bucket="test_s3_bucket",
289+
checkpoint_path="s3://test-s3.tar",
289290
)
290291

291292

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, List, Tuple
2+
from typing import Any, List, Tuple
33
from unittest import mock
44

55
import pytest
@@ -54,21 +54,19 @@
5454
validate_and_update_completion_params,
5555
)
5656
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
57-
from model_engine_server.infra.repositories import live_tokenizer_repository
58-
from model_engine_server.infra.repositories.live_tokenizer_repository import ModelInfo
5957

6058

61-
def good_models_info() -> Dict[str, ModelInfo]:
62-
return {
63-
k: ModelInfo(v.hf_repo, "s3://test-s3.tar")
64-
for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items()
65-
}
59+
def mocked__get_latest_tag():
60+
async def async_mock(*args, **kwargs): # noqa
61+
return "fake_docker_repository_latest_image_tag"
62+
63+
return mock.AsyncMock(side_effect=async_mock)
6664

6765

6866
@pytest.mark.asyncio
6967
@mock.patch(
70-
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO",
71-
good_models_info(),
68+
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag",
69+
mocked__get_latest_tag(),
7270
)
7371
async def test_create_model_endpoint_use_case_success(
7472
test_api_key: str,
@@ -183,40 +181,33 @@ async def test_create_model_endpoint_use_case_success(
183181
assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1]
184182

185183

186-
def bad_models_info() -> Dict[str, ModelInfo]:
187-
info = {
188-
k: ModelInfo(v.hf_repo, v.s3_repo)
189-
for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items()
190-
}
191-
info.update(
192-
{
193-
"mpt-7b": ModelInfo("mosaicml/mpt-7b", None),
194-
"mpt-7b-instruct": ModelInfo("mosaicml/mpt-7b-instruct", "gibberish"),
195-
}
196-
)
197-
return info
198-
199-
200184
@pytest.mark.asyncio
201185
@pytest.mark.parametrize(
202-
"inference_framework, model_name, expected_error",
186+
"inference_framework, model_name, checkpoint_path, expected_error",
203187
[
204-
(LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", InvalidRequestException),
188+
(LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", None, InvalidRequestException),
205189
(
206190
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
207191
"mpt-7b-instruct",
192+
"gibberish",
193+
ObjectHasInvalidValueException,
194+
),
195+
(LLMInferenceFramework.LIGHTLLM, "mpt-7b", None, InvalidRequestException),
196+
(
197+
LLMInferenceFramework.LIGHTLLM,
198+
"mpt-7b-instruct",
199+
"gibberish",
200+
ObjectHasInvalidValueException,
201+
),
202+
(LLMInferenceFramework.VLLM, "mpt-7b", None, InvalidRequestException),
203+
(
204+
LLMInferenceFramework.VLLM,
205+
"mpt-7b-instruct",
206+
"gibberish",
208207
ObjectHasInvalidValueException,
209208
),
210-
(LLMInferenceFramework.LIGHTLLM, "mpt-7b", InvalidRequestException),
211-
(LLMInferenceFramework.LIGHTLLM, "mpt-7b-instruct", ObjectHasInvalidValueException),
212-
(LLMInferenceFramework.VLLM, "mpt-7b", InvalidRequestException),
213-
(LLMInferenceFramework.VLLM, "mpt-7b-instruct", ObjectHasInvalidValueException),
214209
],
215210
)
216-
@mock.patch(
217-
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO",
218-
bad_models_info(),
219-
)
220211
async def test_create_model_bundle_fails_if_no_checkpoint(
221212
test_api_key: str,
222213
fake_model_bundle_repository,
@@ -227,6 +218,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint(
227218
create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request,
228219
inference_framework,
229220
model_name,
221+
checkpoint_path,
230222
expected_error,
231223
):
232224
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
@@ -255,7 +247,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint(
255247
endpoint_type=request.endpoint_type,
256248
num_shards=request.num_shards,
257249
quantize=request.quantize,
258-
checkpoint_path=None,
250+
checkpoint_path=checkpoint_path,
259251
)
260252

261253

@@ -269,10 +261,6 @@ async def test_create_model_bundle_fails_if_no_checkpoint(
269261
(True, LLMInferenceFramework.VLLM, "0.1.3.6"),
270262
],
271263
)
272-
@mock.patch(
273-
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO",
274-
good_models_info(),
275-
)
276264
async def test_create_model_bundle_inference_framework_image_tag_validation(
277265
test_api_key: str,
278266
fake_model_bundle_repository,
@@ -307,6 +295,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation(
307295
request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy()
308296
request.inference_framework = inference_framework
309297
request.inference_framework_image_tag = inference_framework_image_tag
298+
request.checkpoint_path = "s3://test-s3.tar"
310299
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
311300
if valid:
312301
await use_case.execute(user=user, request=request)
@@ -592,6 +581,10 @@ async def test_get_llm_model_endpoint_use_case_raises_not_authorized(
592581

593582

594583
@pytest.mark.asyncio
584+
@mock.patch(
585+
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag",
586+
mocked__get_latest_tag(),
587+
)
595588
async def test_update_model_endpoint_use_case_success(
596589
test_api_key: str,
597590
fake_model_bundle_repository,

0 commit comments

Comments
 (0)