diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 61c80607c..3b4fda420 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,7 +7,8 @@ https://docs.snowflake.com/ Source code is also available at: https://github.com/snowflakedb/snowflake-connector-python # Release Notes -- v3.16.1(TBD) +- v3.17(TBD) + - Removed boto and botocore dependencies. - Added in-band OCSP exception telemetry. - Added `APPLICATION_PATH` within `CLIENT_ENVIRONMENT` to distinguish between multiple scripts using the PythonConnector in the same environment. - Disabled token caching for OAuth Client Credentials authentication diff --git a/setup.cfg b/setup.cfg index 2fd14a0ea..ab2fc4e5a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,8 +44,6 @@ python_requires = >=3.9 packages = find_namespace: install_requires = asn1crypto>0.24.0,<2.0.0 - boto3>=1.24 - botocore>=1.24 cffi>=1.9,<2.0.0 cryptography>=3.1.0 pyOpenSSL>=22.0.0,<26.0.0 @@ -92,6 +90,8 @@ development = pytest-timeout pytest-xdist pytzdata + botocore + boto3 pandas = pandas>=2.1.2,<3.0.0 pyarrow<19.0.0 diff --git a/src/snowflake/connector/_aws_credentials.py b/src/snowflake/connector/_aws_credentials.py new file mode 100644 index 000000000..88b2032a3 --- /dev/null +++ b/src/snowflake/connector/_aws_credentials.py @@ -0,0 +1,136 @@ +""" +Lightweight AWS credential resolution without boto3. + +Resolves credentials in the order: environment → ECS/EKS task metadata → EC2 IMDSv2. +Returns a minimal `SfAWSCredentials` object that can be passed to SigV4 signing +helpers unchanged. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from functools import partial +from typing import Callable + +from .vendored import requests + +logger = logging.getLogger(__name__) + +_ECS_CRED_BASE_URL = "http://169.254.170.2" +_IMDS_BASE_URL = "http://169.254.169.254" +_IMDS_TOKEN_PATH = "/latest/api/token" +_IMDS_ROLE_PATH = "/latest/meta-data/iam/security-credentials/" +_IMDS_AZ_PATH = "/latest/meta-data/placement/availability-zone" + + +@dataclass +class SfAWSCredentials: + """Minimal stand-in for ``botocore.credentials.Credentials``.""" + + access_key: str + secret_key: str + token: str | None = None + + +def get_env_credentials() -> SfAWSCredentials | None: + key, secret = os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY") + if key and secret: + return SfAWSCredentials(key, secret, os.getenv("AWS_SESSION_TOKEN")) + return None + + +def get_container_credentials(*, timeout: float) -> SfAWSCredentials | None: + """Credentials from ECS/EKS task-metadata endpoint.""" + rel_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + full_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") + if not rel_uri and not full_uri: + return None + + url = full_uri or f"{_ECS_CRED_BASE_URL}{rel_uri}" + try: + response = requests.get(url, timeout=timeout) + if response.ok: + data = response.json() + return SfAWSCredentials( + data["AccessKeyId"], data["SecretAccessKey"], data.get("Token") + ) + except (requests.Timeout, requests.ConnectionError, ValueError) as exc: + logger.debug("ECS credential fetch failed: %s", exc, exc_info=True) + return None + + +def _get_imds_v2_token(timeout: float) -> str | None: + try: + response = requests.request( + "PUT", + f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + timeout=timeout, + ) + return response.text if response.ok else None + except (requests.Timeout, requests.ConnectionError): + return None + + +def get_imds_credentials(*, timeout: float) -> SfAWSCredentials | None: + """Instance-profile credentials from the EC2 metadata service.""" + token = _get_imds_v2_token(timeout) + headers = {"X-aws-ec2-metadata-token": token} if token else {} + + try: + role_resp = requests.get( + f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}", headers=headers, timeout=timeout + ) + if not role_resp.ok: + return None + role_name = role_resp.text.strip() + + cred_resp = requests.get( + f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}{role_name}", + headers=headers, + timeout=timeout, + ) + if cred_resp.ok: + data = cred_resp.json() + return SfAWSCredentials( + data["AccessKeyId"], data["SecretAccessKey"], data.get("Token") + ) + except (requests.Timeout, requests.ConnectionError, ValueError) as exc: + logger.debug("IMDS credential fetch failed: %s", exc, exc_info=True) + return None + + +def load_default_credentials(timeout: float = 2.0) -> SfAWSCredentials | None: + """Resolve credentials using the default AWS chain (env → task → IMDS).""" + providers: tuple[Callable[[], SfAWSCredentials | None], ...] = ( + get_env_credentials, + partial(get_container_credentials, timeout=timeout), + partial(get_imds_credentials, timeout=timeout), + ) + for try_fetch_credentials in providers: + credentials = try_fetch_credentials() + if credentials: + return credentials + return None + + +def get_region(timeout: float = 1.0) -> str | None: + """Return the current AWS region if it can be discovered.""" + if region := os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION"): + return region + + token = _get_imds_v2_token(timeout) + headers = {"X-aws-ec2-metadata-token": token} if token else {} + try: + response = requests.request( + "GET", f"{_IMDS_BASE_URL}{_IMDS_AZ_PATH}", headers=headers, timeout=timeout + ) + if response.ok: + az = response.text.strip() + return az[:-1] if az and az[-1].isalpha() else None + except (requests.Timeout, requests.ConnectionError) as exc: + logger.debug("IMDS region lookup failed: %s", exc, exc_info=True) + + return None diff --git a/src/snowflake/connector/_aws_sign_v4.py b/src/snowflake/connector/_aws_sign_v4.py new file mode 100644 index 000000000..4210aeba9 --- /dev/null +++ b/src/snowflake/connector/_aws_sign_v4.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import datetime +import hashlib +import hmac +import urllib.parse as urlparse + +_ALGORITHM: str = "AWS4-HMAC-SHA256" +_EMPTY_PAYLOAD_SHA256: str = hashlib.sha256(b"").hexdigest() +_SAFE_CHARS: str = "-_.~" + + +def _sign(key: bytes, msg: str) -> bytes: + """Return an HMAC-SHA256 of *msg* keyed with *key*.""" + return hmac.new(key, msg.encode(), hashlib.sha256).digest() + + +def _canonical_query_string(query: str) -> str: + """Return the query string in canonical (sorted & URL-escaped) form.""" + pairs = urlparse.parse_qsl(query, keep_blank_values=True) + pairs.sort() + return "&".join( + f"{urlparse.quote(k, _SAFE_CHARS)}={urlparse.quote(v, _SAFE_CHARS)}" + for k, v in pairs + ) + + +def sign_get_caller_identity( + url: str, + region: str, + access_key: str, + secret_key: str, + session_token: str | None = None, +) -> dict[str, str]: + """ + Return the SigV4 headers needed for a presigned POST to AWS STS + `GetCallerIdentity`. + + Parameters: + + url + The full STS endpoint with query parameters + (e.g. ``https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15``) + region + The AWS region used for signing (``us-east-1``, ``us-gov-west-1`` …). + access_key + AWS access-key ID. + secret_key + AWS secret-access key. + session_token + (Optional) session token for temporary credentials. + """ + timestamp = datetime.datetime.utcnow() + amz_date = timestamp.strftime("%Y%m%dT%H%M%SZ") + short_date = timestamp.strftime("%Y%m%d") + service = "sts" + + parsed = urlparse.urlparse(url) + + headers: dict[str, str] = { + "host": parsed.netloc.lower(), + "x-amz-date": amz_date, + "x-snowflake-audience": "snowflakecomputing.com", + } + if session_token: + headers["x-amz-security-token"] = session_token + + # Canonical request + signed_headers = ";".join(sorted(headers)) # e.g. host;x-amz-date;... + canonical_request = "\n".join( + ( + "POST", + urlparse.quote(parsed.path or "/", safe="/"), + _canonical_query_string(parsed.query), + "".join(f"{k}:{headers[k]}\n" for k in sorted(headers)), + signed_headers, + _EMPTY_PAYLOAD_SHA256, + ) + ) + canonical_request_hash = hashlib.sha256(canonical_request.encode()).hexdigest() + + # String to sign + credential_scope = f"{short_date}/{region}/{service}/aws4_request" + string_to_sign = "\n".join( + (_ALGORITHM, amz_date, credential_scope, canonical_request_hash) + ) + + # Signature + key_date = _sign(("AWS4" + secret_key).encode(), short_date) + key_region = _sign(key_date, region) + key_service = _sign(key_region, service) + key_signing = _sign(key_service, "aws4_request") + signature = hmac.new( + key_signing, string_to_sign.encode(), hashlib.sha256 + ).hexdigest() + + # Final Authorization header + headers["authorization"] = ( + f"{_ALGORITHM} " + f"Credential={access_key}/{credential_scope}, " + f"SignedHeaders={signed_headers}, " + f"Signature={signature}" + ) + + return headers diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 3449cdd5e..9a624d27d 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -1,18 +1,26 @@ from __future__ import annotations +"""Workload‑identity attestation helpers. + +This module builds the attestation token that the Snowflake Python connector +sends when Authenticating with Workload Identity Federation (WIF). +It supports AWS, Azure, GCP and generic OIDC environments without pulling +in heavy SDKs such as botocore – we only need a small presigned STS request +for AWS and a couple of metadata‑server calls for Azure / GCP. +""" + import json import logging import os from base64 import b64encode from dataclasses import dataclass from enum import Enum, unique +from typing import Any -import boto3 import jwt -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest -from botocore.utils import InstanceMetadataRegionFetcher +from ._aws_credentials import get_region, load_default_credentials +from ._aws_sign_v4 import sign_get_caller_identity from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError from .vendored import requests @@ -51,36 +59,36 @@ class AttestationProvider(Enum): @staticmethod def from_string(provider: str) -> AttestationProvider: - """Converts a string to a strongly-typed enum value of AttestationProvider.""" + """Converts a string to a strongly-typed enum value of :class:`AttestationProvider`.""" return AttestationProvider[provider.upper()] @dataclass class WorkloadIdentityAttestation: provider: AttestationProvider - credential: str - user_identifier_components: dict + credential: str # base64 JSON blob – provider‑specific + user_identifier_components: dict[str, Any] def try_metadata_service_call( - method: str, url: str, headers: dict, timeout_sec: int = 3 + method: str, url: str, headers: dict[str, str], *, timeout: int = 3 ) -> Response | None: - """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. + """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout in seconds. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: res: Response = requests.request( - method=method, url=url, headers=headers, timeout=timeout_sec + method=method, url=url, headers=headers, timeout=timeout ) - if not res.ok: - return None + return res if res.ok else None except requests.RequestException: return None - return res -def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]: +def extract_iss_and_sub_without_signature_verification( + jwt_str: str, +) -> tuple[str | None, str | None]: """Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature. Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have @@ -92,141 +100,128 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st If there are any errors in parsing the token or extracting iss and sub, this will return (None, None). """ - try: - claims = jwt.decode(jwt_str, options={"verify_signature": False}) - except jwt.exceptions.InvalidTokenError: - logger.warning("Token is not a valid JWT.", exc_info=True) + claims = _decode_jwt_without_validation(jwt_str) + if claims is None: return None, None - if not ("iss" in claims and "sub" in claims): + if "iss" not in claims or "sub" not in claims: logger.warning("Token is missing 'iss' or 'sub' claims.") return None, None return claims["iss"], claims["sub"] -def get_aws_region() -> str | None: - """Get the current AWS workload's region, if any.""" - if "AWS_REGION" in os.environ: # Lambda - return os.environ["AWS_REGION"] - else: # EC2 - return InstanceMetadataRegionFetcher().retrieve_region() +def _decode_jwt_without_validation(token: str) -> Any: + """Helper that decodes *token* with ``verify_signature=False``.:contentReference[oaicite:1]{index=1}""" + try: + return jwt.decode(token, options={"verify_signature": False}) + except jwt.exceptions.InvalidTokenError: + logger.warning("Token is not a valid JWT.", exc_info=True) + return None -def get_aws_arn() -> str | None: - """Get the current AWS workload's ARN, if any.""" - caller_identity = boto3.client("sts").get_caller_identity() - if not caller_identity or "Arn" not in caller_identity: - return None - return caller_identity["Arn"] +class AWSPartition(str, Enum): + BASE = "aws" + CHINA = "aws-cn" + GOV = "aws-us-gov" -def get_aws_partition(arn: str) -> str | None: - """Get the current AWS partition from ARN, if any. +def _partition_from_region(region: str) -> AWSPartition: + if region.startswith("cn-"): + return AWSPartition.CHINA + if region.startswith("us-gov-"): + return AWSPartition.GOV + return AWSPartition.BASE - Args: - arn (str): The Amazon Resource Name (ARN) string. - Returns: - str | None: The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov') - if found, otherwise None. +def _sts_host_from_region(region: str) -> str | None: + """ + Construct the STS endpoint hostname for region according to the + regionalised-STS rules published by AWS.:contentReference[oaicite:2]{index=2} - Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html. + References: + - https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html + - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html + - https://docs.aws.amazon.com/general/latest/gr/sts.html """ - if not arn or not isinstance(arn, str): + if not region or not isinstance(region, str): return None - parts = arn.split(":") - if len(parts) > 1 and parts[0] == "arn" and parts[1]: - return parts[1] - logger.warning("Invalid AWS ARN: %s", arn) - return None + part = _partition_from_region(region) + suffix = ".amazonaws.com.cn" if part is AWSPartition.CHINA else ".amazonaws.com" + return f"sts.{region}{suffix}" -def get_aws_sts_hostname(region: str, partition: str) -> str | None: - """Constructs the AWS STS hostname for a given region and partition. - Args: - region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1'). - partition (str): The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov'). +def _try_get_arn_from_env_vars() -> str | None: + """Try to get ARN already exposed by the runtime (no extra network I/O). - Returns: - str | None: The AWS STS hostname (e.g., 'sts.us-east-1.amazonaws.com') - if a valid hostname can be constructed, otherwise None. - - References: - - https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html - - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html - - https://docs.aws.amazon.com/general/latest/gr/sts.html + • `AWS_ROLE_ARN` – web-identity / many FaaS runtimes + • `AWS_EC2_METADATA_ARN` – some IMDSv2 environments + • `AWS_SESSION_ARN` – recent AWS SDKs export this when assuming a role """ - if ( - not region - or not partition - or not isinstance(region, str) - or not isinstance(partition, str) + for possible_arn_env_var in ( + "AWS_ROLE_ARN", + "AWS_EC2_METADATA_ARN", + "AWS_SESSION_ARN", ): - return None + value = os.getenv(possible_arn_env_var) + if value and value.startswith("arn:"): + return value + return None - if partition == "aws": - # For the 'aws' partition, STS endpoints are generally regional - # except for the global endpoint (sts.amazonaws.com) which is - # generally resolved to us-east-1 under the hood by the SDKs - # when a region is not explicitly specified. - # However, for explicit regional endpoints, the format is sts..amazonaws.com - return f"sts.{region}.amazonaws.com" - elif partition == "aws-cn": - # China regions have a different domain suffix - return f"sts.{region}.amazonaws.com.cn" - elif partition == "aws-us-gov": - return ( - f"sts.{region}.amazonaws.com" # GovCloud uses .com, but dedicated regions - ) - else: - logger.warning("Invalid AWS partition: %s", partition) - return None + +def try_compose_aws_user_identifier(region: str | None = None) -> dict[str, str]: + """Return an identifier for the running AWS workload. + + Always includes the AWS region; adds an *arn* key only if one is already + discoverable via common environment variables. Returns {} only if + the region cannot be determined.""" + region = region or get_region() + if not region: + return {} + + identifier: dict[str, str] = {"region": region} + + if arn := _try_get_arn_from_env_vars(): + identifier["arn"] = arn + + return identifier def create_aws_attestation() -> WorkloadIdentityAttestation | None: - """Tries to create a workload identity attestation for AWS. + """Return AWS attestation or None if we're not on AWS / creds missing.""" - If the application isn't running on AWS or no credentials were found, returns None. - """ - aws_creds = boto3.session.Session().get_credentials() - if not aws_creds: - logger.debug("No AWS credentials were found.") + creds = load_default_credentials() + if not creds: + logger.debug("No AWS credentials available.") return None - region = get_aws_region() + + region = get_region() if not region: - logger.debug("No AWS region was found.") - return None - arn = get_aws_arn() - if not arn: - logger.debug("No AWS caller identity was found.") - return None - partition = get_aws_partition(arn) - if not partition: - logger.debug("No AWS partition was found.") + logger.debug("AWS region could not be determined.") return None - sts_hostname = get_aws_sts_hostname(region, partition) - request = AWSRequest( - method="POST", - url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", - headers={ - "Host": sts_hostname, - "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, - }, + sts_url = ( + f"https://{_sts_host_from_region(region)}" + "/?Action=GetCallerIdentity&Version=2011-06-15" + ) + signed_headers = sign_get_caller_identity( + url=sts_url, + region=region, + access_key=creds.access_key, + secret_key=creds.secret_key, + session_token=creds.token, ) - SigV4Auth(aws_creds, "sts", region).add_auth(request) + attestation = b64encode( + json.dumps( + {"url": sts_url, "method": "POST", "headers": signed_headers} + ).encode() + ).decode() - assertion_dict = { - "url": request.url, - "method": request.method, - "headers": dict(request.headers.items()), - } - credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") + user_identifier = try_compose_aws_user_identifier(region) return WorkloadIdentityAttestation( - AttestationProvider.AWS, credential, {"arn": arn} + AttestationProvider.AWS, attestation, user_identifier ) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index ac3533616..332ee021c 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,29 +1,63 @@ -#!/usr/bin/env python -import datetime +from __future__ import annotations + +import contextlib import json import logging import os from abc import ABC, abstractmethod +from contextlib import ExitStack from time import time from unittest import mock from urllib.parse import parse_qs, urlparse import jwt -from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials -from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError +from snowflake.connector._aws_credentials import ( + _ECS_CRED_BASE_URL, + _IMDS_BASE_URL, + _IMDS_ROLE_PATH, + _IMDS_TOKEN_PATH, +) +from snowflake.connector.vendored.requests.exceptions import ConnectTimeout from snowflake.connector.vendored.requests.models import Response logger = logging.getLogger(__name__) +AZURE_VM_METADATA_HOST = "169.254.169.254" +AZURE_VM_TOKEN_PATH = "/metadata/identity/oauth2/token" + +AZURE_FUNCTION_IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" +AZURE_FUNCTION_IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" + +GCE_METADATA_HOST = "169.254.169.254" +GCE_IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" + +AWS_REGION_ENV_KEYS = ("AWS_REGION", "AWS_DEFAULT_REGION") +AWS_CONTAINER_CRED_ENV = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" +AWS_LAMBDA_FUNCTION_ENV = "AWS_LAMBDA_FUNCTION_NAME" + +HDR_IDENTITY = "X-IDENTITY-HEADER" +HDR_METADATA = "Metadata" +HDR_METADATA_FLAVOR = "Metadata-Flavor" + +AWS_CREDENTIAL_ENV_KEYS = ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_ROLE_ARN", + "AWS_EC2_METADATA_ARN", + "AWS_SESSION_ARN", +) + def gen_dummy_id_token( - sub="test-subject", iss="test-issuer", aud="snowflakecomputing.com" + sub: str = "test-subject", + iss: str = "test-issuer", + aud: str = "snowflakecomputing.com", ) -> str: - """Generates a dummy ID token using the given subject and issuer.""" + """Generates a dummy HS256-signed JWT.""" now = int(time()) - key = "secret" payload = { "sub": sub, "iss": iss, @@ -31,117 +65,146 @@ def gen_dummy_id_token( "iat": now, "exp": now + 60 * 60, } - logger.debug(f"Generating dummy token with the following claims:\n{str(payload)}") - return jwt.encode( - payload=payload, - key=key, - algorithm="HS256", - ) + logger.debug("Generating dummy token with claims %s", payload) + return jwt.encode(payload, key="secret", algorithm="HS256") -def build_response(content: bytes, status_code: int = 200) -> Response: - """Builds a requests.Response object with the given status code and content.""" - response = Response() - response.status_code = status_code - response._content = content - return response +def build_response( + content: bytes, + status_code: int = 200, + headers: dict[str, str] | None = None, +) -> Response: + """Return a minimal Response object with canned body/headers.""" + resp = Response() + resp.status_code = status_code + resp._content = content + if headers: + resp.headers.update(headers) + return resp class FakeMetadataService(ABC): - """Base class for fake metadata service implementations.""" + """Base class for cloud-metadata fakes.""" - def __init__(self): + def __init__(self) -> None: self.reset_defaults() + self._context_stack: ExitStack | None = None + + @staticmethod + def _clean_env_vars_for_scope() -> dict[str, str]: + """Return a mapping that blanks all AWS-specific env-vars. + + Used by Azure / GCP fakes so tests stay hermetic even when + executed inside a real AWS runner. + """ + return {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS} @abstractmethod - def reset_defaults(self): + def reset_defaults(self) -> None: """Resets any default values for test parameters. This is called in the constructor and when entering as a context manager. """ pass - @property @abstractmethod - def expected_hostname(self): - """Hostname at which this metadata service is listening. + def is_expected_hostname(self, host: str | None) -> bool: + """Returns true if the passed hostname is the one at which this metadata service is listening. Used to raise a ConnectTimeout for requests not targeted to this hostname. """ pass @abstractmethod - def handle_request(self, method, parsed_url, headers, timeout): + def handle_request( + self, + method, + parsed_url, + headers, + timeout, + ) -> Response: """Main business logic for handling this request. Should return a Response object.""" pass - def __call__(self, method, url, headers, timeout): - """Entry point for the requests mock.""" - logger.debug(f"Received request: {method} {url} {str(headers)}") - parsed_url = urlparse(url) + def __call__(self, method, url, headers=None, timeout=None, **_kw): + """Entry-point for the requests monkey-patch.""" + headers = headers or {} + parsed = urlparse(url) + logger.debug("FakeMetadataService received %s %s %s", method, url, headers) - if not parsed_url.hostname == self.expected_hostname: + if not self.is_expected_hostname(parsed.hostname): logger.debug( - f"Received request to unexpected hostname {parsed_url.hostname}" + "Received request to unexpected hostname %s – timeout", parsed.hostname ) raise ConnectTimeout() - return self.handle_request(method, parsed_url, headers, timeout) + return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): """Patches the relevant HTTP calls when entering as a context manager.""" self.reset_defaults() - self.patchers = [] - # requests.request is used by the direct metadata service API calls from our code. This is the main - # thing being faked here. - self.patchers.append( + self._context_stack = ExitStack() + self._context_stack.enter_context( + mock.patch( + "snowflake.connector.vendored.requests.request", + side_effect=self, + ) + ) + self._context_stack.enter_context( mock.patch( - "snowflake.connector.vendored.requests.request", side_effect=self + "snowflake.connector.vendored.requests.sessions.Session.request", + side_effect=self, ) ) - # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we - # simply raise a ConnectTimeout to avoid making real network calls. - self.patchers.append( + self._context_stack.enter_context( mock.patch( "urllib3.connection.HTTPConnection.request", side_effect=ConnectTimeout(), ) ) - for patcher in self.patchers: - patcher.__enter__() return self - def __exit__(self, *args, **kwargs): - for patcher in self.patchers: - patcher.__exit__(*args, **kwargs) + def __exit__(self, *exc): + self._context_stack.close() class NoMetadataService(FakeMetadataService): - """Emulates an environment without any metadata service.""" + """Always times out – simulates an environment without any metadata service.""" - def reset_defaults(self): + def reset_defaults(self) -> None: pass - @property - def expected_hostname(self): - return None # Always raise a ConnectTimeout. + def is_expected_hostname(self, host: str | None) -> bool: + return False - def handle_request(self, method, parsed_url, headers, timeout): + def handle_request(self, *_): # This should never be called because we always raise a ConnectTimeout. - pass + raise AssertionError( + "This should never be called because we always raise a ConnectTimeout." + ) class FakeAzureVmMetadataService(FakeMetadataService): """Emulates an environment with the Azure VM metadata service.""" - def reset_defaults(self): + def reset_defaults(self) -> None: # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - @property - def expected_hostname(self): - return "169.254.169.254" + def __enter__(self): + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + + def is_expected_hostname(self, host: str | None) -> bool: + return host == AZURE_VM_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -149,11 +212,11 @@ def handle_request(self, method, parsed_url, headers, timeout): # Reject malformed requests. if not ( method == "GET" - and parsed_url.path == "/metadata/identity/oauth2/token" - and headers.get("Metadata") == "True" - and query_string["resource"] + and parsed_url.path == AZURE_VM_TOKEN_PATH + and headers.get(HDR_METADATA, "").lower() == "true" # <-- patched + and query_string.get("resource") ): - raise HTTPError() + raise ConnectTimeout() logger.debug("Received request for Azure VM metadata service") @@ -165,18 +228,38 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeAzureFunctionMetadataService(FakeMetadataService): """Emulates an environment with the Azure Function metadata service.""" - def reset_defaults(self): - # Defaults used for generating an Entra ID token. Can be overriden in individual tests. + def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - - self.identity_endpoint = "http://169.254.255.2:8081/msi/token" - self.identity_header = "FD80F6DA783A4881BE9FAFA365F58E7A" + self.identity_endpoint = AZURE_FUNCTION_IDENTITY_ENDPOINT + self.identity_header = AZURE_FUNCTION_IDENTITY_HEADER self.parsed_identity_endpoint = urlparse(self.identity_endpoint) - @property - def expected_hostname(self): - return self.parsed_identity_endpoint.hostname + def __enter__(self): + self._stack = contextlib.ExitStack() + # Inject the variables without touching os.environ directly + self._stack.enter_context( + mock.patch.dict( + os.environ, + { + "IDENTITY_ENDPOINT": self.identity_endpoint, + "IDENTITY_HEADER": self.identity_header, + }, + clear=False, + ) + ) + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + + def is_expected_hostname(self, host: str | None) -> bool: + return host == self.parsed_identity_endpoint.hostname def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -185,43 +268,43 @@ def handle_request(self, method, parsed_url, headers, timeout): if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path - and headers.get("X-IDENTITY-HEADER") == self.identity_header + and headers.get(HDR_IDENTITY) == self.identity_header and query_string["resource"] ): logger.warning( - f"Received malformed request: {method} {parsed_url.path} {str(headers)} {str(query_string)}" + f"Received malformed request: {method} {parsed_url.path} " + f"{str(headers)} {str(query_string)}" ) - raise HTTPError() + raise ConnectTimeout() logger.debug("Received request for Azure Functions metadata service") resource = query_string["resource"][0] - self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) - return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) - - def __enter__(self): - # In addition to the normal patching, we need to set the environment variables that Azure Functions would set. - os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint - os.environ["IDENTITY_HEADER"] = self.identity_header - return super().__enter__() - - def __exit__(self, *args, **kwargs): - os.environ.pop("IDENTITY_ENDPOINT") - os.environ.pop("IDENTITY_HEADER") - return super().__exit__(*args, **kwargs) + self.token = gen_dummy_id_token(self.sub, self.iss, resource) + return build_response(json.dumps({"access_token": self.token}).encode()) class FakeGceMetadataService(FakeMetadataService): - """Emulates an environment with the GCE metadata service.""" + """Simulates GCE metadata endpoint.""" - def reset_defaults(self): + def reset_defaults(self) -> None: # Defaults used for generating a token. Can be overriden in individual tests. self.sub = "123" self.iss = "https://accounts.google.com" - @property - def expected_hostname(self): - return "169.254.169.254" + def __enter__(self): + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + + def is_expected_hostname(self, host: str | None) -> bool: + return host == GCE_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -229,80 +312,256 @@ def handle_request(self, method, parsed_url, headers, timeout): # Reject malformed requests. if not ( method == "GET" - and parsed_url.path - == "/computeMetadata/v1/instance/service-accounts/default/identity" - and headers.get("Metadata-Flavor") == "Google" - and query_string["audience"] + and parsed_url.path == GCE_IDENTITY_PATH + and headers.get(HDR_METADATA_FLAVOR) == "Google" + and query_string.get("audience") ): - raise HTTPError() + raise ConnectTimeout() logger.debug("Received request for GCE metadata service") audience = query_string["audience"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) - return build_response(self.token.encode("utf-8")) + return build_response(self.token.encode()) -class FakeAwsEnvironment: - """Emulates the AWS environment-specific functions used in wif_util.py. +class _AwsMetadataService(FakeMetadataService): + """Low-level fake for IMDSv2 and ECS endpoints.""" - Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so - emulating them here would be complex and fragile. Instead, we emulate the higher-level functions - called by the connector code. + HDR_IMDS_TOKEN_TTL = "x-aws-ec2-metadata-token-ttl-seconds" + IMDS_INSTANCE_IDENTITY_DOC = "/latest/dynamic/instance-identity/document" + IMDS_REGION_PATH = "/latest/meta-data/placement/region" + IMDS_AZ_PATH = "/latest/meta-data/placement/availability-zone" + + def reset_defaults(self) -> None: + self.role_name = "MyRole" + self.access_key = "AKIA_TEST" + self.secret_key = "SK_TEST" + self.session_token = "STS_TOKEN" + self.imds_token = "IMDS_TOKEN" + self.region = "us-east-1" + + def is_expected_hostname(self, host: str | None) -> bool: + return host in { + urlparse(_IMDS_BASE_URL).hostname, + urlparse(_ECS_CRED_BASE_URL).hostname, + } + + def handle_request(self, method, parsed_url, headers, timeout): + url = f"{parsed_url.scheme}://{parsed_url.hostname}{parsed_url.path}" + + if method == "PUT" and url == f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}": + return build_response( + self.imds_token.encode(), + headers={self.__class__.HDR_IMDS_TOKEN_TTL: "21600"}, + ) + + if method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}": + return build_response(self.role_name.encode()) + + if ( + method == "GET" + and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}{self.role_name}" + ): + if self.access_key is None or self.secret_key is None: + return build_response(b"", status_code=404) + creds_json = json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + return build_response(creds_json) + + ecs_uri = os.getenv(AWS_CONTAINER_CRED_ENV) + if ecs_uri and method == "GET" and url == f"{_ECS_CRED_BASE_URL}{ecs_uri}": + creds_json = json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + return build_response(creds_json) + + if ( + method == "GET" + and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_REGION_PATH}" + ): + return build_response(self.region.encode()) + + if method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_AZ_PATH}": + return build_response(f"{self.region}a".encode()) + + if ( + method == "GET" + and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_INSTANCE_IDENTITY_DOC}" + ): + return build_response(json.dumps({"region": self.region}).encode()) + + raise ConnectTimeout() + + +class FakeAwsEnvironment: + """ + Base context-manager for AWS runtime fakes. + Subclasses override `_prepare_runtime()` to tweak env-vars / creds. """ def __init__(self): # Defaults used for generating a token. Can be overriden in individual tests. + self._region = "us-east-1" self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" - self.region = "us-east-1" - self.credentials = Credentials(access_key="ak", secret_key="sk") + self.credentials: Credentials | None = Credentials( + access_key="ak", secret_key="sk", token="tk" + ) + self._metadata = _AwsMetadataService() + self._stack: ExitStack | None = None - def get_region(self): - return self.region + @property + def region(self) -> str: + return self._region - def get_arn(self): - return self.arn + @region.setter + def region(self, new_region: str) -> None: + """Change runtime region and, if the env-vars already exist, + patch them via ExitStack so they’re cleaned up on __exit__. + """ + self._region = new_region + self._metadata.region = new_region - def get_credentials(self): - return self.credentials + if getattr(self, "_stack", None): + for key in AWS_REGION_ENV_KEYS: + if key in os.environ: + self._stack.enter_context( + mock.patch.dict(os.environ, {key: new_region}, clear=False) + ) - def sign_request(self, request: AWSRequest): - request.headers.add_header("X-Amz-Date", datetime.time().isoformat()) - request.headers.add_header("X-Amz-Security-Token", "") - request.headers.add_header( - "Authorization", - f"AWS4-HMAC-SHA256 Credential=, SignedHeaders={';'.join(request.headers.keys())}, Signature=", - ) + def _prepare_runtime(self): + """Sub-classes patch env / credentials here.""" + return None def __enter__(self): - # Patch the relevant functions to do what we want. - self.patchers = [] - self.patchers.append( + self._stack = ExitStack() + + self._stack.enter_context( mock.patch( - "boto3.session.Session.get_credentials", - side_effect=self.get_credentials, + "snowflake.connector.vendored.requests.request", + side_effect=self._metadata, ) ) - self.patchers.append( + self._stack.enter_context( mock.patch( - "botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request + "snowflake.connector.vendored.requests.sessions.Session.request", + side_effect=self._metadata, ) ) - self.patchers.append( + self._stack.enter_context( mock.patch( - "snowflake.connector.wif_util.get_aws_region", - side_effect=self.get_region, + "urllib3.connection.HTTPConnection.request", + side_effect=ConnectTimeout(), ) ) - self.patchers.append( - mock.patch( - "snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn - ) + + # Keep the metadata stub in sync with the final credential set. + self._metadata.access_key = ( + self.credentials.access_key if self.credentials else None + ) + self._metadata.secret_key = ( + self.credentials.secret_key if self.credentials else None ) - for patcher in self.patchers: - patcher.__enter__() + self._metadata.session_token = ( + self.credentials.token if self.credentials else None + ) + self._metadata.region = self.region + + env_for_chain = {key: self.region for key in AWS_REGION_ENV_KEYS} + if self.credentials: + env_for_chain["AWS_ACCESS_KEY_ID"] = self.credentials.access_key + env_for_chain["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key + if self.credentials.token: + env_for_chain["AWS_SESSION_TOKEN"] = self.credentials.token + + self._stack.enter_context( + mock.patch.dict(os.environ, env_for_chain, clear=False) + ) + + # Runtime-specific tweaks (may change creds / env). + self._prepare_runtime() return self - def __exit__(self, *args, **kwargs): - for patcher in self.patchers: - patcher.__exit__(*args, **kwargs) + def __exit__(self, *exc): + self._stack.close() + + +class FakeAwsEc2(FakeAwsEnvironment): + """Default – IMDSv2 only.""" + + +class FakeAwsEcs(FakeAwsEnvironment): + """ECS/EKS task-role – exposes creds via task metadata endpoint.""" + + def _prepare_runtime(self): + self._stack.enter_context( + mock.patch.dict( + os.environ, + {AWS_CONTAINER_CRED_ENV: "/v2/credentials/test-id"}, + clear=False, + ) + ) + + +class FakeAwsLambda(FakeAwsEnvironment): + """Lambda runtime – temporary credentials + runtime env-vars.""" + + def __init__(self): + super().__init__() + # Lambda always returns *session* credentials + self.credentials = Credentials( + access_key="ak", + secret_key="sk", + token="dummy-session-token", + ) + + def _prepare_runtime(self) -> None: + # Patch env vars via mock.patch.dict so nothing touches os.environ directly + self._stack.enter_context( + mock.patch.dict( + os.environ, + {AWS_LAMBDA_FUNCTION_ENV: "dummy-fn"}, + clear=False, + ) + ) + + +class _AwsMetadataTimeout(_AwsMetadataService): + """IMDS/ECS stub that never answers – simulates a totally unreachable endpoint.""" + + def handle_request(self, *args, **kwargs): + raise ConnectTimeout() + + +class FakeAwsNoCreds(FakeAwsEnvironment): + """Negative path – no credentials anywhere *and* IMDS/ECS completely unreachable.""" + + def __init__(self): + super().__init__() + # Use the timeout-only IMDS stub + self._metadata = _AwsMetadataTimeout() + + def _prepare_runtime(self): + # Strip every env-var that could satisfy the AWS credential chain + self.credentials = None + self._stack.enter_context( + mock.patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "", + "AWS_SECRET_ACCESS_KEY": "", + "AWS_SESSION_TOKEN": "", + AWS_CONTAINER_CRED_ENV: "", + }, + clear=False, + ) + ) diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 65c2fb02f..89cd79649 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -5,7 +5,10 @@ from snowflake.connector.telemetry_oob import TelemetryService from ..csp_helpers import ( - FakeAwsEnvironment, + FakeAwsEc2, + FakeAwsEcs, + FakeAwsLambda, + FakeAwsNoCreds, FakeAzureFunctionMetadataService, FakeAzureVmMetadataService, FakeGceMetadataService, @@ -30,10 +33,31 @@ def no_metadata_service(): yield server +@pytest.fixture( + params=[FakeAwsEc2, FakeAwsEcs, FakeAwsLambda], + ids=["aws_ec2", "aws_ecs", "aws_lambda"], +) +def fake_aws_environment(request): + """Runtimes that *do* expose credentials.""" + with request.param() as env: + yield env + + @pytest.fixture -def fake_aws_environment(): - """Emulates the AWS environment, returning dummy credentials.""" - with FakeAwsEnvironment() as env: +def imds_only_aws_environment(fake_aws_environment, monkeypatch): + """ + Same fake runtime, but with AWS_REGION / AWS_DEFAULT_REGION removed + so the code *must* query IMDS to discover the region. + """ + for key in ("AWS_REGION", "AWS_DEFAULT_REGION"): + monkeypatch.delenv(key, raising=False) + yield fake_aws_environment + + +@pytest.fixture(params=[FakeAwsNoCreds], ids=["aws_no_creds"]) +def malformed_aws_environment(request): + """Runtime where *no* credentials are discoverable (negative-path).""" + with request.param() as env: yield env diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index f2e42aae3..791a254f3 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -17,8 +17,8 @@ from snowflake.connector.wif_util import ( AZURE_ISSUER_PREFIXES, AttestationProvider, - get_aws_partition, - get_aws_sts_hostname, + _partition_from_region, + _sts_host_from_region, ) from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token @@ -34,28 +34,36 @@ def extract_api_data(auth_class: AuthByWorkloadIdentity): def verify_aws_token(token: str, region: str): - """Performs some basic checks on a 'token' produced for AWS, to ensure it includes the expected fields.""" - decoded_token = json.loads(b64decode(token)) - - parsed_url = urlparse(decoded_token["url"]) - assert parsed_url.scheme == "https" - assert parsed_url.hostname == f"sts.{region}.amazonaws.com" - query_string = parse_qs(parsed_url.query) - assert query_string.get("Action")[0] == "GetCallerIdentity" - assert query_string.get("Version")[0] == "2011-06-15" - - assert decoded_token["method"] == "POST" - - headers = decoded_token["headers"] - assert set(headers.keys()) == { - "Host", - "X-Snowflake-Audience", - "X-Amz-Date", - "X-Amz-Security-Token", - "Authorization", + """Accepts both SigV4 variants (with / without session token).""" + decoded_payload = json.loads(b64decode(token)) + + # URL validation + sts_request_url = urlparse(decoded_payload["url"]) + assert sts_request_url.scheme == "https" + assert sts_request_url.hostname == f"sts.{region}.amazonaws.com" + + query_params = parse_qs(sts_request_url.query) + assert query_params["Action"][0] == "GetCallerIdentity" + assert query_params["Version"][0] == "2011-06-15" + + # Method validation + assert decoded_payload["method"] == "POST" + + # Header validation + headers = {k.lower(): v for k, v in decoded_payload["headers"].items()} + + mandatory_headers = { + "host", + "x-snowflake-audience", + "x-amz-date", + "authorization", } - assert headers["Host"] == f"sts.{region}.amazonaws.com" - assert headers["X-Snowflake-Audience"] == "snowflakecomputing.com" + optional_headers = {"x-amz-security-token"} + + assert mandatory_headers.issubset(headers) + assert set(headers).issubset(mandatory_headers | optional_headers) + assert headers["host"] == f"sts.{region}.amazonaws.com" + assert headers["x-snowflake-audience"] == "snowflakecomputing.com" # -- OIDC Tests -- @@ -107,8 +115,10 @@ def test_explicit_oidc_no_token_raises_error(): # -- AWS Tests -- -def test_explicit_aws_no_auth_raises_error(fake_aws_environment: FakeAwsEnvironment): - fake_aws_environment.credentials = None +def test_explicit_aws_no_auth_raises_error( + malformed_aws_environment: FakeAwsEnvironment, +): + malformed_aws_environment.credentials = None auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) with pytest.raises(ProgrammingError) as excinfo: @@ -137,7 +147,7 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro data = extract_api_data(auth_class) decoded_token = json.loads(b64decode(data["TOKEN"])) hostname_from_url = urlparse(decoded_token["url"]).hostname - hostname_from_header = decoded_token["headers"]["Host"] + hostname_from_header = decoded_token["headers"]["host"] expected_hostname = "sts.antarctica-northeast-3.amazonaws.com" assert expected_hostname == hostname_from_url @@ -147,83 +157,84 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro def test_explicit_aws_generates_unique_assertion_content( fake_aws_environment: FakeAwsEnvironment, ): - fake_aws_environment.arn = ( - "arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab" - ) + # Change region to ensure assertion_content updates accordingly. + fake_aws_environment.region = "antarctica-northeast-3" + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) auth_class.prepare() - assert ( - '{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}' - == auth_class.assertion_content - ) + expected = '{"_provider":"AWS","region":"' + fake_aws_environment.region + '"}' + assert auth_class.assertion_content == expected + + +@pytest.mark.parametrize( + "arn_env_var", + [ + "AWS_ROLE_ARN", + "AWS_EC2_METADATA_ARN", + "AWS_SESSION_ARN", + ], +) +def test_explicit_aws_includes_arn_when_env_present( + fake_aws_environment: FakeAwsEnvironment, + monkeypatch, + arn_env_var, +): + dummy_arn = "arn:aws:sts::123456789012:assumed-role/MyRole/i-abcdef123456" + monkeypatch.setenv(arn_env_var, dummy_arn) + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth_class.prepare() + + # Parse the JSON to ignore ordering. + assertion_data = json.loads(auth_class.assertion_content) + + assert assertion_data["_provider"] == "AWS" + assert assertion_data["region"] == fake_aws_environment.region + assert assertion_data["arn"] == dummy_arn @pytest.mark.parametrize( - "arn, expected_partition", + "region, expected_partition", [ - ("arn:aws:iam::123456789012:role/MyTestRole", "aws"), - ( - "arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0", - "aws-cn", - ), - ("arn:aws-us-gov:s3:::my-gov-bucket", "aws-us-gov"), - ("arn:aws:s3:::my-bucket/my/key", "aws"), - ("arn:aws:lambda:us-east-1:123456789012:function:my-function", "aws"), - ("arn:aws:sns:eu-west-1:111122223333:my-topic", "aws"), - # Edge cases / Invalid inputs - ("invalid-arn", None), - ("arn::service:region:account:resource", None), # Missing partition - ("arn:aws:iam:", "aws"), # Incomplete ARN, but partition is present - ("", None), # Empty string - (None, None), # None input - (123, None), # Non-string input + # — happy-path AWS commercial + ("us-east-1", "aws"), + ("eu-central-1", "aws"), + ("ap-south-1", "aws"), + # — China partitions + ("cn-north-1", "aws-cn"), + ("cn-northwest-1", "aws-cn"), + # — GovCloud partitions + ("us-gov-west-1", "aws-us-gov"), + ("us-gov-east-1", "aws-us-gov"), + # - Weird values also fall back to commercial + ("invalid-region", "aws"), + ("", "aws"), ], ) -def test_get_aws_partition_valid_and_invalid_arns(arn, expected_partition): - assert get_aws_partition(arn) == expected_partition +def test_partition_from_region(region, expected_partition): + assert _partition_from_region(region).value == expected_partition @pytest.mark.parametrize( - "region, partition, expected_hostname", + "region, expected_hostname", [ - # AWS partition - ("us-east-1", "aws", "sts.us-east-1.amazonaws.com"), - ("eu-west-2", "aws", "sts.eu-west-2.amazonaws.com"), - ("ap-southeast-1", "aws", "sts.ap-southeast-1.amazonaws.com"), - ( - "us-east-1", - "aws", - "sts.us-east-1.amazonaws.com", - ), # Redundant but good for coverage - # AWS China partition - ("cn-north-1", "aws-cn", "sts.cn-north-1.amazonaws.com.cn"), - ("cn-northwest-1", "aws-cn", "sts.cn-northwest-1.amazonaws.com.cn"), - ("", "aws-cn", None), # No global endpoint for 'aws-cn' without region - # AWS GovCloud partition - ("us-gov-west-1", "aws-us-gov", "sts.us-gov-west-1.amazonaws.com"), - ("us-gov-east-1", "aws-us-gov", "sts.us-gov-east-1.amazonaws.com"), - ("", "aws-us-gov", None), # No global endpoint for 'aws-us-gov' without region - # Invalid/Edge cases - ("us-east-1", "unknown-partition", None), # Unknown partition - ("some-region", "invalid-partition", None), # Invalid partition - (None, "aws", None), # None region - ("us-east-1", None, None), # None partition - (123, "aws", None), # Non-string region - ("us-east-1", 456, None), # Non-string partition - ("", "", None), # Empty region and partition - ("us-east-1", "", None), # Empty partition - ( - "invalid-region", - "aws", - "sts.invalid-region.amazonaws.com", - ), # Valid format, invalid region name + # commercial partition + ("us-east-1", "sts.us-east-1.amazonaws.com"), + ("eu-west-2", "sts.eu-west-2.amazonaws.com"), + # China + ("cn-north-1", "sts.cn-north-1.amazonaws.com.cn"), + # GovCloud + ("us-gov-east-1", "sts.us-gov-east-1.amazonaws.com"), + # unknown but syntactically valid - still formatted + ("invalid-region", "sts.invalid-region.amazonaws.com"), + ("", None), + (None, None), + (123, None), ], ) -def test_get_aws_sts_hostname_valid_and_invalid_inputs( - region, partition, expected_hostname -): - assert get_aws_sts_hostname(region, partition) == expected_hostname +def test_sts_host_from_region_valid_inputs(region, expected_hostname): + assert _sts_host_from_region(region) == expected_hostname # -- GCP Tests -- @@ -456,3 +467,28 @@ def test_autodetect_no_provider_raises_error(no_metadata_service): assert "No workload identity credential was found for 'auto-detect" in str( excinfo.value ) + + +def test_explicit_aws_region_falls_back_to_imds(imds_only_aws_environment): + """ + When region env-vars are absent, the connector must discover the region via + the runtime metadata service (IMDS / task-metadata / lambda env). + """ + # Advertise a non-default region through the fake metadata service + imds_only_aws_environment.region = "us-west-2" + + auth = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth.prepare() + + verify_aws_token(extract_api_data(auth)["TOKEN"], "us-west-2") + + +def test_autodetect_prefers_gcp_when_no_aws_env(fake_gce_metadata_service): + """ + No AWS env-vars + a responsive GCP metadata server -> GCP selected. + """ + auth_class = AuthByWorkloadIdentity(provider=None) + auth_class.prepare() + + assert extract_api_data(auth_class)["PROVIDER"] == "GCP" + assert extract_api_data(auth_class)["TOKEN"] == fake_gce_metadata_service.token diff --git a/test/unit/test_boto_compatibility.py b/test/unit/test_boto_compatibility.py new file mode 100644 index 000000000..fe3a8f895 --- /dev/null +++ b/test/unit/test_boto_compatibility.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import datetime +import urllib.parse as urlparse + +import pytest +from botocore import session as boto_session +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.credentials import Credentials + +from snowflake.connector import _aws_credentials +from snowflake.connector._aws_sign_v4 import sign_get_caller_identity +from snowflake.connector.wif_util import _sts_host_from_region + + +def _normalise_headers(headers: dict[str, str]) -> dict[str, str]: + """Lower-case keys, trim values, drop User-Agent (botocore adds it).""" + return { + k.lower(): v.strip() for k, v in headers.items() if k.lower() != "user-agent" + } + + +@pytest.fixture +def freeze_utcnow(monkeypatch: pytest.MonkeyPatch): + """Freeze `datetime.datetime.utcnow()` for deterministic SigV4 signatures.""" + fixed = datetime.datetime(2025, 1, 1, 0, 0, 0) + + class _FrozenDateTime(datetime.datetime): + @classmethod + def utcnow(cls): + return fixed + + monkeypatch.setattr(datetime, "datetime", _FrozenDateTime) + yield + + +@pytest.mark.parametrize("region", ["us-east-1", "eu-west-1", "us-gov-west-1"]) +def test_sigv4_parity_with_botocore(region: str, freeze_utcnow): + url = ( + f"https://{_sts_host_from_region(region)}" + "/?Action=GetCallerIdentity&Version=2011-06-15" + ) + + access_key_id = "AKIDEXAMPLE" + secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + + sf_driver_aws_headers = sign_get_caller_identity( + url=url, + region=region, + access_key=access_key_id, + secret_key=secret_access_key, + ) + + boto_req = AWSRequest( + method="POST", + url=url, + headers={ + "Host": sf_driver_aws_headers["host"], + "X-Snowflake-Audience": "snowflakecomputing.com", + "X-Amz-Date": datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ"), + }, + ) + SigV4Auth(Credentials(access_key_id, secret_access_key), "sts", region).add_auth( + boto_req + ) + + assert "authorization" in sf_driver_aws_headers + assert _normalise_headers(sf_driver_aws_headers) == _normalise_headers( + boto_req.headers + ) + + +@pytest.mark.parametrize("region", ["us-east-1", "eu-west-1", "us-gov-west-1"]) +def test_sigv4_parity_with_session_token(region: str, freeze_utcnow): + url = ( + f"https://{_sts_host_from_region(region)}" + "/?Action=GetCallerIdentity&Version=2011-06-15" + ) + + access_key_id = "AKIDEXAMPLE" + secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + session_token = "IQoJb3JpZ2luX2VjEPr//////////wEaCXVzLWFz" + + sf_driver_aws_headers = sign_get_caller_identity( + url=url, + region=region, + access_key=access_key_id, + secret_key=secret_access_key, + session_token=session_token, + ) + + boto_req = AWSRequest( + method="POST", + url=url, + headers={ + "Host": sf_driver_aws_headers["host"], + "X-Snowflake-Audience": "snowflakecomputing.com", + "X-Amz-Date": sf_driver_aws_headers["x-amz-date"], + "X-Amz-Security-Token": session_token, + }, + ) + SigV4Auth( + Credentials(access_key_id, secret_access_key, token=session_token), + "sts", + region, + ).add_auth(boto_req) + + assert _normalise_headers(sf_driver_aws_headers) == _normalise_headers( + boto_req.headers + ) + + +@pytest.mark.parametrize( + "region", ["us-east-1", "eu-west-1", "us-gov-west-1", "cn-north-1"] +) +def test_sts_host_from_region_matches_botocore( + monkeypatch: pytest.MonkeyPatch, region: str +): + sf_host = _sts_host_from_region(region) + + # Force botocore into regional mode so that it doesn’t fall back to the + # legacy global host (sts.amazonaws.com) for the particular regions (like us-east-1). + # Both approaches work correctly. + monkeypatch.setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") + + boto_host = urlparse.urlparse( + boto_session.Session() + .create_client( + "sts", region_name=region, aws_access_key_id="x", aws_secret_access_key="y" + ) + .meta.endpoint_url + ).netloc.lower() + + assert sf_host == boto_host + + +def test_region_env_var_default(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Both libraries should resolve the region from AWS_DEFAULT_REGION + without any extra hints. + """ + expected_region = "ap-southeast-2" + monkeypatch.delenv("AWS_REGION", raising=False) + monkeypatch.setenv("AWS_DEFAULT_REGION", expected_region) + + # Driver + sf_region = _aws_credentials.get_region() + assert sf_region == expected_region + + # Botocore + boto_region = ( + boto_session.Session() + .create_client("s3", aws_access_key_id="x", aws_secret_access_key="y") + .meta.region_name + ) + assert boto_region == sf_region + + +def test_region_env_var_legacy(monkeypatch: pytest.MonkeyPatch) -> None: + """ + AWS_REGION is ignored by botocore currently, but should be introduced in the future: https://docs.aws.amazon.com/sdkref/latest/guide/feature-region.html + Therefore for now we set it as env_var for the driver and pass via explicit parameter to botocore. + """ + desired_region = "ca-central-1" + monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) + monkeypatch.setenv("AWS_REGION", desired_region) + + # Snowflake helper sees AWS_REGION + sf_region = _aws_credentials.get_region() + assert sf_region == desired_region + + # botocore needs an explicit region_name when AWS_REGION is set + boto_region = ( + boto_session.Session() + .create_client( + "s3", + region_name=desired_region, + aws_access_key_id="x", + aws_secret_access_key="y", + ) + .meta.region_name + ) + assert boto_region == desired_region