Skip to content

Commit 4d9d934

Browse files
Sync scale from zero, part 1 (#229)
* add patch files * fix ruff
1 parent c53f3c4 commit 4d9d934

23 files changed

+337
-55
lines changed

model-engine/model_engine_server/api/dependencies.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
LiveStreamingModelEndpointInferenceGateway,
5050
LiveSyncModelEndpointInferenceGateway,
5151
ModelEndpointInfraGateway,
52+
RedisInferenceAutoscalingMetricsGateway,
5253
S3FilesystemGateway,
5354
S3LLMArtifactGateway,
5455
)
@@ -179,6 +180,9 @@ def _get_external_interfaces(
179180
model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
180181
filesystem_gateway=filesystem_gateway
181182
)
183+
inference_autoscaling_metrics_gateway = RedisInferenceAutoscalingMetricsGateway(
184+
redis_client=redis_client
185+
) # we can just reuse the existing redis client, we shouldn't get key collisions because of the prefix
182186
model_endpoint_service = LiveModelEndpointService(
183187
model_endpoint_record_repository=model_endpoint_record_repo,
184188
model_endpoint_infra_gateway=model_endpoint_infra_gateway,
@@ -187,6 +191,7 @@ def _get_external_interfaces(
187191
streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway,
188192
sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway,
189193
model_endpoints_schema_gateway=model_endpoints_schema_gateway,
194+
inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway,
190195
)
191196
llm_model_endpoint_service = LiveLLMModelEndpointService(
192197
model_endpoint_record_repository=model_endpoint_record_repo,

model-engine/model_engine_server/api/tasks_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
CreateAsyncTaskV1Response,
1212
EndpointPredictV1Request,
1313
GetAsyncTaskV1Response,
14+
SyncEndpointPredictV1Request,
1415
SyncEndpointPredictV1Response,
1516
TaskStatus,
1617
)
@@ -97,7 +98,7 @@ def get_async_inference_task(
9798
@inference_task_router_v1.post("/sync-tasks", response_model=SyncEndpointPredictV1Response)
9899
async def create_sync_inference_task(
99100
model_endpoint_id: str,
100-
request: EndpointPredictV1Request,
101+
request: SyncEndpointPredictV1Request,
101102
auth: User = Depends(verify_authentication),
102103
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only),
103104
) -> SyncEndpointPredictV1Response:
@@ -137,7 +138,7 @@ async def create_sync_inference_task(
137138
@inference_task_router_v1.post("/streaming-tasks")
138139
async def create_streaming_inference_task(
139140
model_endpoint_id: str,
140-
request: EndpointPredictV1Request,
141+
request: SyncEndpointPredictV1Request,
141142
auth: User = Depends(verify_authentication),
142143
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only),
143144
) -> EventSourceResponse:

model-engine/model_engine_server/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_model_cache_directory_name(model_name: str):
4545
class HostedModelInferenceServiceConfig:
4646
endpoint_namespace: str
4747
billing_queue_arn: str
48-
cache_redis_url: str
48+
cache_redis_url: str # also using this to store sync autoscaling metrics
4949
sqs_profile: str
5050
sqs_queue_policy_template: str
5151
sqs_queue_tag_template: str

model-engine/model_engine_server/common/dtos/tasks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Optional
77

88
from model_engine_server.domain.entities import CallbackAuth
9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, Field
1010

1111

1212
class ResponseSchema(BaseModel):
@@ -49,3 +49,10 @@ class EndpointPredictV1Request(BaseModel):
4949
callback_url: Optional[str] = None
5050
callback_auth: Optional[CallbackAuth] = None
5151
return_pickled: bool = False
52+
53+
54+
class SyncEndpointPredictV1Request(EndpointPredictV1Request):
55+
timeout_seconds: Optional[float] = Field(default=None, gt=0)
56+
num_retries: Optional[int] = Field(default=None, ge=0)
57+
# See live_{sync,streaming}_model_endpoint_inference_gateway to see how timeout_seconds/num_retries interact.
58+
# Also these fields are only relevant for sync endpoints

model-engine/model_engine_server/domain/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ class TooManyRequestsException(DomainException):
5959
"""
6060

6161

62+
class NoHealthyUpstreamException(DomainException):
63+
"""
64+
Thrown if an endpoint returns a 503 exception for no healthy upstream. This can happen if there are zero pods
65+
available to serve the request.
66+
"""
67+
68+
6269
class CorruptRecordInfraStateException(DomainException):
6370
"""
6471
Thrown if the data from existing state (i.e. the db, k8s, etc.) is somehow uninterpretable

model-engine/model_engine_server/domain/gateways/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .cron_job_gateway import CronJobGateway
33
from .docker_image_batch_job_gateway import DockerImageBatchJobGateway
44
from .file_storage_gateway import FileStorageGateway
5+
from .inference_autoscaling_metrics_gateway import InferenceAutoscalingMetricsGateway
56
from .llm_artifact_gateway import LLMArtifactGateway
67
from .model_endpoints_schema_gateway import ModelEndpointsSchemaGateway
78
from .model_primitive_gateway import ModelPrimitiveGateway
@@ -15,6 +16,7 @@
1516
"CronJobGateway",
1617
"DockerImageBatchJobGateway",
1718
"FileStorageGateway",
19+
"InferenceAutoscalingMetricsGateway",
1820
"LLMArtifactGateway",
1921
"ModelEndpointsSchemaGateway",
2022
"ModelPrimitiveGateway",
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class InferenceAutoscalingMetricsGateway(ABC):
5+
"""
6+
Abstract Base Class for a gateway that emits autoscaling metrics for inference requests. Can be used in conjunction
7+
with various autoscaler resources, e.g. a Keda ScaledObject, to autoscale inference endpoints.
8+
"""
9+
10+
@abstractmethod
11+
async def emit_inference_autoscaling_metric(self, endpoint_id: str):
12+
"""
13+
On an inference request, emit a metric
14+
"""
15+
pass
16+
17+
@abstractmethod
18+
async def emit_prewarm_metric(self, endpoint_id: str):
19+
"""
20+
If you want to prewarm an endpoint, emit a metric here
21+
"""
22+
pass

model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import AsyncIterable
33

44
from model_engine_server.common.dtos.tasks import (
5-
EndpointPredictV1Request,
5+
SyncEndpointPredictV1Request,
66
SyncEndpointPredictV1Response,
77
)
88

@@ -17,7 +17,7 @@ class StreamingModelEndpointInferenceGateway(ABC):
1717

1818
@abstractmethod
1919
def streaming_predict(
20-
self, topic: str, predict_request: EndpointPredictV1Request
20+
self, topic: str, predict_request: SyncEndpointPredictV1Request
2121
) -> AsyncIterable[SyncEndpointPredictV1Response]:
2222
"""
2323
Runs a prediction request and returns a streaming response.

model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from model_engine_server.common.dtos.tasks import (
4-
EndpointPredictV1Request,
4+
SyncEndpointPredictV1Request,
55
SyncEndpointPredictV1Response,
66
)
77

@@ -16,7 +16,7 @@ class SyncModelEndpointInferenceGateway(ABC):
1616

1717
@abstractmethod
1818
async def predict(
19-
self, topic: str, predict_request: EndpointPredictV1Request
19+
self, topic: str, predict_request: SyncEndpointPredictV1Request
2020
) -> SyncEndpointPredictV1Response:
2121
"""
2222
Runs a prediction request and returns a response.

model-engine/model_engine_server/domain/services/model_endpoint_service.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
StreamingModelEndpointInferenceGateway,
1919
SyncModelEndpointInferenceGateway,
2020
)
21+
from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import (
22+
InferenceAutoscalingMetricsGateway,
23+
)
2124

2225

2326
class ModelEndpointService(ABC):
@@ -49,6 +52,14 @@ def get_streaming_model_endpoint_inference_gateway(
4952
Returns the sync model endpoint inference gateway.
5053
"""
5154

55+
@abstractmethod
56+
def get_inference_auto_scaling_metrics_gateway(
57+
self,
58+
) -> InferenceAutoscalingMetricsGateway:
59+
"""
60+
Returns the inference autoscaling metrics gateway.
61+
"""
62+
5263
@abstractmethod
5364
async def create_model_endpoint(
5465
self,

0 commit comments

Comments
 (0)