Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,17 @@ serviceTemplate:
config:
values:
infra:
# cloud_provider [required]; either "aws" or "azure"
# cloud_provider [required]; either "aws", "azure", or "gcp"
cloud_provider: aws
# k8s_cluster_name [required] is the name of the k8s cluster
k8s_cluster_name: main_cluster
# dns_host_domain [required] is the domain name of the k8s cluster
dns_host_domain: llm-engine.domain.com
# default_region [required] is the default AWS region for various resources (e.g ECR)
default_region: us-east-1
# aws_account_id [required] is the AWS account ID for various resources (e.g ECR)
# ml_account_id [required] is the AWS account ID for various resources (e.g ECR) if cloud_provider is "aws", and the GCP project ID if cloud_provider is "gcp"
ml_account_id: "000000000000"
# docker_repo_prefix [required] is the prefix for AWS ECR repositories
# docker_repo_prefix [required] is the prefix for AWS ECR repositories, GCP Artifact Registry repositories, or Azure Container Registry repositories
docker_repo_prefix: "000000000000.dkr.ecr.us-east-1.amazonaws.com"
# redis_host [required if redis_aws_secret_name not present] is the hostname of the redis cluster you wish to connect
redis_host: llm-engine-prod-cache.use1.cache.amazonaws.com
Expand Down
7 changes: 6 additions & 1 deletion model-engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ Run `mypy . --install-types` to set up mypy.
Most of the business logic in Model Engine should contain unit tests, located in
[`tests/unit`](./tests/unit). To run the tests, run `pytest`.

## Building Docker Images

In order to build docker images, you must change directories into the llm-engine repository root and then run
`docker build -f model-engine/Dockerfile .`

## Generating OpenAI types
We've decided to make our V2 APIs OpenAI compatible. We generate the
corresponding Pydantic models:
1. Fetch the OpenAPI spec from https://github.com/openai/openai-openapi/blob/master/openapi.yaml
2. Run scripts/generate-openai-types.sh
2. Run scripts/generate-openai-types.sh
14 changes: 13 additions & 1 deletion model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
DbTriggerRepository,
ECRDockerRepository,
FakeDockerRepository,
GCPArtifactRegistryDockerRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
RedisModelEndpointCacheRepository,
Expand Down Expand Up @@ -226,6 +227,10 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kovben95scale had a good question about this -- is there a reason we don't use redis on azure etc? I think we're just using this as a celery task queue here, which seems like it would fit with redis.

# we use redis for gcp (instead of using servicebus or the like)
inference_task_queue_gateway = redis_24h_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
else:
inference_task_queue_gateway = sqs_task_queue_gateway
infra_task_queue_gateway = sqs_task_queue_gateway
Expand Down Expand Up @@ -345,6 +350,12 @@ def _get_external_interfaces(
docker_repository = FakeDockerRepository()
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
docker_repository = ACRDockerRepository()
elif "pkg.dev" in infra_config().docker_repo_prefix:
assert (
infra_config().docker_repo_prefix
== f"{infra_config().default_region}-docker.pkg.dev/{infra_config().ml_account_id}" # this stores the gcp project id (when cloud_provider is gcp)
)
docker_repository = GCPArtifactRegistryDockerRepository()
else:
docker_repository = ECRDockerRepository()

Expand Down Expand Up @@ -387,7 +398,8 @@ def get_default_external_interfaces() -> ExternalInterfaces:

def get_default_external_interfaces_read_only() -> ExternalInterfaces:
session = async_scoped_session(
get_session_read_only_async(), scopefunc=asyncio.current_task # type: ignore
get_session_read_only_async(),
scopefunc=asyncio.current_task, # type: ignore
)
return _get_external_interfaces(read_only=True, session=session)

Expand Down
7 changes: 6 additions & 1 deletion model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ class HostedModelInferenceServiceConfig:
user_inference_tensorflow_repository: str
docker_image_layer_cache_repository: str
sensitive_log_mode: bool
# Exactly one of the following three must be specified
# Exactly one of the following four must be specified
cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics
cache_redis_azure_host: Optional[str] = None
cache_redis_aws_secret_name: Optional[str] = (
None # Not an env var because the redis cache info is already here
)
cache_redis_gcp_host: Optional[str] = None

sglang_repository: Optional[str] = None

@classmethod
Expand Down Expand Up @@ -103,6 +105,9 @@ def cache_redis_url(self) -> str:
), "cache_redis_aws_secret_name is only for AWS"
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
return creds["cache-url"]
elif self.cache_redis_gcp_host:
assert infra_config().cloud_provider == "gcp"
return f"rediss://{self.cache_redis_gcp_host}"

assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BrokerName(str, Enum):
"""

REDIS = "redis-message-broker-master"
REDIS_GCP = "redis-gcp-memorystore-message-broker-master"
SQS = "sqs-message-broker-master"
SERVICEBUS = "servicebus-message-broker-master"

Expand Down
2 changes: 2 additions & 0 deletions model-engine/model_engine_server/core/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
TaskVisibility,
celery_app,
get_all_db_indexes,
get_default_backend_protocol,
get_redis_host_port,
inspect_app,
)

__all__: Sequence[str] = (
"celery_app",
"get_default_backend_protocol",
"get_all_db_indexes",
"get_redis_host_port",
"inspect_app",
Expand Down
14 changes: 11 additions & 3 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def seconds_to_visibility(timeout: int) -> "TaskVisibility":
@staticmethod
def from_name(name: str) -> "TaskVisibility":
# pylint: disable=no-member,protected-access
lookup = {
x.name: x.value for x in TaskVisibility._value2member_map_.values()
} # type: ignore
lookup = {x.name: x.value for x in TaskVisibility._value2member_map_.values()} # type: ignore
return TaskVisibility(lookup[name.upper()])


Expand Down Expand Up @@ -595,3 +593,13 @@ async def get_num_unclaimed_tasks_async(
if redis_instance is None:
await _redis_instance.close() # type: ignore
return num_unclaimed


def get_default_backend_protocol():
logger.info("CLOUD PROVIDER: %s", infra_config().cloud_provider)
if infra_config().cloud_provider == "azure":
return "abs"
elif infra_config().cloud_provider == "gcp":
return "redis" # TODO: THIS IS TEMPORARY! replace with cloud storage
else:
return "s3"
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def excluded_namespaces():
ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master"
SQS_BROKER = "sqs-message-broker-master"
SERVICEBUS_BROKER = "servicebus-message-broker-master"
GCP_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master"

UPDATE_DEPLOYMENT_MAX_RETRIES = 10

Expand Down Expand Up @@ -588,6 +589,7 @@ async def main():
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True),
SQS_BROKER: SQSBroker(),
SERVICEBUS_BROKER: ASBBroker(),
GCP_REDIS_BROKER: RedisBroker(use_elasticache=False),
}

broker = BROKER_NAME_TO_CLASS[autoscaler_broker]
Expand All @@ -598,10 +600,18 @@ async def main():
)

if broker_type == "redis":
# TODO: change this backend_protocol to use the correct protocol
# NOTE: the infra config is not available in the autoscaler (for some reason), so we have
# to use the autoscaler_broker to determine the infra.
backend_protocol = "redis" if "gcp" in autoscaler_broker else "s3"
inspect = {
db_index: inspect_app(
app=celery_app(
None, broker_type=broker_type, task_visibility=db_index, aws_role=aws_profile
None,
broker_type=broker_type,
task_visibility=db_index,
aws_role=aws_profile,
backend_protocol=backend_protocol,
)
)
for db_index in get_all_db_indexes()
Expand Down
3 changes: 2 additions & 1 deletion model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class _InfraConfig:
k8s_cluster_name: str
dns_host_domain: str
default_region: str
ml_account_id: str
ml_account_id: str # NOTE: this stores the aws account id if cloud_provider is aws, and the gcp project id if cloud_provider is gcpFgc
docker_repo_prefix: str
s3_bucket: str
redis_host: Optional[str] = None
Expand All @@ -49,6 +49,7 @@ class _InfraConfig:
firehose_role_arn: Optional[str] = None
firehose_stream_name: Optional[str] = None
prometheus_server_address: Optional[str] = None
# TODO: on gcp, the repo prefix is derived from the project_id and default_region. So it should be computed somehow.


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions model-engine/model_engine_server/core/configmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ async def read_config_map(
except ApiException as e:
logger.exception(f"Error reading configmap {config_map_name}")
raise e


# TODO: figure out what this does
35 changes: 32 additions & 3 deletions model-engine/model_engine_server/db/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import sys
import time
Expand All @@ -7,6 +8,7 @@
import sqlalchemy
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from google.cloud.secretmanager_v1 import SecretManagerServiceClient
from model_engine_server.core.aws.secrets import get_key_file
from model_engine_server.core.config import InfraConfig, infra_config
from model_engine_server.core.loggers import logger_name, make_logger
Expand All @@ -20,8 +22,12 @@


def get_key_file_name(environment: str) -> str:
if infra_config().cloud_provider == "azure":
# azure and gcp don't support "/" in the key file secret name
# so we use dashes
if infra_config().cloud_provider == "azure" or infra_config().cloud_provider == "gcp":
return f"{environment}-ml-infra-pg".replace("training", "prod").replace("-new", "")

# aws does support "/" in the key file secret name
return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "")


Expand Down Expand Up @@ -55,16 +61,17 @@ def get_engine_url(
key_file = os.environ.get("DB_SECRET_NAME")
if env is None:
env = infra_config().env
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the values of env / where is it used?

# TODO: what are the values of env?
if key_file is None:
key_file = get_key_file_name(env) # type: ignore
logger.debug(f"Using key file {key_file}")

if infra_config().cloud_provider == "azure":
client = SecretClient(
az_secret_client = SecretClient(
vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net",
credential=DefaultAzureCredential(),
)
db = client.get_secret(key_file).value
db = az_secret_client.get_secret(key_file).value
user = os.environ.get("AZURE_IDENTITY_NAME")
token = DefaultAzureCredential().get_token(
"https://ossrdbms-aad.database.windows.net/.default"
Expand All @@ -76,6 +83,28 @@ def get_engine_url(
# for recommendations on how to work with rotating auth credentials
engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require"
expiry_in_sec = token.expires_on
elif infra_config().cloud_provider == "gcp":
gcp_secret_manager_client = (
SecretManagerServiceClient()
) # uses application default credentials (see: https://cloud.google.com/secret-manager/docs/reference/libraries#client-libraries-usage-python)
secret_version = gcp_secret_manager_client.access_secret_version(
request={
"name": f"projects/{infra_config().ml_account_id}/secrets/{key_file}/versions/latest"
}
)
creds = json.loads(secret_version.payload.data.decode("utf-8"))

user = creds.get("username")
password = creds.get("password")
host = creds.get("host")
port = str(creds.get("port"))
dbname = creds.get("dbname")

assert all([user, password, host, port, dbname]) # TODO: remove this

logger.info(f"Connecting to db {host}:{port}, name {dbname}")

engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
else:
db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE")
creds = get_key_file(key_file, db_secret_aws_profile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
GetAsyncTaskV1Response,
TaskStatus,
)
from model_engine_server.core.celery import TaskVisibility, celery_app
from model_engine_server.core.celery import TaskVisibility, celery_app, get_default_backend_protocol
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import InvalidRequestException
from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway

logger = make_logger(logger_name())
backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3"

backend_protocol = get_default_backend_protocol()

celery_redis = celery_app(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,8 @@ def get_endpoint_resource_arguments_from_request(
image_hash = compute_image_hash(request.image)

# In Circle CI, we use Redis on localhost instead of SQS
if CIRCLECI:
broker_name = BrokerName.REDIS.value
if CIRCLECI or infra_config().cloud_provider == "gcp":
broker_name = BrokerName.REDIS_GCP.value
broker_type = BrokerType.REDIS.value
elif infra_config().cloud_provider == "azure":
broker_name = BrokerName.SERVICEBUS.value
Expand All @@ -576,6 +576,7 @@ def get_endpoint_resource_arguments_from_request(
abs_account_name = os.getenv("ABS_ACCOUNT_NAME")
if abs_account_name is not None:
main_env.append({"name": "ABS_ACCOUNT_NAME", "value": abs_account_name})
# TODO: what should we add here

# LeaderWorkerSet exclusive
worker_env = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .ecr_docker_repository import ECRDockerRepository
from .fake_docker_repository import FakeDockerRepository
from .feature_flag_repository import FeatureFlagRepository
from .gcp_artifact_registry_docker_repository import GCPArtifactRegistryDockerRepository
from .live_tokenizer_repository import LiveTokenizerRepository
from .llm_fine_tune_repository import LLMFineTuneRepository
from .model_endpoint_cache_repository import ModelEndpointCacheRepository
Expand Down Expand Up @@ -42,4 +43,5 @@
"RedisModelEndpointCacheRepository",
"S3FileLLMFineTuneRepository",
"S3FileLLMFineTuneEventsRepository",
"GCPArtifactRegistryDockerRepository",
]
Loading