Skip to content

Commit ba68b8d

Browse files
Infer hardware from model name (#515)
* Infer hardware from model name * fix * fix lint * fix * Use formula instead of hardcode * tests * remove print and cache * fixes
1 parent fbe7417 commit ba68b8d

File tree

10 files changed

+487
-139
lines changed

10 files changed

+487
-139
lines changed

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""LLM Model Endpoint routes for the hosted model inference service.
22
"""
3+
34
import traceback
45
from datetime import datetime
56
from typing import Optional
@@ -169,6 +170,7 @@ async def create_model_endpoint(
169170
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
170171
model_endpoint_service=external_interfaces.model_endpoint_service,
171172
docker_repository=external_interfaces.docker_repository,
173+
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
172174
)
173175
return await use_case.execute(user=auth, request=request)
174176
except ObjectAlreadyExistsException as exc:
@@ -331,9 +333,9 @@ async def create_completion_sync_task(
331333
external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics,
332334
TokenUsage(
333335
num_prompt_tokens=response.output.num_prompt_tokens if response.output else None,
334-
num_completion_tokens=response.output.num_completion_tokens
335-
if response.output
336-
else None,
336+
num_completion_tokens=(
337+
response.output.num_completion_tokens if response.output else None
338+
),
337339
total_duration=use_case_timer.duration,
338340
),
339341
metric_metadata,
@@ -401,9 +403,9 @@ async def event_generator():
401403
external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics,
402404
TokenUsage(
403405
num_prompt_tokens=message.output.num_prompt_tokens if message.output else None,
404-
num_completion_tokens=message.output.num_completion_tokens
405-
if message.output
406-
else None,
406+
num_completion_tokens=(
407+
message.output.num_completion_tokens if message.output else None
408+
),
407409
total_duration=use_case_timer.duration,
408410
time_to_first_token=time_to_first_token,
409411
),
@@ -593,6 +595,7 @@ async def create_batch_completions(
593595
docker_image_batch_job_gateway=external_interfaces.docker_image_batch_job_gateway,
594596
docker_repository=external_interfaces.docker_repository,
595597
docker_image_batch_job_bundle_repo=external_interfaces.docker_image_batch_job_bundle_repository,
598+
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
596599
)
597600
return await use_case.execute(user=auth, request=request)
598601
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ class CreateLLMModelEndpointV1Request(BaseModel):
5151
metadata: Dict[str, Any] # TODO: JSON type
5252
post_inference_hooks: Optional[List[str]]
5353
endpoint_type: ModelEndpointType = ModelEndpointType.SYNC
54-
cpus: CpuSpecificationType
55-
gpus: int
56-
memory: StorageSpecificationType
57-
gpu_type: GpuType
54+
cpus: Optional[CpuSpecificationType]
55+
gpus: Optional[int]
56+
memory: Optional[StorageSpecificationType]
57+
gpu_type: Optional[GpuType]
5858
storage: Optional[StorageSpecificationType]
5959
optimize_costs: Optional[bool]
6060
min_workers: int

model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List
2+
from typing import Any, Dict, List
33

44

55
class LLMArtifactGateway(ABC):
@@ -39,3 +39,13 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[
3939
model_name (str): name of the model
4040
"""
4141
pass
42+
43+
@abstractmethod
44+
def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
45+
"""
46+
Gets the model config from the model files live at given folder.
47+
48+
Args:
49+
path (str): path to model files
50+
"""
51+
pass

0 commit comments

Comments
 (0)