diff --git a/components/clp-py-utils/clp_py_utils/clp_config.py b/components/clp-py-utils/clp_py_utils/clp_config.py index 0da114232b..892e932fb1 100644 --- a/components/clp-py-utils/clp_py_utils/clp_config.py +++ b/components/clp-py-utils/clp_py_utils/clp_config.py @@ -1,5 +1,6 @@ import os import pathlib +from datetime import datetime, timedelta, timezone from enum import auto from typing import Annotated, Any, ClassVar, Literal, Optional, Union @@ -11,6 +12,7 @@ model_validator, PlainSerializer, PrivateAttr, + SecretStr, ) from strenum import KebabCaseStrEnum, LowercaseStrEnum @@ -416,6 +418,100 @@ def validate_authentication(cls, data): return data +class AwsCredential(BaseModel): + """ + Represents a stored AWS credential retrieved from the database. + + This model is used for credentials that are persisted in the `aws_credentials` table. + Credentials can be either static (access key + secret key) or configured for role + assumption. + """ + + id: int + name: Annotated[ + str, + Field( + min_length=1, + max_length=255, + pattern=r"^[a-zA-Z0-9_-]+$", + description="Credential name (alphanumeric, hyphens, underscores only; 1-255 characters)", + ), + ] + + access_key_id: SecretStr + secret_access_key: SecretStr + role_arn: str | None = None + + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_s3_credentials(self) -> S3Credentials: + """ + Converts to `S3Credentials` for use with boto3. + + Note: This only works for static credentials. For temporary credentials + with session tokens, use the `TemporaryCredential` model instead. + + :return: `S3Credentials` object with secrets revealed. + """ + return S3Credentials( + access_key_id=self.access_key_id.get_secret_value(), + secret_access_key=self.secret_access_key.get_secret_value(), + session_token=None, + ) + + +class TemporaryCredential(BaseModel): + """ + Represents cached temporary credentials (session tokens). + + This model is used for credentials cached in the `aws_temporary_credentials` table. + These credentials can come from various sources: + - STS AssumeRole operations + - Resource-specific session tokens + + The `source` field tracks the origin of the session token, which can be: + - A role ARN: "arn:aws:iam::123456789012:role/MyRole" + - An S3 resource ARN: "arn:aws:s3:::bucket/path/*" + """ + + id: int + long_term_key_id: int # Foreign key to aws_credentials table + access_key_id: SecretStr + secret_access_key: SecretStr + session_token: SecretStr + source: str # Role ARN or S3 resource ARN + expires_at: datetime + created_at: datetime + + def to_s3_credentials(self) -> S3Credentials: + """ + Converts to `S3Credentials` for use with boto3. + + :return: `S3Credentials` object with secrets revealed. + """ + return S3Credentials( + access_key_id=self.access_key_id.get_secret_value(), + secret_access_key=self.secret_access_key.get_secret_value(), + session_token=self.session_token.get_secret_value(), + ) + + def is_expired(self, buffer_minutes: int = 5) -> bool: + """ + Checks if credential is expired or expiring soon. + + :param buffer_minutes: Minutes of buffer before expiration to consider credential expired. + :return: True if expired or expiring within `buffer_minutes`. + """ + + now = datetime.now(timezone.utc) + exp = self.expires_at + if exp.tzinfo is None: + # Assume DB stores UTC; attach UTC tzinfo to compare safely. + exp = exp.replace(tzinfo=timezone.utc) + return now >= exp - timedelta(minutes=buffer_minutes) + + class S3Config(BaseModel): region_code: NonEmptyStr bucket: NonEmptyStr diff --git a/components/clp-py-utils/clp_py_utils/clp_metadata_db_utils.py b/components/clp-py-utils/clp_py_utils/clp_metadata_db_utils.py index 97acb2e9b6..4bc925f483 100644 --- a/components/clp-py-utils/clp_py_utils/clp_metadata_db_utils.py +++ b/components/clp-py-utils/clp_py_utils/clp_metadata_db_utils.py @@ -13,6 +13,8 @@ ARCHIVE_TAGS_TABLE_SUFFIX = "archive_tags" ARCHIVES_TABLE_SUFFIX = "archives" +AWS_CREDENTIALS_TABLE_SUFFIX = "aws_credentials" +AWS_TEMPORARY_CREDENTIALS_TABLE_SUFFIX = "aws_temporary_credentials" COLUMN_METADATA_TABLE_SUFFIX = "column_metadata" DATASETS_TABLE_SUFFIX = "datasets" FILES_TABLE_SUFFIX = "files" @@ -21,6 +23,8 @@ TABLE_SUFFIX_MAX_LEN = max( len(ARCHIVE_TAGS_TABLE_SUFFIX), len(ARCHIVES_TABLE_SUFFIX), + len(AWS_CREDENTIALS_TABLE_SUFFIX), + len(AWS_TEMPORARY_CREDENTIALS_TABLE_SUFFIX), len(COLUMN_METADATA_TABLE_SUFFIX), len(DATASETS_TABLE_SUFFIX), len(FILES_TABLE_SUFFIX), @@ -110,6 +114,47 @@ def _create_column_metadata_table(db_cursor, table_prefix: str, dataset: str) -> ) +def _create_aws_credentials_table(db_cursor, aws_credentials_table_name: str) -> None: + db_cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS `{aws_credentials_table_name}` ( + `id` INT NOT NULL AUTO_INCREMENT, + `name` VARCHAR(255) NOT NULL UNIQUE, + `access_key_id` VARCHAR(255) NOT NULL, + `secret_access_key` VARCHAR(255) NOT NULL, + `role_arn` VARCHAR(2048), + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`) + ) ROW_FORMAT=DYNAMIC + """ + ) + + +def _create_aws_temporary_credentials_table( + db_cursor, aws_temporary_credentials_table_name: str, aws_credentials_table_name: str +) -> None: + db_cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS `{aws_temporary_credentials_table_name}` ( + `id` INT NOT NULL AUTO_INCREMENT, + `long_term_key_id` INT NOT NULL, + `access_key_id` VARCHAR(255) NOT NULL, + `secret_access_key` VARCHAR(255) NOT NULL, + `session_token` VARCHAR(2048) NOT NULL, + `source` VARCHAR(2048) NOT NULL, + `expires_at` DATETIME NOT NULL, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + KEY `long_term_key_expires` (`long_term_key_id`, `expires_at`), + KEY `source_expires` (`source`(512), `expires_at`), + FOREIGN KEY (`long_term_key_id`) REFERENCES `{aws_credentials_table_name}` (`id`) + ON DELETE CASCADE + ) ROW_FORMAT=DYNAMIC + """ + ) + + def _get_table_name(prefix: str, suffix: str, dataset: str | None) -> str: """ :param prefix: @@ -145,6 +190,34 @@ def create_datasets_table(db_cursor, table_prefix: str) -> None: ) +def create_aws_credentials_table(db_cursor, table_prefix: str) -> None: + """ + Creates the AWS credentials table for storing user-managed static credentials. + + :param db_cursor: The database cursor to execute the table creation. + :param table_prefix: A string to prepend to the table name. + """ + aws_credentials_table_name = get_aws_credentials_table_name(table_prefix) + _create_aws_credentials_table(db_cursor, aws_credentials_table_name) + + +def create_aws_temporary_credentials_table(db_cursor, table_prefix: str) -> None: + """ + Creates the AWS temporary credentials table for storing cached session tokens. + + This table caches session tokens from various sources (user-provided, role assumption, etc.) + to enable efficient credential reuse. It references the aws_credentials table via foreign key. + + :param db_cursor: The database cursor to execute the table creation. + :param table_prefix: A string to prepend to the table name. + """ + aws_credentials_table_name = get_aws_credentials_table_name(table_prefix) + aws_temporary_credentials_table_name = get_aws_temporary_credentials_table_name(table_prefix) + _create_aws_temporary_credentials_table( + db_cursor, aws_temporary_credentials_table_name, aws_credentials_table_name + ) + + def add_dataset( db_conn, db_cursor, @@ -300,6 +373,14 @@ def get_archives_table_name(table_prefix: str, dataset: str | None) -> str: return _get_table_name(table_prefix, ARCHIVES_TABLE_SUFFIX, dataset) +def get_aws_credentials_table_name(table_prefix: str) -> str: + return _get_table_name(table_prefix, AWS_CREDENTIALS_TABLE_SUFFIX, None) + + +def get_aws_temporary_credentials_table_name(table_prefix: str) -> str: + return _get_table_name(table_prefix, AWS_TEMPORARY_CREDENTIALS_TABLE_SUFFIX, None) + + def get_column_metadata_table_name(table_prefix: str, dataset: str | None) -> str: return _get_table_name(table_prefix, COLUMN_METADATA_TABLE_SUFFIX, dataset) diff --git a/components/clp-py-utils/clp_py_utils/initialize-clp-metadata-db.py b/components/clp-py-utils/clp_py_utils/initialize-clp-metadata-db.py index dd9b2dc29f..b70141e5a4 100644 --- a/components/clp-py-utils/clp_py_utils/initialize-clp-metadata-db.py +++ b/components/clp-py-utils/clp_py_utils/initialize-clp-metadata-db.py @@ -13,6 +13,8 @@ StorageEngine, ) from clp_py_utils.clp_metadata_db_utils import ( + create_aws_credentials_table, + create_aws_temporary_credentials_table, create_datasets_table, create_metadata_db_tables, ) @@ -61,6 +63,9 @@ def main(argv): with closing(sql_adapter.create_connection(True)) as metadata_db, closing( metadata_db.cursor(dictionary=True) ) as metadata_db_cursor: + create_aws_credentials_table(metadata_db_cursor, table_prefix) + create_aws_temporary_credentials_table(metadata_db_cursor, table_prefix) + if StorageEngine.CLP_S == storage_engine: create_datasets_table(metadata_db_cursor, table_prefix) else: diff --git a/components/clp-py-utils/clp_py_utils/s3_credential_manager.py b/components/clp-py-utils/clp_py_utils/s3_credential_manager.py new file mode 100644 index 0000000000..c40e885ec7 --- /dev/null +++ b/components/clp-py-utils/clp_py_utils/s3_credential_manager.py @@ -0,0 +1,474 @@ +import re +from contextlib import closing +from datetime import datetime + +from pydantic import SecretStr +from sql_adapter import SQL_Adapter + +from clp_py_utils.clp_config import ( + AwsCredential, + CLP_METADATA_TABLE_PREFIX, + Database, + TemporaryCredential, +) +from clp_py_utils.clp_logging import get_logger +from clp_py_utils.clp_metadata_db_utils import ( + get_aws_credentials_table_name, + get_aws_temporary_credentials_table_name, +) + +logger = get_logger(__name__) + + +class S3CredentialManager: + """Manages AWS S3 credentials in the database.""" + + def __init__(self, database_config: Database): + """ + Initializes the credential manager. + + :param database_config: Database configuration for connecting to CLP metadata database. + """ + self.sql_adapter = SQL_Adapter(database_config) + database_config.ensure_credentials_loaded() + conn_params = database_config.get_clp_connection_params_and_type() + self.table_prefix = conn_params.get("table_prefix", CLP_METADATA_TABLE_PREFIX) + + def create_credential( + self, + name: str, + access_key_id: str, + secret_access_key: str, + role_arn: str | None = None, + ) -> int: + """ + Creates a new AWS long-term credential entry. + + :param name: + :param access_key_id: + :param secret_access_key: + :param role_arn: IAM role to assume when using this credential, if any. + :return: The ID of the created credential. + :raises ValueError: If validation fails or if a credential with `name` already exists. + """ + # Validate inputs + self._validate_credential_name(name) + self._validate_access_key_id(access_key_id) + self._validate_secret_access_key(secret_access_key) + + table_name = get_aws_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + # Check for duplicate name + cursor.execute(f"SELECT id FROM `{table_name}` WHERE name = %s", (name,)) + if cursor.fetchone(): + raise ValueError(f"Credential with name '{name}' already exists") + + cursor.execute( + f""" + INSERT INTO `{table_name}` + (name, access_key_id, secret_access_key, role_arn) + VALUES (%s, %s, %s, %s) + """, + (name, access_key_id, secret_access_key, role_arn), + ) + db_conn.commit() + + credential_id = cursor.lastrowid + logger.info("Created credential '%s' with ID %s", name, credential_id) + return credential_id + + def list_credentials(self) -> list[tuple[int, str, datetime]]: + """ + Lists all credential entries (metadata only, no secrets). + + :return: List of tuples containing: + - The credential id. + - The credential name. + - The time the credential was created at. + """ + table_name = get_aws_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + cursor.execute( + f""" + SELECT id, name, created_at + FROM `{table_name}` + ORDER BY name + """ + ) + rows = cursor.fetchall() + return [(row["id"], row["name"], row["created_at"]) for row in rows] + + def get_credential_by_id(self, credential_id: int) -> AwsCredential | None: + """ + Retrieves a long-term credential by ID (includes secrets). + + :param credential_id: + :return: `AwsCredential` object or None if not found. + """ + table_name = get_aws_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + cursor.execute( + f""" + SELECT id, name, access_key_id, secret_access_key, + role_arn, created_at, updated_at + FROM `{table_name}` + WHERE id = %s + """, + (credential_id,), + ) + row = cursor.fetchone() + + if not row: + return None + + return AwsCredential( + id=row["id"], + name=row["name"], + access_key_id=SecretStr(row["access_key_id"]), + secret_access_key=SecretStr(row["secret_access_key"]), + role_arn=row["role_arn"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + def get_credential_by_name(self, name: str) -> AwsCredential | None: + """ + Retrieves a long-term credential by name (includes secrets). + + :param name: + :return: `AwsCredential` object or None if not found. + """ + table_name = get_aws_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + cursor.execute( + f""" + SELECT id, name, access_key_id, secret_access_key, + role_arn, created_at, updated_at + FROM `{table_name}` + WHERE name = %s + """, + (name,), + ) + row = cursor.fetchone() + + if not row: + return None + + return AwsCredential( + id=row["id"], + name=row["name"], + access_key_id=SecretStr(row["access_key_id"]), + secret_access_key=SecretStr(row["secret_access_key"]), + role_arn=row["role_arn"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + def update_credential( + self, + credential_id: int, + name: str | None = None, + access_key_id: str | None = None, + secret_access_key: str | None = None, + role_arn: str | None = None, + ) -> bool: + """ + Updates an existing long-term credential. Only provided fields are updated. + + :param credential_id: + :param name: + :param access_key_id: + :param secret_access_key: + :param role_arn: Replacement IAM role to associate; use empty string to clear. + :return: True if updated, False if credential not found. + :raises ValueError: If `name` conflicts or if validation fails. + """ + if all(v is None for v in [name, access_key_id, secret_access_key, role_arn]): + raise ValueError("At least one field must be specified for update") + + table_name = get_aws_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + # Check if credential exists + cursor.execute(f"SELECT id FROM `{table_name}` WHERE id = %s", (credential_id,)) + if not cursor.fetchone(): + return False + + # Validate updates + if name is not None: + self._validate_credential_name(name) + # Check name uniqueness + cursor.execute( + f"SELECT id FROM `{table_name}` WHERE name = %s AND id != %s", + (name, credential_id), + ) + if cursor.fetchone(): + raise ValueError(f"Credential with name '{name}' already exists") + + if access_key_id is not None: + self._validate_access_key_id(access_key_id) + if secret_access_key is not None: + self._validate_secret_access_key(secret_access_key) + + # Build UPDATE query dynamically + update_fields = [] + params = [] + + if name is not None: + update_fields.append("name = %s") + params.append(name) + if access_key_id is not None: + update_fields.append("access_key_id = %s") + params.append(access_key_id) + if secret_access_key is not None: + update_fields.append("secret_access_key = %s") + params.append(secret_access_key) + if role_arn is not None: + update_fields.append("role_arn = %s") + params.append(role_arn if role_arn != "" else None) + + params.append(credential_id) + + cursor.execute( + f""" + UPDATE `{table_name}` + SET {', '.join(update_fields)} + WHERE id = %s + """, + params, + ) + db_conn.commit() + + logger.info("Updated credential ID %s", credential_id) + return True + + def delete_credential(self, credential_id: int) -> bool: + """ + Deletes a credential by ID. + + The foreign key constraint with `ON DELETE CASCADE` will automatically delete any cached + temporary credentials associated with this credential. + + :param credential_id: + :return: True if deleted, False if not found. + """ + table_name = get_aws_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + cursor.execute(f"DELETE FROM `{table_name}` WHERE id = %s", (credential_id,)) + deleted = cursor.rowcount > 0 + db_conn.commit() + + if deleted: + logger.info("Deleted credential ID %s", credential_id) + else: + logger.warning(f"Credential ID {credential_id} not found for deletion") + + return deleted + + # Temporary credential cache methods + + def cache_session_token( + self, + long_term_key_id: int, + access_key_id: str, + secret_access_key: str, + session_token: str, + source: str, + expires_at: datetime, + ) -> int: + """ + Caches a session token in the temporary credentials table. + + :param long_term_key_id: + :param access_key_id: + :param secret_access_key: + :param session_token: + :param source: Origin marker describing how the token was generated, such as a role ARN. + :param expires_at: + :return: + :raises ValueError: If validation fails. + """ + if not access_key_id or not secret_access_key or not session_token: + raise ValueError("Temporary credentials cannot have empty fields") + + if not source: + raise ValueError("Source cannot be empty") + + if len(source) > 2048: + raise ValueError("Source cannot exceed 2048 characters") + + table_name = get_aws_temporary_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + cursor.execute( + f""" + INSERT INTO `{table_name}` + (long_term_key_id, access_key_id, secret_access_key, session_token, source, expires_at) + VALUES (%s, %s, %s, %s, %s, %s) + """, + ( + long_term_key_id, + access_key_id, + secret_access_key, + session_token, + source, + expires_at, + ), + ) + db_conn.commit() + + credential_id = cursor.lastrowid + logger.info( + "Cached session token for source '%s' with ID '%s', expires at %s", + source, + credential_id, + expires_at, + ) + return credential_id + + def get_cached_session_token_by_source( + self, source: str, long_term_key_id: int | None = None + ) -> TemporaryCredential | None: + """ + Retrieves a valid (non-expired) cached session token by source. + + Looks up session tokens by source ARN (role or S3 resource). Optionally filters + by `long_term_key_id` for more specific lookups. + + :param source: Source identifier (role ARN or S3 resource ARN). + :param long_term_key_id: Optional filter by associated long-term credential ID. + :return: `TemporaryCredential` object or None if not found or expired. + """ + table_name = get_aws_temporary_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + # Build query with optional long_term_key_id filter + query = f""" + SELECT id, long_term_key_id, access_key_id, secret_access_key, + session_token, source, expires_at, created_at + FROM `{table_name}` + WHERE source = %s AND expires_at > NOW() + """ + params = [source] + + if long_term_key_id is not None: + query += " AND long_term_key_id = %s" + params.append(long_term_key_id) + + query += " ORDER BY expires_at DESC LIMIT 1" + + cursor.execute(query, params) + row = cursor.fetchone() + + if not row: + return None + + return TemporaryCredential( + id=row["id"], + long_term_key_id=row["long_term_key_id"], + access_key_id=SecretStr(row["access_key_id"]), + secret_access_key=SecretStr(row["secret_access_key"]), + session_token=SecretStr(row["session_token"]), + source=row["source"], + expires_at=row["expires_at"], + created_at=row["created_at"], + ) + + def cleanup_expired_session_tokens(self) -> int: + """ + Deletes expired session tokens from the cache. + + :return: Count of deleted credentials. + """ + table_name = get_aws_temporary_credentials_table_name(self.table_prefix) + + with ( + closing(self.sql_adapter.create_connection(True)) as db_conn, + closing(db_conn.cursor(dictionary=True)) as cursor, + ): + cursor.execute( + f""" + DELETE FROM `{table_name}` + WHERE expires_at < NOW() + """ + ) + deleted_count = cursor.rowcount + db_conn.commit() + + if deleted_count > 0: + logger.info("Cleaned up %s expired session token(s)", deleted_count) + else: + logger.debug("No expired session tokens to clean up") + + return deleted_count + + def _validate_credential_name(self, name: str) -> None: + """ + Validates credential name. + + :param name: + :raises ValueError: If `name` is invalid. + """ + if not name or not name.strip(): + raise ValueError("Credential name cannot be empty") + + if len(name) > 255: + raise ValueError("Credential name cannot exceed 255 characters") + + if " " in name: + raise ValueError("Credential name cannot contain spaces") + + # Check for special characters (allow alphanumeric, hyphens, underscores) + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError( + "Credential name can only contain alphanumeric characters, hyphens, and underscores" + ) + + def _validate_access_key_id(self, access_key_id: str) -> None: + """ + Validates AWS access key ID format. + + :param access_key_id: + :raises ValueError: If `access_key_id` is invalid. + """ + if not access_key_id or not access_key_id.strip(): + raise ValueError("Access key ID cannot be empty") + + def _validate_secret_access_key(self, secret_access_key: str) -> None: + """ + Validates AWS secret access key. + + :param secret_access_key: + :raises ValueError: If `secret_access_key` is invalid. + """ + if not secret_access_key or not secret_access_key.strip(): + raise ValueError("Secret access key cannot be empty")