diff --git a/projects/pgai/pgai/vectorizer/loading.py b/projects/pgai/pgai/vectorizer/loading.py index bf8369d4a..a51d4a660 100644 --- a/projects/pgai/pgai/vectorizer/loading.py +++ b/projects/pgai/pgai/vectorizer/loading.py @@ -3,9 +3,12 @@ from io import BytesIO from typing import Any, Literal +import structlog from filetype import filetype # type: ignore from pydantic import BaseModel +logger = structlog.get_logger() + @dataclass class LoadedDocument: @@ -56,6 +59,7 @@ def load(self, row: dict[str, str]) -> LoadedDocument: file_path = row[self.column_name] transport_params = None + s3_resource = None if file_path.startswith("s3://") and self.aws_role_arn is not None: external_id = os.getenv("AWS_ASSUME_ROLE_EXTERNAL_ID") sts_client: STSClient = boto3.client("sts") # type: ignore @@ -80,11 +84,17 @@ def load(self, row: dict[str, str]) -> LoadedDocument: # Create an S3 client using the session with assumed role s3_client: S3Client = session.client("s3") # type: ignore transport_params = {"client": s3_client} - content = BytesIO( - smart_open.open( # type: ignore - file_path, "rb", transport_params=transport_params - ).read() - ) + s3_resource = session.resource("s3") # type: ignore + if file_path.startswith("s3://") and not self.aws_role_arn: + import boto3 + + s3_resource = boto3.resource("s3") # type: ignore + file = smart_open.open(file_path, "rb", transport_params=transport_params) # type: ignore + if s3_resource is not None: + size = file.to_boto3(s3_resource).content_length # type: ignore + logger.info(f"Preparing to download file {file_path}, size: {size} bytes") + + content = BytesIO(file.read()) # type: ignore return LoadedDocument( content=content, file_path=file_path,