Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/snowflake/connector/platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,21 @@ def is_github_action():
)


def is_aws_wif_outbound_token_enabled():
"""
Check if AWS WIF outbound token is enabled via environment variable.

Returns:
_DetectionState: DETECTED if ENABLE_AWS_WIF_OUTBOUND_TOKEN env var is true,
NOT_DETECTED otherwise.
"""
return (
_DetectionState.DETECTED
if os.environ.get("ENABLE_AWS_WIF_OUTBOUND_TOKEN", "false").lower() == "true"
else _DetectionState.NOT_DETECTED
)


@cache
def detect_platforms(
platform_detection_timeout_seconds: float | None,
Expand Down Expand Up @@ -490,6 +505,7 @@ def detect_platforms(
"is_gce_cloud_run_service": is_gcp_cloud_run_service(),
"is_gce_cloud_run_job": is_gcp_cloud_run_job(),
"is_github_action": is_github_action(),
"is_aws_wif_outbound_token_enabled": is_aws_wif_outbound_token_enabled(),
}

# Run network-calling functions in parallel
Expand Down
60 changes: 39 additions & 21 deletions src/snowflake/connector/wif_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,29 +195,47 @@ def create_aws_attestation(
)
region = get_aws_region()
partition = session.get_partition_for_region(region)
sts_hostname = get_aws_sts_hostname(region, partition)
request = AWSRequest(
method="POST",
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
headers={
"Host": sts_hostname,
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
},
)
# TODO: Remove this environment variable check once AWS WIF outbound token is fully released
# and make it the default behavior (SNOW-2919437)
if os.environ.get("ENABLE_AWS_WIF_OUTBOUND_TOKEN", "false").lower() == "true":
sts_client = session.client("sts", region_name=region)
response = sts_client.get_web_identity_token(
Audience=[SNOWFLAKE_AUDIENCE], SigningAlgorithm="ES384"
)
jwt_token = response["WebIdentityToken"]
return WorkloadIdentityAttestation(
AttestationProvider.AWS,
jwt_token,
{"region": region, "partition": partition},
)
else:
sts_hostname = get_aws_sts_hostname(region, partition)
request = AWSRequest(
method="POST",
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
headers={
"Host": sts_hostname,
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
},
)

SigV4Auth(aws_creds, "sts", region).add_auth(request)
SigV4Auth(aws_creds, "sts", region).add_auth(request)

assertion_dict = {
"url": request.url,
"method": request.method,
"headers": dict(request.headers.items()),
}
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
# Unlike other providers, for AWS, we only include general identifiers (region and partition)
# rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call.
return WorkloadIdentityAttestation(
AttestationProvider.AWS, credential, {"region": region, "partition": partition}
)
assertion_dict = {
"url": request.url,
"method": request.method,
"headers": dict(request.headers.items()),
}
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode(
"utf-8"
)
# Unlike other providers, for AWS, we only include general identifiers (region and partition)
# rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call.
return WorkloadIdentityAttestation(
AttestationProvider.AWS,
credential,
{"region": region, "partition": partition},
)


def get_gcp_access_token(session_manager: SessionManager) -> str:
Expand Down
4 changes: 4 additions & 0 deletions test/csp_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(self):
b'{"region": "us-east-1", "instanceId": "i-1234567890abcdef0"}'
)
self.metadata_token = "test-token"
self.web_identity_token = "fake.jwt.token-for-testing-only"

def assume_role(self, **kwargs):
if (
Expand Down Expand Up @@ -423,6 +424,9 @@ def boto3_client(self, *args, **kwargs):
mock_client = mock.Mock()
mock_client.get_caller_identity.return_value = self.caller_identity
mock_client.assume_role = self.assume_role
mock_client.get_web_identity_token.return_value = {
"WebIdentityToken": self.web_identity_token
}
return mock_client

def __enter__(self):
Expand Down
32 changes: 32 additions & 0 deletions test/unit/test_auth_workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,38 @@ def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_pat
assert fake_aws_environment.assume_role_call_count == 2


@pytest.mark.parametrize(
Copy link
Copy Markdown
Contributor

@sfc-gh-rsavenok sfc-gh-rsavenok Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to have e2e test in test/wif/test_wif.py (ask llm to explain how that e2e test is executed)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comments! I have 2 questions to get the e2e test working:

  1. The e2e test requires updating the WORKLOAD_IDENTITY configuration on TEST_WIF_E2E_AWS in sfctest0. May I know who should I reach out to for that?
  2. The GS-side changes are merged but not yet rolled out to prod. Does the CI run against a prod or non-prod environment? And would the e2e test need to wait for the full GS rollout?
    Thanks!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems TEST_WIF_E2E_AWS is a test user used in our team cc @sfc-gh-xizhao if you have any idea, thanks!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes available on preprod? We could write a test there first and later switch to sfctest0

Ping @sfc-gh-akolodziejczyk on Slack to get access to the accounts we use for wif e2e tests.

e2e tests are super important, as when we worked on WIF in other drivers based on Python implementation, we had some drivers not connecting because of wrong implementation

Copy link
Copy Markdown
Author

@sfc-gh-yuzzhang sfc-gh-yuzzhang Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. The e2e test test_should_authenticate_using_aws_outbound_token is added in test/wif/test_wif.py.

I've also manually validated the full flow on qa6 using the current branch(AWS VM) across all 4 scenarios and details are in Description section.
The CI test is currently failing because the GS param hasn't rolled out to prod yet, and TEST_WIF_E2E_AWS in sfctest0 needs some configured after GS rollout. I'll reach out to get the test account set up. The CI test will pass once GS is fully rolled out to prod and the account is configured.

"env_value,expected_format",
[
("true", "jwt"),
("false", "old"),
(None, "old"),
],
)
def test_aws_token_format_based_on_env_variable(
fake_aws_environment: FakeAwsEnvironment,
monkeypatch,
env_value,
expected_format,
):
"""Test that AWS uses correct token format based on ENABLE_AWS_WIF_OUTBOUND_TOKEN environment variable."""
if env_value is not None:
monkeypatch.setenv("ENABLE_AWS_WIF_OUTBOUND_TOKEN", env_value)

auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
auth_class.prepare(conn=None)

data = extract_api_data(auth_class)

assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY"
assert data["PROVIDER"] == "AWS"

if expected_format == "jwt":
assert data["TOKEN"] == fake_aws_environment.web_identity_token
else:
verify_aws_token(data["TOKEN"], fake_aws_environment.region)


# -- GCP Tests --


Expand Down
Loading