diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 9c7dd2f76..42957e491 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -94,6 +94,9 @@ from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) +from model_engine_server.infra.gateways.resources.onprem_queue_endpoint_resource_delegate import ( + OnPremQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( QueueEndpointResourceDelegate, ) @@ -114,6 +117,7 @@ FakeDockerRepository, LiveTokenizerRepository, LLMFineTuneRepository, + OnPremDockerRepository, RedisModelEndpointCacheRepository, S3FileLLMFineTuneEventsRepository, S3FileLLMFineTuneRepository, @@ -225,6 +229,8 @@ def _get_external_interfaces( queue_delegate = FakeQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "onprem": + queue_delegate = OnPremQueueEndpointResourceDelegate() else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) @@ -238,6 +244,9 @@ def _get_external_interfaces( elif infra_config().cloud_provider == "azure": inference_task_queue_gateway = servicebus_task_queue_gateway infra_task_queue_gateway = servicebus_task_queue_gateway + elif infra_config().cloud_provider == "onprem": + inference_task_queue_gateway = redis_task_queue_gateway + infra_task_queue_gateway = redis_task_queue_gateway elif infra_config().celery_broker_type_redis: inference_task_queue_gateway = redis_task_queue_gateway infra_task_queue_gateway = redis_task_queue_gateway @@ -274,16 +283,12 @@ def _get_external_interfaces( monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) - filesystem_gateway = ( - ABSFilesystemGateway() - if infra_config().cloud_provider == "azure" - else S3FilesystemGateway() - ) - llm_artifact_gateway = ( - ABSLLMArtifactGateway() - if infra_config().cloud_provider == "azure" - else S3LLMArtifactGateway() - ) + if infra_config().cloud_provider == "azure": + filesystem_gateway = ABSFilesystemGateway() + llm_artifact_gateway = ABSLLMArtifactGateway() + else: + filesystem_gateway = S3FilesystemGateway() + llm_artifact_gateway = S3LLMArtifactGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) @@ -328,18 +333,11 @@ def _get_external_interfaces( hmi_config.cloud_file_llm_fine_tune_repository, ) if infra_config().cloud_provider == "azure": - llm_fine_tune_repository = ABSFileLLMFineTuneRepository( - file_path=file_path, - ) + llm_fine_tune_repository = ABSFileLLMFineTuneRepository(file_path=file_path) + llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository() else: - llm_fine_tune_repository = S3FileLLMFineTuneRepository( - file_path=file_path, - ) - llm_fine_tune_events_repository = ( - ABSFileLLMFineTuneEventsRepository() - if infra_config().cloud_provider == "azure" - else S3FileLLMFineTuneEventsRepository() - ) + llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path) + llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository() llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService( docker_image_batch_job_gateway=docker_image_batch_job_gateway, docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository, @@ -350,17 +348,18 @@ def _get_external_interfaces( docker_image_batch_job_gateway=docker_image_batch_job_gateway ) - file_storage_gateway = ( - ABSFileStorageGateway() - if infra_config().cloud_provider == "azure" - else S3FileStorageGateway() - ) + if infra_config().cloud_provider == "azure": + file_storage_gateway = ABSFileStorageGateway() + else: + file_storage_gateway = S3FileStorageGateway() docker_repository: DockerRepository if CIRCLECI: docker_repository = FakeDockerRepository() - elif infra_config().docker_repo_prefix.endswith("azurecr.io"): + elif infra_config().cloud_provider == "azure": docker_repository = ACRDockerRepository() + elif infra_config().cloud_provider == "onprem": + docker_repository = OnPremDockerRepository() else: docker_repository = ECRDockerRepository() diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 532ead21a..286ad46b9 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -90,21 +90,29 @@ def from_yaml(cls, yaml_path): @property def cache_redis_url(self) -> str: + cloud_provider = infra_config().cloud_provider + + if cloud_provider == "onprem": + if self.cache_redis_aws_url: + logger.info("On-prem deployment using cache_redis_aws_url") + return self.cache_redis_aws_url + redis_host = os.getenv("REDIS_HOST", "redis") + redis_port = getattr(infra_config(), "redis_port", 6379) + return f"redis://{redis_host}:{redis_port}/0" + if self.cache_redis_aws_url: - assert infra_config().cloud_provider == "aws", "cache_redis_aws_url is only for AWS" + assert cloud_provider == "aws", "cache_redis_aws_url is only for AWS" if self.cache_redis_aws_secret_name: logger.warning( "Both cache_redis_aws_url and cache_redis_aws_secret_name are set. Using cache_redis_aws_url" ) return self.cache_redis_aws_url elif self.cache_redis_aws_secret_name: - assert ( - infra_config().cloud_provider == "aws" - ), "cache_redis_aws_secret_name is only for AWS" - creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role + assert cloud_provider == "aws", "cache_redis_aws_secret_name is only for AWS" + creds = get_key_file(self.cache_redis_aws_secret_name) return creds["cache-url"] - assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure" + assert self.cache_redis_azure_host and cloud_provider == "azure" username = os.getenv("AZURE_OBJECT_ID") token = DefaultAzureCredential().get_token("https://redis.azure.com/.default") password = token.token diff --git a/model-engine/model_engine_server/common/io.py b/model-engine/model_engine_server/common/io.py index c9d9458ff..9984c969d 100644 --- a/model-engine/model_engine_server/common/io.py +++ b/model-engine/model_engine_server/common/io.py @@ -10,12 +10,11 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs): client: Any - cloud_provider: str - # This follows the 5.1.0 smart_open API try: cloud_provider = infra_config().cloud_provider except Exception: cloud_provider = "aws" + if cloud_provider == "azure": from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient @@ -24,6 +23,20 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs): f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", DefaultAzureCredential(), ) + elif cloud_provider == "onprem": + session = boto3.Session() + client_kwargs = {} + + s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv( + "S3_ENDPOINT_URL" + ) + if s3_endpoint: + client_kwargs["endpoint_url"] = s3_endpoint + + addressing_style = getattr(infra_config(), "s3_addressing_style", "path") + client_kwargs["config"] = boto3.session.Config(s3={"addressing_style": addressing_style}) + + client = session.client("s3", **client_kwargs) else: profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) session = boto3.Session(profile_name=profile_name) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index af7790d1e..de352f01a 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -531,17 +531,28 @@ def _get_backend_url_and_conf( backend_url = get_redis_endpoint(1) elif backend_protocol == "s3": backend_url = "s3://" - if aws_role is None: - aws_session = session(infra_config().profile_ml_worker) + if infra_config().cloud_provider == "aws": + if aws_role is None: + aws_session = session(infra_config().profile_ml_worker) + else: + aws_session = session(aws_role) + out_conf_changes.update( + { + "s3_boto3_session": aws_session, + "s3_bucket": s3_bucket, + "s3_base_path": s3_base_path, + } + ) else: - aws_session = session(aws_role) - out_conf_changes.update( - { - "s3_boto3_session": aws_session, - "s3_bucket": s3_bucket, - "s3_base_path": s3_base_path, - } - ) + logger.info( + "Non-AWS deployment, using environment variables for S3 backend credentials" + ) + out_conf_changes.update( + { + "s3_bucket": s3_bucket, + "s3_base_path": s3_base_path, + } + ) elif backend_protocol == "abs": backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}" else: diff --git a/model-engine/model_engine_server/core/configs/onprem.yaml b/model-engine/model_engine_server/core/configs/onprem.yaml new file mode 100644 index 000000000..9206286ac --- /dev/null +++ b/model-engine/model_engine_server/core/configs/onprem.yaml @@ -0,0 +1,72 @@ +# On-premise deployment configuration +# This configuration file provides defaults for on-prem deployments +# Many values can be overridden via environment variables + +cloud_provider: "onprem" +env: "production" # Can be: production, staging, development, local +k8s_cluster_name: "onprem-cluster" +dns_host_domain: "ml.company.local" +default_region: "us-east-1" # Placeholder for compatibility with cloud-agnostic code + +# ==================== +# Object Storage (MinIO/S3-compatible) +# ==================== +s3_bucket: "model-engine" +# S3 endpoint URL - can be overridden by S3_ENDPOINT_URL env var +# Examples: "https://minio.company.local", "http://minio-service:9000" +s3_endpoint_url: "" # Set via S3_ENDPOINT_URL env var if not specified here +# MinIO requires path-style addressing (bucket in URL path, not subdomain) +s3_addressing_style: "path" + +# ==================== +# Redis Configuration +# ==================== +# Redis is used for: +# - Celery task queue broker +# - Model endpoint caching +# - Inference autoscaling metrics +redis_host: "" # Set via REDIS_HOST env var (e.g., "redis.company.local" or "redis-service") +redis_port: 6379 +# Whether to use Redis as Celery broker (true for on-prem) +celery_broker_type_redis: true + +# ==================== +# Celery Configuration +# ==================== +# Backend protocol: "redis" for on-prem (not "s3" or "abs") +celery_backend_protocol: "redis" + +# ==================== +# Database Configuration +# ==================== +# Database connection settings (credentials from environment variables) +# DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD +db_host: "postgres" # Default hostname, can be overridden by DB_HOST env var +db_port: 5432 +db_name: "llm_engine" +db_engine_pool_size: 20 +db_engine_max_overflow: 10 +db_engine_echo: false +db_engine_echo_pool: false +db_engine_disconnect_strategy: "pessimistic" + +# ==================== +# Docker Registry Configuration +# ==================== +# Docker registry prefix for container images +# Examples: "registry.company.local", "harbor.company.local/ml-platform" +# Leave empty if using full image paths directly +docker_repo_prefix: "registry.company.local" + +# ==================== +# Monitoring & Observability +# ==================== +# Prometheus server address for metrics (optional) +# prometheus_server_address: "http://prometheus:9090" + +# ==================== +# Not applicable for on-prem (kept for compatibility) +# ==================== +ml_account_id: "onprem" +profile_ml_worker: "default" +profile_ml_inference_worker: "default" diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index 5033d8ada..1e2f3149d 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -59,7 +59,17 @@ def get_engine_url( key_file = get_key_file_name(env) # type: ignore logger.debug(f"Using key file {key_file}") - if infra_config().cloud_provider == "azure": + if infra_config().cloud_provider == "onprem": + user = os.environ.get("DB_USER", "postgres") + password = os.environ.get("DB_PASSWORD", "postgres") + host = os.environ.get("DB_HOST_RO") or os.environ.get("DB_HOST", "localhost") + port = os.environ.get("DB_PORT", "5432") + dbname = os.environ.get("DB_NAME", "llm_engine") + logger.info(f"Connecting to db {host}:{port}, name {dbname}") + + engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}" + + elif infra_config().cloud_provider == "azure": client = SecretClient( vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net", credential=DefaultAzureCredential(), diff --git a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py index 2a5a4863c..512818e35 100644 --- a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -71,10 +71,18 @@ def validate_fields_present_for_framework_type(cls, field_values): "type was selected." ) else: # field_values["framework_type"] == ModelBundleFramework.CUSTOM: - assert field_values["ecr_repo"] and field_values["image_tag"], ( - "Expected `ecr_repo` and `image_tag` to be non-null because the custom framework " + assert field_values["image_tag"], ( + "Expected `image_tag` to be non-null because the custom framework " "type was selected." ) + if not field_values.get("ecr_repo"): + from model_engine_server.core.config import infra_config + + if infra_config().cloud_provider != "onprem": + raise ValueError( + "Expected `ecr_repo` to be non-null for custom framework. " + "For on-prem deployments, ecr_repo can be omitted to use direct image references." + ) return field_values model_config = ConfigDict(from_attributes=True) diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 98dcd9b35..b6740ec25 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -42,6 +42,9 @@ ECRDockerRepository, FakeDockerRepository, ) +from model_engine_server.infra.repositories.onprem_docker_repository import ( + OnPremDockerRepository, +) from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, ) @@ -124,8 +127,10 @@ async def main(args: Any): docker_repo: DockerRepository if CIRCLECI: docker_repo = FakeDockerRepository() - elif infra_config().docker_repo_prefix.endswith("azurecr.io"): + elif infra_config().cloud_provider == "azure": docker_repo = ACRDockerRepository() + elif infra_config().cloud_provider == "onprem": + docker_repo = OnPremDockerRepository() else: docker_repo = ECRDockerRepository() while True: diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 068b6e856..5183955b9 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -639,6 +639,7 @@ async def _make_request( "host", "content-length", "connection", + "transfer-encoding", } headers = {k: v for k, v in headers.items() if k.lower() not in excluded_headers} url = request.url diff --git a/model-engine/model_engine_server/infra/gateways/resources/onprem_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/onprem_queue_endpoint_resource_delegate.py new file mode 100644 index 000000000..c86eed1cd --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/onprem_queue_endpoint_resource_delegate.py @@ -0,0 +1,52 @@ +from typing import Any, Dict, Sequence + +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + +__all__: Sequence[str] = ("OnPremQueueEndpointResourceDelegate",) + + +class OnPremQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + + logger.debug( + f"On-prem queue for endpoint {endpoint_id}: {queue_name} " + f"(Redis queues don't require explicit creation)" + ) + + return QueueInfo(queue_name=queue_name, queue_url=queue_name) + + async def delete_queue(self, endpoint_id: str) -> None: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + logger.debug(f"Delete request for queue {queue_name} (no-op for Redis-based queues)") + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + + logger.warning( + f"Getting queue attributes for {queue_name} - returning hardcoded values. " + f"On-prem Redis queues do not support real-time message counts. " + f"Do not rely on ApproximateNumberOfMessages for autoscaling decisions." + ) + + return { + "Attributes": { + "ApproximateNumberOfMessages": "0", + "QueueName": queue_name, + }, + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + } diff --git a/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py index a50207408..8d6747890 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py @@ -2,35 +2,41 @@ from typing import List, Optional from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.gateways.file_storage_gateway import ( FileMetadata, FileStorageGateway, ) -from model_engine_server.infra.gateways import S3FilesystemGateway +from model_engine_server.infra.gateways.s3_filesystem_gateway import S3FilesystemGateway +from model_engine_server.infra.gateways.s3_utils import get_s3_client +logger = make_logger(logger_name()) -def get_s3_key(owner: str, file_id: str): + +def get_s3_key(owner: str, file_id: str) -> str: return os.path.join(owner, file_id) -def get_s3_url(owner: str, file_id: str): +def get_s3_url(owner: str, file_id: str) -> str: return f"s3://{infra_config().s3_bucket}/{get_s3_key(owner, file_id)}" class S3FileStorageGateway(FileStorageGateway): - """ - Concrete implementation of a file storage gateway backed by S3. - """ - def __init__(self): self.filesystem_gateway = S3FilesystemGateway() async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: - return self.filesystem_gateway.generate_signed_url(get_s3_url(owner, file_id)) + try: + url = self.filesystem_gateway.generate_signed_url(get_s3_url(owner, file_id)) + logger.debug(f"Generated presigned URL for {owner}/{file_id}") + return url + except Exception as e: + logger.error(f"Failed to generate presigned URL for {owner}/{file_id}: {e}") + return None async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: try: - obj = self.filesystem_gateway.get_s3_client({}).head_object( + obj = get_s3_client({}).head_object( Bucket=infra_config().s3_bucket, Key=get_s3_key(owner, file_id), ) @@ -41,7 +47,8 @@ async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: owner=owner, updated_at=obj.get("LastModified"), ) - except: # noqa: E722 + except Exception as e: + logger.debug(f"File not found or error retrieving {owner}/{file_id}: {e}") return None async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: @@ -49,8 +56,11 @@ async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: with self.filesystem_gateway.open( get_s3_url(owner, file_id), aws_profile=infra_config().profile_ml_worker ) as f: - return f.read() - except: # noqa: E722 + content = f.read() + logger.debug(f"Retrieved content for {owner}/{file_id}") + return content + except Exception as e: + logger.error(f"Failed to read file {owner}/{file_id}: {e}") return None async def upload_file(self, owner: str, filename: str, content: bytes) -> str: @@ -58,22 +68,38 @@ async def upload_file(self, owner: str, filename: str, content: bytes) -> str: get_s3_url(owner, filename), mode="w", aws_profile=infra_config().profile_ml_worker ) as f: f.write(content.decode("utf-8")) + logger.info(f"Uploaded file {owner}/{filename}") return filename async def delete_file(self, owner: str, file_id: str) -> bool: try: - self.filesystem_gateway.get_s3_client({}).delete_object( + get_s3_client({}).delete_object( Bucket=infra_config().s3_bucket, Key=get_s3_key(owner, file_id), ) + logger.info(f"Deleted file {owner}/{file_id}") return True - except: # noqa: E722 + except Exception as e: + logger.error(f"Failed to delete file {owner}/{file_id}: {e}") return False async def list_files(self, owner: str) -> List[FileMetadata]: - objects = self.filesystem_gateway.get_s3_client({}).list_objects_v2( - Bucket=infra_config().s3_bucket, - Prefix=owner, - ) - files = [await self.get_file(owner, obj["Name"]) for obj in objects] - return [f for f in files if f is not None] + try: + objects = get_s3_client({}).list_objects_v2( + Bucket=infra_config().s3_bucket, + Prefix=owner, + ) + files = [] + for obj in objects.get("Contents", []): + key = obj["Key"] + if key.startswith(owner): + file_id = key[len(owner) :].lstrip("/") + if file_id: + file_metadata = await self.get_file(owner, file_id) + if file_metadata: + files.append(file_metadata) + logger.debug(f"Listed {len(files)} files for owner {owner}") + return files + except Exception as e: + logger.error(f"Failed to list files for owner {owner}: {e}") + return [] diff --git a/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py index b0bf9e84e..4cdf02c35 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py @@ -1,33 +1,22 @@ -import os import re from typing import IO -import boto3 import smart_open from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.s3_utils import get_s3_client class S3FilesystemGateway(FilesystemGateway): - """ - Concrete implementation for interacting with a filesystem backed by S3. - """ - - def get_s3_client(self, kwargs): - profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) - session = boto3.Session(profile_name=profile_name) - client = session.client("s3") - return client - def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: - # This follows the 5.1.0 smart_open API - client = self.get_s3_client(kwargs) + client = get_s3_client(kwargs) transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: - client = self.get_s3_client(kwargs) - match = re.search("^s3://([^/]+)/(.*?)$", uri) - assert match + client = get_s3_client(kwargs) + match = re.search(r"^s3://([^/]+)/(.*?)$", uri) + if not match: + raise ValueError(f"Invalid S3 URI format: {uri}") bucket, key = match.group(1), match.group(2) return client.generate_presigned_url( diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index b48d1eef2..504234c59 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -2,49 +2,42 @@ import os from typing import Any, Dict, List -import boto3 from model_engine_server.common.config import get_model_cache_directory_name, hmi_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.url import parse_attachment_url from model_engine_server.domain.gateways import LLMArtifactGateway +from model_engine_server.infra.gateways.s3_utils import get_s3_resource logger = make_logger(logger_name()) class S3LLMArtifactGateway(LLMArtifactGateway): - """ - Concrete implemention for interacting with a filesystem backed by S3. - """ - - def _get_s3_resource(self, kwargs): - profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) - session = boto3.Session(profile_name=profile_name) - resource = session.resource("s3") - return resource - def list_files(self, path: str, **kwargs) -> List[str]: - s3 = self._get_s3_resource(kwargs) + s3 = get_s3_resource(kwargs) parsed_remote = parse_attachment_url(path, clean_key=False) bucket = parsed_remote.bucket key = parsed_remote.key s3_bucket = s3.Bucket(bucket) files = [obj.key for obj in s3_bucket.objects.filter(Prefix=key)] + logger.debug(f"Listed {len(files)} files from {path}") return files def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: - s3 = self._get_s3_resource(kwargs) + s3 = get_s3_resource(kwargs) parsed_remote = parse_attachment_url(path, clean_key=False) bucket = parsed_remote.bucket key = parsed_remote.key s3_bucket = s3.Bucket(bucket) downloaded_files: List[str] = [] + for obj in s3_bucket.objects.filter(Prefix=key): file_path_suffix = obj.key.replace(key, "").lstrip("/") local_path = os.path.join(target_path, file_path_suffix).rstrip("/") if not overwrite and os.path.exists(local_path): + logger.debug(f"Skipping existing file: {local_path}") downloaded_files.append(local_path) continue @@ -55,10 +48,12 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) logger.info(f"Downloading {obj.key} to {local_path}") s3_bucket.download_file(obj.key, local_path) downloaded_files.append(local_path) + + logger.info(f"Downloaded {len(downloaded_files)} files to {target_path}") return downloaded_files def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: - s3 = self._get_s3_resource(kwargs) + s3 = get_s3_resource(kwargs) parsed_remote = parse_attachment_url( hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False ) @@ -69,17 +64,27 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ model_files: List[str] = [] model_cache_name = get_model_cache_directory_name(model_name) prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" + for obj in s3_bucket.objects.filter(Prefix=prefix): model_files.append(f"s3://{bucket}/{obj.key}") + + logger.debug(f"Found {len(model_files)} model weight files for {owner}/{model_name}") return model_files def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: - s3 = self._get_s3_resource(kwargs) + s3 = get_s3_resource(kwargs) parsed_remote = parse_attachment_url(path, clean_key=False) bucket = parsed_remote.bucket key = os.path.join(parsed_remote.key, "config.json") + s3_bucket = s3.Bucket(bucket) - filepath = os.path.join("/tmp", key).replace("/", "_") + filepath = os.path.join("/tmp", key.replace("/", "_")) + + logger.debug(f"Downloading config from {bucket}/{key} to {filepath}") s3_bucket.download_file(key, filepath) + with open(filepath, "r") as f: - return json.load(f) + config = json.load(f) + + logger.debug(f"Loaded model config from {path}") + return config diff --git a/model-engine/model_engine_server/infra/gateways/s3_utils.py b/model-engine/model_engine_server/infra/gateways/s3_utils.py new file mode 100644 index 000000000..88c142fe2 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/s3_utils.py @@ -0,0 +1,60 @@ +import os +from typing import Any, Dict, Optional + +import boto3 +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + + +def get_s3_client(kwargs: Optional[Dict[str, Any]] = None): + kwargs = kwargs or {} + session = boto3.Session() + client_kwargs = {} + + if infra_config().cloud_provider == "onprem": + logger.debug("Using on-prem/MinIO S3-compatible configuration") + + s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv( + "S3_ENDPOINT_URL" + ) + if s3_endpoint: + client_kwargs["endpoint_url"] = s3_endpoint + logger.debug(f"Using S3 endpoint: {s3_endpoint}") + + addressing_style = getattr(infra_config(), "s3_addressing_style", "path") + client_kwargs["config"] = boto3.session.Config(s3={"addressing_style": addressing_style}) + else: + logger.debug("Using AWS S3 configuration") + aws_profile = kwargs.get("aws_profile") + if aws_profile: + session = boto3.Session(profile_name=aws_profile) + + return session.client("s3", **client_kwargs) + + +def get_s3_resource(kwargs: Optional[Dict[str, Any]] = None): + kwargs = kwargs or {} + session = boto3.Session() + resource_kwargs = {} + + if infra_config().cloud_provider == "onprem": + logger.debug("Using on-prem/MinIO S3-compatible configuration") + + s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv( + "S3_ENDPOINT_URL" + ) + if s3_endpoint: + resource_kwargs["endpoint_url"] = s3_endpoint + logger.debug(f"Using S3 endpoint: {s3_endpoint}") + + addressing_style = getattr(infra_config(), "s3_addressing_style", "path") + resource_kwargs["config"] = boto3.session.Config(s3={"addressing_style": addressing_style}) + else: + logger.debug("Using AWS S3 configuration") + aws_profile = kwargs.get("aws_profile") + if aws_profile: + session = boto3.Session(profile_name=aws_profile) + + return session.resource("s3", **resource_kwargs) diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index f14cf69f7..5a9a32070 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -16,6 +16,7 @@ from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository from .model_endpoint_record_repository import ModelEndpointRecordRepository +from .onprem_docker_repository import OnPremDockerRepository from .redis_feature_flag_repository import RedisFeatureFlagRepository from .redis_model_endpoint_cache_repository import RedisModelEndpointCacheRepository from .s3_file_llm_fine_tune_events_repository import S3FileLLMFineTuneEventsRepository @@ -38,6 +39,7 @@ "LLMFineTuneRepository", "ModelEndpointRecordRepository", "ModelEndpointCacheRepository", + "OnPremDockerRepository", "RedisFeatureFlagRepository", "RedisModelEndpointCacheRepository", "S3FileLLMFineTuneRepository", diff --git a/model-engine/model_engine_server/infra/repositories/onprem_docker_repository.py b/model-engine/model_engine_server/infra/repositories/onprem_docker_repository.py new file mode 100644 index 000000000..4e2787ee5 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/onprem_docker_repository.py @@ -0,0 +1,48 @@ +from typing import Optional + +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + + +class OnPremDockerRepository(DockerRepository): + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + if not repository_name: + logger.warning( + f"Direct image reference: {image_tag}, assuming exists. " + f"Image validation skipped for on-prem deployments." + ) + return True + + logger.warning( + f"Registry image: {repository_name}:{image_tag}, assuming exists. " + f"Image validation skipped for on-prem deployments. " + f"Deployment will fail if image does not exist in registry." + ) + return True + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + if not repository_name: + logger.debug(f"Using direct image reference: {image_tag}") + return image_tag + + image_url = f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + logger.debug(f"Constructed image URL: {image_url}") + return image_url + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError( + "OnPremDockerRepository does not support building images. " + "Images should be built via CI/CD and pushed to the on-prem registry." + ) + + def get_latest_image_tag(self, repository_name: str) -> str: + raise NotImplementedError( + "OnPremDockerRepository does not support querying latest image tags. " + "Please specify explicit image tags in your deployment configuration." + ) diff --git a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py index 2dfcbc769..86241f968 100644 --- a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py @@ -1,18 +1,19 @@ import json -import os from json.decoder import JSONDecodeError from typing import IO, List -import boto3 import smart_open from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( LLMFineTuneEventsRepository, ) +from model_engine_server.infra.gateways.s3_utils import get_s3_client + +logger = make_logger(logger_name()) -# Echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( f"s3://{infra_config().s3_bucket}/hosted-model-inference/fine_tuned_weights" ) @@ -20,34 +21,24 @@ class S3FileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): def __init__(self): - pass - - # _get_s3_client + _open copypasted from s3_file_llm_fine_tune_repo, in turn from s3_filesystem_gateway - # sorry - def _get_s3_client(self, kwargs): - profile_name = kwargs.get("aws_profile", os.getenv("S3_WRITE_AWS_PROFILE")) - session = boto3.Session(profile_name=profile_name) - client = session.client("s3") - return client + logger.debug("Initialized S3FileLLMFineTuneEventsRepository") def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: - # This follows the 5.1.0 smart_open API - client = self._get_s3_client(kwargs) + client = get_s3_client(kwargs) transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) - # echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py - def _get_model_cache_directory_name(self, model_name: str): + def _get_model_cache_directory_name(self, model_name: str) -> str: """How huggingface maps model names to directory names in their cache for model files. We adopt this when storing model cache files in s3. - Args: model_name (str): Name of the huggingface model """ + name = "models--" + model_name.replace("/", "--") return name - def _get_file_location(self, user_id: str, model_endpoint_name: str): + def _get_file_location(self, user_id: str, model_endpoint_name: str) -> str: model_cache_name = self._get_model_cache_directory_name(model_endpoint_name) s3_file_location = ( f"{S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX}/{user_id}/{model_cache_name}.jsonl" @@ -78,12 +69,18 @@ async def get_fine_tune_events( level="info", ) final_events.append(event) + logger.debug( + f"Retrieved {len(final_events)} events for {user_id}/{model_endpoint_name}" + ) return final_events - except Exception as exc: # TODO better exception + except Exception as exc: + logger.error(f"Failed to get fine-tune events from {s3_file_location}: {exc}") raise ObjectNotFoundException from exc async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: s3_file_location = self._get_file_location( user_id=user_id, model_endpoint_name=model_endpoint_name ) - self._open(s3_file_location, "w") + with self._open(s3_file_location, "w"): + pass + logger.info(f"Initialized events file at {s3_file_location}") diff --git a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py index 6b3ea8aa8..a58f9c4d1 100644 --- a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py @@ -1,57 +1,61 @@ import json -import os from typing import IO, Dict, Optional -import boto3 import smart_open +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.gateways.s3_utils import get_s3_client from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository +logger = make_logger(logger_name()) + class S3FileLLMFineTuneRepository(LLMFineTuneRepository): def __init__(self, file_path: str): self.file_path = file_path - - def _get_s3_client(self, kwargs): - profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) - session = boto3.Session(profile_name=profile_name) - client = session.client("s3") - return client + logger.debug(f"Initialized S3FileLLMFineTuneRepository with path: {file_path}") def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: - # This follows the 5.1.0 smart_open API - client = self._get_s3_client(kwargs) + client = get_s3_client(kwargs) transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) @staticmethod - def _get_key(model_name, fine_tuning_method): + def _get_key(model_name: str, fine_tuning_method: str) -> str: return f"{model_name}-{fine_tuning_method}" # possible for collisions but we control these names async def get_job_template_for_model( self, model_name: str, fine_tuning_method: str ) -> Optional[LLMFineTuneTemplate]: - # can hot reload the file lol - with self._open(self.file_path, "r") as f: - data = json.load(f) - key = self._get_key(model_name, fine_tuning_method) - job_template_dict = data.get(key, None) - if job_template_dict is None: - return None - return LLMFineTuneTemplate.parse_obj(job_template_dict) + try: + with self._open(self.file_path, "r") as f: + data = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + job_template_dict = data.get(key, None) + if job_template_dict is None: + logger.debug(f"No template found for {key}") + return None + logger.debug(f"Retrieved template for {key}") + return LLMFineTuneTemplate.parse_obj(job_template_dict) + except Exception as e: + logger.error(f"Failed to get job template for {model_name}/{fine_tuning_method}: {e}") + return None async def write_job_template_for_model( self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate ): - # Use locally in script with self._open(self.file_path, "r") as f: data: Dict = json.load(f) + key = self._get_key(model_name, fine_tuning_method) data[key] = dict(job_template) + with self._open(self.file_path, "w") as f: json.dump(data, f) + logger.info(f"Wrote job template for {key}") + async def initialize_data(self): - # Use locally in script with self._open(self.file_path, "w") as f: json.dump({}, f) + logger.info(f"Initialized fine-tune repository at {self.file_path}") diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 1b32104a5..c285ab19f 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -250,12 +250,9 @@ async def build_endpoint( else: flavor = model_bundle.flavor assert isinstance(flavor, RunnableImageLike) - repository = ( - f"{infra_config().docker_repo_prefix}/{flavor.repository}" - if self.docker_repository.is_repo_name(flavor.repository) - else flavor.repository + image = self.docker_repository.get_image_url( + image_tag=flavor.tag, repository_name=flavor.repository ) - image = f"{repository}:{flavor.tag}" # Because this update is not the final update in the lock, the 'update_in_progress' # value isn't really necessary for correctness in not having races, but it's still diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index 8db4a109c..cf40510d8 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -52,6 +52,9 @@ RedisFeatureFlagRepository, RedisModelEndpointCacheRepository, ) +from model_engine_server.infra.repositories.onprem_docker_repository import ( + OnPremDockerRepository, +) from model_engine_server.infra.services import LiveEndpointBuilderService from model_engine_server.service_builder.celery import service_builder_service @@ -83,8 +86,10 @@ def get_live_endpoint_builder_service( docker_repository: DockerRepository if CIRCLECI: docker_repository = FakeDockerRepository() - elif infra_config().docker_repo_prefix.endswith("azurecr.io"): + elif infra_config().cloud_provider == "azure": docker_repository = ACRDockerRepository() + elif infra_config().cloud_provider == "onprem": + docker_repository = OnPremDockerRepository() else: docker_repository = ECRDockerRepository() inference_autoscaling_metrics_gateway = ( diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 6e784ecc9..516a26677 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -326,7 +326,7 @@ protobuf==3.20.3 # -r model-engine/requirements.in # ddsketch # ddtrace -psycopg2-binary==2.9.3 +psycopg2-binary==2.9.10 # via -r model-engine/requirements.in py-xid==0.3.0 # via -r model-engine/requirements.in diff --git a/pr.md b/pr.md new file mode 100644 index 000000000..2e1a102c2 --- /dev/null +++ b/pr.md @@ -0,0 +1,20 @@ +# Add On-Premise Deployment Support + +This PR adds comprehensive support for on-premise deployments using Redis, MinIO/S3-compatible storage, and private container registries as alternatives to cloud-managed services. + +## Key Changes + +- **New on-prem configuration**: Added `onprem.yaml` config file with settings for MinIO, Redis, and private registries +- **Redis-based infrastructure**: Implemented Redis task queues and on-prem queue endpoint delegate +- **S3-compatible storage**: Added support for MinIO and custom S3 endpoints with configurable addressing styles +- **Container registry flexibility**: Support for private registries with `OnPremDockerRepository` +- **Database configuration**: Environment variable-based PostgreSQL connection for on-prem deployments +- **Improved logging**: Enhanced error handling and debug logs in S3 file storage gateway + +## Configuration Highlights + +The on-prem setup allows deployments to use: +- MinIO or S3-compatible object storage instead of AWS S3/Azure Blob +- Redis for Celery task queues and caching instead of SQS/ASB +- Local PostgreSQL with environment-based credentials +- Private container registries instead of ECR/ACR