-
Notifications
You must be signed in to change notification settings - Fork 67
Google Cloud Storage Integration #683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import os | ||
| from typing import List, Optional | ||
|
|
||
| from google.cloud import storage | ||
|
|
||
| from model_engine_server.core.config import infra_config | ||
| from model_engine_server.domain.gateways.file_storage_gateway import ( | ||
| FileMetadata, | ||
| FileStorageGateway, | ||
| ) | ||
| from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway | ||
|
|
||
|
|
||
| def get_gcs_key(owner: str, file_id: str) -> str: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I'd prefix these w/ an underscore so that no one is tempted to try and import these from outside this file, thus breaking Clean Architecture norms. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw I think the s3_file_storage_gateway also doesn't have the prefixes |
||
| """ | ||
| Constructs a GCS object key from the owner and file_id. | ||
| """ | ||
| return os.path.join(owner, file_id) | ||
|
|
||
|
|
||
| def get_gcs_url(owner: str, file_id: str) -> str: | ||
| """ | ||
| Returns the gs:// URL for the bucket, using the GCS key. | ||
| """ | ||
| return f"gs://{infra_config().gcs_bucket}/{get_gcs_key(owner, file_id)}" | ||
|
|
||
|
|
||
| class GCSFileStorageGateway(FileStorageGateway): | ||
| """ | ||
| Concrete implementation of a file storage gateway backed by GCS. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self.filesystem_gateway = GCSFilesystemGateway() | ||
|
|
||
| async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: | ||
| """ | ||
| Returns a signed GCS URL for the given file. | ||
| """ | ||
| try: | ||
| return self.filesystem_gateway.generate_signed_url(get_gcs_url(owner, file_id)) | ||
| except Exception: | ||
| return None | ||
|
|
||
| async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: | ||
| """ | ||
| Retrieves file metadata if it exists. Returns None if the file is missing. | ||
| """ | ||
| try: | ||
| client = self.filesystem_gateway.get_storage_client({}) | ||
| bucket = client.bucket(infra_config().gcs_bucket) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this pattern was already there, but I think it'd probably make more sense to pass in the bucket into the constructor of this class. This way, there's one less dependency on the old Could also make the argument to just pass in the bucket as an argument with every There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with having the bucket be passed in the constructor (in addition to anything else from any configs); |
||
| blob = bucket.blob(get_gcs_key(owner, file_id)) | ||
| blob.reload() # Fetch metadata | ||
| return FileMetadata( | ||
| id=file_id, | ||
| filename=file_id, | ||
| size=blob.size, | ||
| owner=owner, | ||
| updated_at=blob.updated, | ||
| ) | ||
| except Exception: | ||
| return None | ||
|
|
||
| async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: | ||
| """ | ||
| Reads and returns the string content of the file. | ||
| """ | ||
| try: | ||
| with self.filesystem_gateway.open(get_gcs_url(owner, file_id)) as f: | ||
| return f.read() | ||
| except Exception: | ||
| return None | ||
|
|
||
| async def upload_file(self, owner: str, filename: str, content: bytes) -> str: | ||
| """ | ||
| Uploads the file to the GCS bucket. Returns the filename used in bucket. | ||
| """ | ||
| with self.filesystem_gateway.open( | ||
| get_gcs_url(owner, filename), mode="w" | ||
| ) as f: | ||
| f.write(content.decode("utf-8")) | ||
| return filename | ||
|
|
||
| async def delete_file(self, owner: str, file_id: str) -> bool: | ||
| """ | ||
| Deletes the file from the GCS bucket. Returns True if successful, False otherwise. | ||
| """ | ||
| try: | ||
| client = self.filesystem_gateway.get_storage_client({}) | ||
| bucket = client.bucket(infra_config().gcs_bucket) | ||
| blob = bucket.blob(get_gcs_key(owner, file_id)) | ||
| blob.delete() | ||
| return True | ||
| except Exception: | ||
| return False | ||
|
|
||
| async def list_files(self, owner: str) -> List[FileMetadata]: | ||
| """ | ||
| Lists all files in the GCS bucket for the given owner. | ||
| """ | ||
| client = self.filesystem_gateway.get_storage_client({}) | ||
| blobs = client.list_blobs(infra_config().gcs_bucket, prefix=owner) | ||
| files = [await self.get_file(owner, b.name[len(owner) + 1 :]) for b in blobs if b.name != owner] | ||
| return [f for f in files if f is not None] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| import os | ||
| import re | ||
| from typing import IO, Optional, Dict | ||
|
|
||
| import smart_open | ||
| from google.cloud import storage | ||
| from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway | ||
|
|
||
|
|
||
| class GCSFilesystemGateway(FilesystemGateway): | ||
| """ | ||
| Concrete implementation for interacting with Google Cloud Storage. | ||
| """ | ||
|
|
||
| def get_storage_client(self, kwargs: Optional[Dict]) -> storage.Client: | ||
| """ | ||
| Retrieve or create a Google Cloud Storage client. Could optionally | ||
| utilize environment variables or passed-in credentials. | ||
| """ | ||
| project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT")) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where does this env var get set? it seems analogous to |
||
| return storage.Client(project=project) | ||
|
|
||
| def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: | ||
| """ | ||
| Uses smart_open to handle reading/writing to GCS. | ||
| """ | ||
| # The `transport_params` is how smart_open passes in the storage client | ||
| client = self.get_storage_client(kwargs) | ||
| transport_params = {"client": client} | ||
| return smart_open.open(uri, mode, transport_params=transport_params) | ||
|
|
||
| def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: | ||
| """ | ||
| Generate a signed URL for the given GCS URI, valid for `expiration` seconds. | ||
| """ | ||
| # Expecting URIs in the form: 'gs://bucket_name/some_key' | ||
| match = re.search(r"^gs://([^/]+)/(.+)$", uri) | ||
| if not match: | ||
| raise ValueError(f"Invalid GCS URI: {uri}") | ||
|
|
||
| bucket_name, blob_name = match.groups() | ||
| client = self.get_storage_client(kwargs) | ||
| bucket = client.bucket(bucket_name) | ||
| blob = bucket.blob(blob_name) | ||
|
|
||
| return blob.generate_signed_url(expiration=expiration) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| import json | ||
| import os | ||
| from typing import Any, Dict, List | ||
|
|
||
| from google.cloud import storage | ||
| from model_engine_server.common.config import get_model_cache_directory_name, hmi_config | ||
| from model_engine_server.core.loggers import logger_name, make_logger | ||
| from model_engine_server.core.utils.url import parse_attachment_url | ||
| from model_engine_server.domain.gateways import LLMArtifactGateway | ||
|
|
||
| logger = make_logger(logger_name()) | ||
|
|
||
|
|
||
| class GCSLLMArtifactGateway(LLMArtifactGateway): | ||
| """ | ||
| Concrete implementation for interacting with a filesystem backed by GCS. | ||
| """ | ||
|
|
||
| def _get_gcs_client(self, kwargs) -> storage.Client: | ||
| """ | ||
| Returns a GCS client. If desired, you could pass in project info | ||
| or credentials via `kwargs`. | ||
| """ | ||
| project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT")) | ||
| return storage.Client(project=project) | ||
|
|
||
| def list_files(self, path: str, **kwargs) -> List[str]: | ||
| """ | ||
| Lists all files under the path argument in GCS. The path is expected | ||
| to be in the form 'gs://bucket/prefix'. | ||
| """ | ||
| gcs = self._get_gcs_client(kwargs) | ||
| parsed_remote = parse_attachment_url(path, clean_key=False) | ||
| bucket_name = parsed_remote.bucket | ||
| prefix = parsed_remote.key | ||
|
|
||
| bucket = gcs.bucket(bucket_name) | ||
| files = [blob.name for blob in bucket.list_blobs(prefix=prefix)] | ||
| return files | ||
|
|
||
| def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC with the current state of the code, this only gets called if you use TGI, LightLLM, TensorRT-LLM, Deepspeed as inference frameworks, so I doubt this code ends up getting exercised in practice (it's only used to download a tokenizer to count tokens on the Gateway) |
||
| """ | ||
| Downloads all files under the given path to the local target_path directory. | ||
| """ | ||
| gcs = self._get_gcs_client(kwargs) | ||
| parsed_remote = parse_attachment_url(path, clean_key=False) | ||
| bucket_name = parsed_remote.bucket | ||
| prefix = parsed_remote.key | ||
|
|
||
| bucket = gcs.bucket(bucket_name) | ||
| blobs = bucket.list_blobs(prefix=prefix) | ||
| downloaded_files = [] | ||
|
|
||
| for blob in blobs: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this is just sequentially downloading the files? Is this the recommended way? ChatGPT showed me two responses, (1) with a ThreadPoolExecutor, and (2) this one. |
||
| # Remove prefix and leading slash to derive local name | ||
| file_path_suffix = blob.name.replace(prefix, "").lstrip("/") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you want to for that matter is that also a bug in the s3 implementation? |
||
| local_path = os.path.join(target_path, file_path_suffix).rstrip("/") | ||
|
|
||
| if not overwrite and os.path.exists(local_path): | ||
| downloaded_files.append(local_path) | ||
| continue | ||
|
|
||
| local_dir = os.path.dirname(local_path) | ||
| if not os.path.exists(local_dir): | ||
| os.makedirs(local_dir) | ||
|
|
||
| logger.info(f"Downloading {blob.name} to {local_path}") | ||
| blob.download_to_filename(local_path) | ||
| downloaded_files.append(local_path) | ||
|
|
||
| return downloaded_files | ||
|
|
||
| def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: | ||
| """ | ||
| Retrieves URLs for all model weight artifacts stored under the | ||
| prefix: hmi_config.hf_user_fine_tuned_weights_prefix / {owner} / {model_cache_name} | ||
| """ | ||
| gcs = self._get_gcs_client(kwargs) | ||
| prefix_base = hmi_config.hf_user_fine_tuned_weights_prefix | ||
| if prefix_base.startswith("gs://"): | ||
| # Strip "gs://" for prefix logic below | ||
| prefix_base = prefix_base[5:] | ||
| bucket_name, prefix_base = prefix_base.split("/", 1) | ||
|
|
||
| model_cache_name = get_model_cache_directory_name(model_name) | ||
| prefix = f"{prefix_base}/{owner}/{model_cache_name}" | ||
|
|
||
| bucket = gcs.bucket(bucket_name) | ||
| blobs = bucket.list_blobs(prefix=prefix) | ||
|
|
||
| model_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs] | ||
| return model_files | ||
|
|
||
| def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: | ||
| """ | ||
| Downloads a 'config.json' file from GCS located at path/config.json | ||
| and returns it as a dictionary. | ||
| """ | ||
| gcs = self._get_gcs_client(kwargs) | ||
| parsed_remote = parse_attachment_url(path, clean_key=False) | ||
| bucket_name = parsed_remote.bucket | ||
| # The key from parse_attachment_url might be e.g. "weight_prefix/model_dir" | ||
| # so we append "/config.json" and build a local path to download it. | ||
| key_with_config = os.path.join(parsed_remote.key, "config.json") | ||
|
|
||
| bucket = gcs.bucket(bucket_name) | ||
| blob = bucket.blob(key_with_config) | ||
|
|
||
| # Download to a tmp path and load | ||
| filepath = os.path.join("/tmp", key_with_config.replace("/", "_")) | ||
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
| blob.download_to_filename(filepath) | ||
|
|
||
| with open(filepath, "r") as f: | ||
| return json.load(f) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: I think this is only really used for some fine tuning apis that aren't really used at this point, think it's fine to keep ofc since you'll probably need to initialize dependencies anyways, but this code probably won't really get exercised at all