Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ serviceTemplate:
config:
values:
infra:
# cloud_provider [required]; either "aws" or "azure"
# cloud_provider [required]; either "aws" or "azure" or "gcp"
cloud_provider: aws
# k8s_cluster_name [required] is the name of the k8s cluster
k8s_cluster_name: main_cluster
Expand Down
38 changes: 23 additions & 15 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@
LiveLLMModelEndpointService,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway
from model_engine_server.infra.gateways.gcs_llm_artifact_gateway import GCSLLMArtifactGateway
from model_engine_server.infra.gateways.gcs_file_storage_gateway import GCSFileStorageGateway

logger = make_logger(logger_name())

Expand Down Expand Up @@ -258,16 +261,20 @@ def _get_external_interfaces(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
filesystem_gateway = (
ABSFilesystemGateway()
if infra_config().cloud_provider == "azure"
else S3FilesystemGateway()
)
llm_artifact_gateway = (
ABSLLMArtifactGateway()
if infra_config().cloud_provider == "azure"
else S3LLMArtifactGateway()
)
if infra_config().cloud_provider == "azure":
filesystem_gateway = ABSFilesystemGateway()
elif infra_config().cloud_provider == "gcp":
filesystem_gateway = GCSFilesystemGateway()
else:
filesystem_gateway = S3FilesystemGateway()

if infra_config().cloud_provider == "azure":
llm_artifact_gateway = ABSLLMArtifactGateway()
elif infra_config().cloud_provider == "gcp":
llm_artifact_gateway = GCSLLMArtifactGateway()
else:
llm_artifact_gateway = S3LLMArtifactGateway()

model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
filesystem_gateway=filesystem_gateway
)
Expand Down Expand Up @@ -334,11 +341,12 @@ def _get_external_interfaces(
docker_image_batch_job_gateway=docker_image_batch_job_gateway
)

file_storage_gateway = (
ABSFileStorageGateway()
if infra_config().cloud_provider == "azure"
else S3FileStorageGateway()
)
if infra_config().cloud_provider == "azure":
file_storage_gateway = ABSFileStorageGateway()
elif infra_config().cloud_provider == "gcp":
file_storage_gateway = GCSFileStorageGateway()
else:
file_storage_gateway = S3FileStorageGateway()

docker_repository: DockerRepository
if CIRCLECI:
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class _InfraConfig:
ml_account_id: str
docker_repo_prefix: str
s3_bucket: str
gcs_bucket: Optional[str] = None
redis_host: Optional[str] = None
redis_aws_secret_name: Optional[str] = None
profile_ml_worker: str = "default"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I think this is only really used for some fine tuning apis that aren't really used at this point, think it's fine to keep ofc since you'll probably need to initialize dependencies anyways, but this code probably won't really get exercised at all

from typing import List, Optional

from google.cloud import storage

from model_engine_server.core.config import infra_config
from model_engine_server.domain.gateways.file_storage_gateway import (
FileMetadata,
FileStorageGateway,
)
from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway


def get_gcs_key(owner: str, file_id: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd prefix these w/ an underscore so that no one is tempted to try and import these from outside this file, thus breaking Clean Architecture norms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I think the s3_file_storage_gateway also doesn't have the prefixes

"""
Constructs a GCS object key from the owner and file_id.
"""
return os.path.join(owner, file_id)


def get_gcs_url(owner: str, file_id: str) -> str:
"""
Returns the gs:// URL for the bucket, using the GCS key.
"""
return f"gs://{infra_config().gcs_bucket}/{get_gcs_key(owner, file_id)}"


class GCSFileStorageGateway(FileStorageGateway):
"""
Concrete implementation of a file storage gateway backed by GCS.
"""

def __init__(self):
self.filesystem_gateway = GCSFilesystemGateway()

async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]:
"""
Returns a signed GCS URL for the given file.
"""
try:
return self.filesystem_gateway.generate_signed_url(get_gcs_url(owner, file_id))
except Exception:
return None

async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]:
"""
Retrieves file metadata if it exists. Returns None if the file is missing.
"""
try:
client = self.filesystem_gateway.get_storage_client({})
bucket = client.bucket(infra_config().gcs_bucket)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this pattern was already there, but I think it'd probably make more sense to pass in the bucket into the constructor of this class. This way, there's one less dependency on the old infra_config object. @seanshi-scale @tiffzhao5 thoughts?

Could also make the argument to just pass in the bucket as an argument with every get_file call, but that's outside of the scope of this change I'd say.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with having the bucket be passed in the constructor (in addition to anything else from any configs); dependencies.py does read in from infra_config at times to figure out constructor arguments, so there's precedent already

blob = bucket.blob(get_gcs_key(owner, file_id))
blob.reload() # Fetch metadata
return FileMetadata(
id=file_id,
filename=file_id,
size=blob.size,
owner=owner,
updated_at=blob.updated,
)
except Exception:
return None

async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
"""
Reads and returns the string content of the file.
"""
try:
with self.filesystem_gateway.open(get_gcs_url(owner, file_id)) as f:
return f.read()
except Exception:
return None

async def upload_file(self, owner: str, filename: str, content: bytes) -> str:
"""
Uploads the file to the GCS bucket. Returns the filename used in bucket.
"""
with self.filesystem_gateway.open(
get_gcs_url(owner, filename), mode="w"
) as f:
f.write(content.decode("utf-8"))
return filename

async def delete_file(self, owner: str, file_id: str) -> bool:
"""
Deletes the file from the GCS bucket. Returns True if successful, False otherwise.
"""
try:
client = self.filesystem_gateway.get_storage_client({})
bucket = client.bucket(infra_config().gcs_bucket)
blob = bucket.blob(get_gcs_key(owner, file_id))
blob.delete()
return True
except Exception:
return False

