Skip to content

Commit 37af55c

Browse files
authored
Azure fixes + additional asks (#468)
1 parent 5b0aaf1 commit 37af55c

File tree

13 files changed

+158
-108
lines changed

13 files changed

+158
-108
lines changed

charts/model-engine/templates/_helpers.tpl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,6 @@ env:
270270
value: {{ .Values.azure.abs_account_name }}
271271
- name: SERVICEBUS_NAMESPACE
272272
value: {{ .Values.azure.servicebus_namespace }}
273-
- name: SERVICEBUS_SAS_KEY
274-
value: {{ .Values.azure.servicebus_sas_key }}
275273
{{- end }}
276274
{{- end }}
277275

model-engine/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# syntax = docker/dockerfile:experimental
22

3-
FROM python:3.8.8-slim as model-engine
3+
FROM python:3.8.18-slim as model-engine
44
WORKDIR /workspace
55

66
RUN apt-get update && apt-get install -y \

model-engine/model_engine_server/api/dependencies.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import os
3+
import time
34
from dataclasses import dataclass
45
from typing import Callable, Optional
56

@@ -442,6 +443,7 @@ async def verify_authentication(
442443
def get_or_create_aioredis_pool() -> aioredis.ConnectionPool:
443444
global _pool
444445

445-
if _pool is None:
446+
expiration_timestamp = hmi_config.cache_redis_url_expiration_timestamp
447+
if _pool is None or (expiration_timestamp is not None and time.time() > expiration_timestamp):
446448
_pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url)
447449
return _pool

model-engine/model_engine_server/api/files_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def upload_file(
4444
)
4545
return await use_case.execute(
4646
user=auth,
47-
filename=file.filename,
47+
filename=file.filename or "",
4848
content=file.file.read(),
4949
)
5050

model-engine/model_engine_server/common/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH)
3232

33+
redis_cache_expiration_timestamp = None
34+
3335

3436
# duplicated from llm/ia3_finetune
3537
def get_model_cache_directory_name(model_name: str):
@@ -81,9 +83,17 @@ def cache_redis_url(self) -> str:
8183

8284
assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
8385
username = os.getenv("AZURE_OBJECT_ID")
84-
password = DefaultAzureCredential().get_token("https://redis.azure.com/.default").token
86+
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
87+
password = token.token
88+
global redis_cache_expiration_timestamp
89+
redis_cache_expiration_timestamp = token.expires_on
8590
return f"rediss://{username}:{password}@{self.cache_redis_azure_host}"
8691

92+
@property
93+
def cache_redis_url_expiration_timestamp(self) -> Optional[int]:
94+
global redis_cache_expiration_timestamp
95+
return redis_cache_expiration_timestamp
96+
8797
@property
8898
def cache_redis_host_port(self) -> str:
8999
# redis://redis.url:6379/<db_index>

model-engine/model_engine_server/db/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool
5858
user = os.environ.get("AZURE_IDENTITY_NAME")
5959
password = (
6060
DefaultAzureCredential()
61-
.get_token("https://ossrdbms-aad.database.windows.net")
61+
.get_token("https://ossrdbms-aad.database.windows.net/.default")
6262
.token
6363
)
6464
logger.info(f"Connecting to db {db} as user {user}")
@@ -81,7 +81,9 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool
8181

8282
# For async postgres, we need to use an async dialect.
8383
if not sync:
84-
engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://")
84+
engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://").replace(
85+
"sslmode", "ssl"
86+
)
8587
return engine_url
8688

8789

model-engine/model_engine_server/domain/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class DockerImageNotFoundException(DomainException):
4141
tag: str
4242

4343

44+
class DockerRepositoryNotFoundException(DomainException):
45+
"""
46+
Thrown when a Docker repository that is trying to be accessed doesn't exist.
47+
"""
48+
49+
4450
class DockerBuildFailedException(DomainException):
4551
"""
4652
Thrown if the server failed to build a docker image.

model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from typing import Any, Callable, Dict, Sequence, Set, Type, Union
44

55
from fastapi import routing
6+
from fastapi._compat import GenerateJsonSchema, get_model_definitions
7+
from fastapi.openapi.constants import REF_TEMPLATE
68
from fastapi.openapi.utils import get_openapi_path
7-
from fastapi.utils import get_model_definitions
89
from model_engine_server.common.dtos.tasks import (
910
EndpointPredictV1Request,
1011
GetAsyncTaskV1Response,
@@ -119,8 +120,13 @@ def get_openapi(
119120
if isinstance(route, routing.APIRoute):
120121
prefix = model_endpoint_name
121122
model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix)
123+
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
122124
result = get_openapi_path(
123-
route=route, model_name_map=model_name_map, operation_ids=operation_ids
125+
route=route,
126+
model_name_map=model_name_map,
127+
operation_ids=operation_ids,
128+
schema_generator=schema_generator,
129+
field_mapping={},
124130
)
125131
if result:
126132
path, security_schemes, path_definitions = result

model-engine/model_engine_server/infra/repositories/acr_docker_repository.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse
77
from model_engine_server.core.config import infra_config
88
from model_engine_server.core.loggers import logger_name, make_logger
9+
from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException
910
from model_engine_server.domain.repositories import DockerRepository
1011

1112
logger = make_logger(logger_name())
@@ -36,7 +37,11 @@ def get_latest_image_tag(self, repository_name: str) -> str:
3637
credential = DefaultAzureCredential()
3738
client = ContainerRegistryClient(endpoint, credential)
3839

39-
image = client.list_manifest_properties(
40-
repository_name, order_by="time_desc", results_per_page=1
41-
).next()
42-
return image.tags[0]
40+
try:
41+
image = client.list_manifest_properties(
42+
repository_name, order_by="time_desc", results_per_page=1
43+
).next()
44+
# Azure automatically deletes empty ACR repositories, so repos will always have at least one image
45+
return image.tags[0]
46+
except ResourceNotFoundError:
47+
raise DockerRepositoryNotFoundException

model-engine/model_engine_server/infra/services/image_cache_service.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from model_engine_server.core.config import infra_config
88
from model_engine_server.core.loggers import logger_name, make_logger
99
from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState
10+
from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException
1011
from model_engine_server.domain.repositories import DockerRepository
1112
from model_engine_server.infra.gateways.resources.image_cache_gateway import (
1213
CachedImages,
@@ -78,11 +79,14 @@ def _cache_finetune_llm_images(
7879
vllm_image_032 = DockerImage(
7980
f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2"
8081
)
81-
latest_tag = (
82-
self.docker_repository.get_latest_image_tag(hmi_config.batch_inference_vllm_repository)
83-
if not CIRCLECI
84-
else "fake_docker_repository_latest_image_tag"
85-
)
82+
latest_tag = "fake_docker_repository_latest_image_tag"
83+
if not CIRCLECI:
84+
try: # pragma: no cover
85+
latest_tag = self.docker_repository.get_latest_image_tag(
86+
hmi_config.batch_inference_vllm_repository
87+
)
88+
except DockerRepositoryNotFoundException:
89+
pass
8690
vllm_batch_image_latest = DockerImage(
8791
f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}",
8892
latest_tag,

0 commit comments

Comments
 (0)