diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 3adb2b4ae9..a5b33e59e9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,6 +7,10 @@ https://docs.snowflake.com/ Source code is also available at: https://github.com/snowflakedb/snowflake-connector-python # Release Notes +- Unreleased + - Added support for AWS outbound JWT token attestation for Workload Identity Federation (WIF). This can be enabled by setting the + `SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN` environment variable to `true`. + - v4.4.0(March 24,2026) - Bump the lower boundary of cryptography to 46.0.5 due to CVE-2026-26007. - Added support for Python 3.14. diff --git a/src/snowflake/connector/platform_detection.py b/src/snowflake/connector/platform_detection.py index 3a5a02c78f..0d3127f5fc 100644 --- a/src/snowflake/connector/platform_detection.py +++ b/src/snowflake/connector/platform_detection.py @@ -429,6 +429,22 @@ def is_github_action(): ) +def is_aws_wif_outbound_token_enabled(): + """ + Check if AWS WIF outbound token is enabled via environment variable. + + Returns: + _DetectionState: DETECTED if SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN env var is true, + NOT_DETECTED otherwise. + """ + return ( + _DetectionState.DETECTED + if os.environ.get("SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN", "false").lower() + == "true" + else _DetectionState.NOT_DETECTED + ) + + @cache def detect_platforms( platform_detection_timeout_seconds: float | None, @@ -490,6 +506,7 @@ def detect_platforms( "is_gce_cloud_run_service": is_gcp_cloud_run_service(), "is_gce_cloud_run_job": is_gcp_cloud_run_job(), "is_github_action": is_github_action(), + "is_aws_wif_outbound_token_enabled": is_aws_wif_outbound_token_enabled(), } # Run network-calling functions in parallel diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index fbe5d51f93..f95f5a7734 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -195,29 +195,51 @@ def create_aws_attestation( ) region = get_aws_region() partition = session.get_partition_for_region(region) - sts_hostname = get_aws_sts_hostname(region, partition) - request = AWSRequest( - method="POST", - url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", - headers={ - "Host": sts_hostname, - "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, - }, - ) + # TODO: Remove this environment variable check once AWS WIF outbound token is fully released + # and make it the default behavior (SNOW-2919437) + if ( + os.environ.get("SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN", "false").lower() + == "true" + ): + sts_client = session.client("sts", region_name=region) + response = sts_client.get_web_identity_token( + Audience=[SNOWFLAKE_AUDIENCE], SigningAlgorithm="ES384" + ) + jwt_token = response["WebIdentityToken"] + logger.debug("AWS outbound token prefix: %s", jwt_token[:10]) + return WorkloadIdentityAttestation( + AttestationProvider.AWS, + jwt_token, + {"region": region, "partition": partition}, + ) + else: + sts_hostname = get_aws_sts_hostname(region, partition) + request = AWSRequest( + method="POST", + url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", + headers={ + "Host": sts_hostname, + "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, + }, + ) - SigV4Auth(aws_creds, "sts", region).add_auth(request) + SigV4Auth(aws_creds, "sts", region).add_auth(request) - assertion_dict = { - "url": request.url, - "method": request.method, - "headers": dict(request.headers.items()), - } - credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") - # Unlike other providers, for AWS, we only include general identifiers (region and partition) - # rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call. - return WorkloadIdentityAttestation( - AttestationProvider.AWS, credential, {"region": region, "partition": partition} - ) + assertion_dict = { + "url": request.url, + "method": request.method, + "headers": dict(request.headers.items()), + } + credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode( + "utf-8" + ) + # Unlike other providers, for AWS, we only include general identifiers (region and partition) + # rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call. + return WorkloadIdentityAttestation( + AttestationProvider.AWS, + credential, + {"region": region, "partition": partition}, + ) def get_gcp_access_token(session_manager: SessionManager) -> str: diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 10b111ee40..6b3446a577 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -379,6 +379,7 @@ def __init__(self): b'{"region": "us-east-1", "instanceId": "i-1234567890abcdef0"}' ) self.metadata_token = "test-token" + self.web_identity_token = "fake.jwt.token-for-testing-only" def assume_role(self, **kwargs): if ( @@ -423,6 +424,9 @@ def boto3_client(self, *args, **kwargs): mock_client = mock.Mock() mock_client.get_caller_identity.return_value = self.caller_identity mock_client.assume_role = self.assume_role + mock_client.get_web_identity_token.return_value = { + "WebIdentityToken": self.web_identity_token + } return mock_client def __enter__(self): diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 7c9ff4a03e..b6bff08511 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -333,6 +333,38 @@ def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_pat assert fake_aws_environment.assume_role_call_count == 2 +@pytest.mark.parametrize( + "env_value,expected_format", + [ + ("true", "jwt"), + ("false", "old"), + (None, "old"), + ], +) +def test_aws_token_format_based_on_env_variable( + fake_aws_environment: FakeAwsEnvironment, + monkeypatch, + env_value, + expected_format, +): + """Test that AWS uses correct token format based on SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN environment variable.""" + if env_value is not None: + monkeypatch.setenv("SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN", env_value) + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth_class.prepare(conn=None) + + data = extract_api_data(auth_class) + + assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" + assert data["PROVIDER"] == "AWS" + + if expected_format == "jwt": + assert data["TOKEN"] == fake_aws_environment.web_identity_token + else: + verify_aws_token(data["TOKEN"], fake_aws_environment.region) + + # -- GCP Tests -- diff --git a/test/wif/test_wif.py b/test/wif/test_wif.py index 4b57aa0d76..f496fda22e 100644 --- a/test/wif/test_wif.py +++ b/test/wif/test_wif.py @@ -79,6 +79,26 @@ def test_should_authenticate_with_impersonation(): ), f"Failed to connect using WIF with provider {PROVIDER}" +@pytest.mark.wif +def test_should_authenticate_using_aws_outbound_token(): + if PROVIDER != "AWS": + pytest.skip("Skipping test - not running on AWS") + + os.environ["SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN"] = "true" + try: + connection_params = { + "host": HOST, + "account": ACCOUNT, + "authenticator": "WORKLOAD_IDENTITY", + "workload_identity_provider": "AWS", + } + assert connect_and_execute_simple_query( + connection_params, EXPECTED_USERNAME + ), "Failed to connect using WIF with AWS outbound token" + finally: + os.environ.pop("SNOWFLAKE_ENABLE_AWS_WIF_OUTBOUND_TOKEN", None) + + def is_provider_gcp() -> bool: return PROVIDER == "GCP"