async def list_files(self, owner: str) -> List[FileMetadata]:
"""
Lists all files in the GCS bucket for the given owner.
"""
client = self.filesystem_gateway.get_storage_client({})
blobs = client.list_blobs(infra_config().gcs_bucket, prefix=owner)
files = [await self.get_file(owner, b.name[len(owner) + 1 :]) for b in blobs if b.name != owner]
return [f for f in files if f is not None]
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import re
from typing import IO, Optional, Dict

import smart_open
from google.cloud import storage
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway


class GCSFilesystemGateway(FilesystemGateway):
"""
Concrete implementation for interacting with Google Cloud Storage.
"""

def get_storage_client(self, kwargs: Optional[Dict]) -> storage.Client:
"""
Retrieve or create a Google Cloud Storage client. Could optionally
utilize environment variables or passed-in credentials.
"""
project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this env var get set? it seems analogous to AWS_PROFILE but those changes would need to be baked into any relevant k8s yamls most likely

return storage.Client(project=project)

def open(self, uri: str, mode: str = "rt", **kwargs) -> IO:
"""
Uses smart_open to handle reading/writing to GCS.
"""
# The `transport_params` is how smart_open passes in the storage client
client = self.get_storage_client(kwargs)
transport_params = {"client": client}
return smart_open.open(uri, mode, transport_params=transport_params)

def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
"""
Generate a signed URL for the given GCS URI, valid for `expiration` seconds.
"""
# Expecting URIs in the form: 'gs://bucket_name/some_key'
match = re.search(r"^gs://([^/]+)/(.+)$", uri)
if not match:
raise ValueError(f"Invalid GCS URI: {uri}")

bucket_name, blob_name = match.groups()
client = self.get_storage_client(kwargs)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)

return blob.generate_signed_url(expiration=expiration)
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import os
from typing import Any, Dict, List

from google.cloud import storage
from model_engine_server.common.config import get_model_cache_directory_name, hmi_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.core.utils.url import parse_attachment_url
from model_engine_server.domain.gateways import LLMArtifactGateway

logger = make_logger(logger_name())


class GCSLLMArtifactGateway(LLMArtifactGateway):
"""
Concrete implementation for interacting with a filesystem backed by GCS.
"""

def _get_gcs_client(self, kwargs) -> storage.Client:
"""
Returns a GCS client. If desired, you could pass in project info
or credentials via `kwargs`.
"""
project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT"))
return storage.Client(project=project)

def list_files(self, path: str, **kwargs) -> List[str]:
"""
Lists all files under the path argument in GCS. The path is expected
to be in the form 'gs://bucket/prefix'.
"""
gcs = self._get_gcs_client(kwargs)
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket_name = parsed_remote.bucket
prefix = parsed_remote.key

bucket = gcs.bucket(bucket_name)
files = [blob.name for blob in bucket.list_blobs(prefix=prefix)]
return files

def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC with the current state of the code, this only gets called if you use TGI, LightLLM, TensorRT-LLM, Deepspeed as inference frameworks, so I doubt this code ends up getting exercised in practice (it's only used to download a tokenizer to count tokens on the Gateway)

"""
Downloads all files under the given path to the local target_path directory.
"""
gcs = self._get_gcs_client(kwargs)
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket_name = parsed_remote.bucket
prefix = parsed_remote.key

bucket = gcs.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=prefix)
downloaded_files = []

for blob in blobs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is just sequentially downloading the files? Is this the recommended way? ChatGPT showed me two responses, (1) with a ThreadPoolExecutor, and (2) this one.

# Remove prefix and leading slash to derive local name
file_path_suffix = blob.name.replace(prefix, "").lstrip("/")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to replace(prefix, "", count=1)? (or something like that) just in case the prefix appears elsewhere in the string

for that matter is that also a bug in the s3 implementation?

local_path = os.path.join(target_path, file_path_suffix).rstrip("/")

if not overwrite and os.path.exists(local_path):
downloaded_files.append(local_path)
continue

local_dir = os.path.dirname(local_path)
if not os.path.exists(local_dir):
os.makedirs(local_dir)

logger.info(f"Downloading {blob.name} to {local_path}")
blob.download_to_filename(local_path)
downloaded_files.append(local_path)

return downloaded_files

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
"""
Retrieves URLs for all model weight artifacts stored under the
prefix: hmi_config.hf_user_fine_tuned_weights_prefix / {owner} / {model_cache_name}
"""
gcs = self._get_gcs_client(kwargs)
prefix_base = hmi_config.hf_user_fine_tuned_weights_prefix
if prefix_base.startswith("gs://"):
# Strip "gs://" for prefix logic below
prefix_base = prefix_base[5:]
bucket_name, prefix_base = prefix_base.split("/", 1)

model_cache_name = get_model_cache_directory_name(model_name)
prefix = f"{prefix_base}/{owner}/{model_cache_name}"

bucket = gcs.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=prefix)

model_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
return model_files

def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
"""
Downloads a 'config.json' file from GCS located at path/config.json
and returns it as a dictionary.
"""
gcs = self._get_gcs_client(kwargs)
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket_name = parsed_remote.bucket
# The key from parse_attachment_url might be e.g. "weight_prefix/model_dir"
# so we append "/config.json" and build a local path to download it.
key_with_config = os.path.join(parsed_remote.key, "config.json")

bucket = gcs.bucket(bucket_name)
blob = bucket.blob(key_with_config)

# Download to a tmp path and load
filepath = os.path.join("/tmp", key_with_config.replace("/", "_"))
os.makedirs(os.path.dirname(filepath), exist_ok=True)
blob.download_to_filename(filepath)

with open(filepath, "r") as f:
return json.load(f)