Skip to content

Commit 2226c90

Browse files
Support WIF Impersonation on AWS workloads (#2517)
Co-authored-by: Peter Mansour <[email protected]>
1 parent 53e0165 commit 2226c90

File tree

7 files changed

+88
-24
lines changed

7 files changed

+88
-24
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010
- v3.18.0(TBD)
11-
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP workloads only
11+
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only
1212

1313
- v3.17.3(September 02,2025)
1414
- Enhanced configuration file permission warning messages.
53 Bytes
Binary file not shown.

src/snowflake/connector/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,14 +1358,18 @@ def __open_connection(self):
13581358
)
13591359
if (
13601360
self._workload_identity_impersonation_path
1361-
and self._workload_identity_provider != AttestationProvider.GCP
1361+
and self._workload_identity_provider
1362+
not in (
1363+
AttestationProvider.GCP,
1364+
AttestationProvider.AWS,
1365+
)
13621366
):
13631367
Error.errorhandler_wrapper(
13641368
self,
13651369
None,
13661370
ProgrammingError,
13671371
{
1368-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1372+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
13691373
"errno": ER_INVALID_WIF_SETTINGS,
13701374
},
13711375
)

src/snowflake/connector/wif_util.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,37 @@ def get_aws_sts_hostname(region: str, partition: str) -> str:
145145
)
146146

147147

148+
def get_aws_session(impersonation_path: list[str] | None = None):
149+
"""Creates a boto3 session with the appropriate credentials.
150+
151+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
152+
"""
153+
session = boto3.session.Session()
154+
155+
impersonation_path = impersonation_path or []
156+
for arn in impersonation_path:
157+
response = session.client("sts").assume_role(
158+
RoleArn=arn, RoleSessionName="identity-federation-session"
159+
)
160+
creds = response["Credentials"]
161+
session = boto3.session.Session(
162+
aws_access_key_id=creds["AccessKeyId"],
163+
aws_secret_access_key=creds["SecretAccessKey"],
164+
aws_session_token=creds["SessionToken"],
165+
)
166+
return session
167+
168+
148169
def create_aws_attestation(
149-
session_manager: SessionManager | None = None,
170+
impersonation_path: list[str] | None = None,
150171
) -> WorkloadIdentityAttestation:
151172
"""Tries to create a workload identity attestation for AWS.
152173
153174
If the application isn't running on AWS or no credentials were found, raises an error.
154175
"""
155176
# TODO: SNOW-2223669 Investigate if our adapters - containing settings of http traffic - should be passed here as boto urllib3session. Those requests go to local servers, so they do not need Proxy setup or Headers customization in theory. But we may want to have all the traffic going through one class (e.g. Adapter or mixin).
156-
session = boto3.session.Session()
177+
session = get_aws_session(impersonation_path)
178+
157179
aws_creds = session.get_credentials()
158180
if not aws_creds:
159181
raise ProgrammingError(
@@ -387,7 +409,7 @@ def create_attestation(
387409
)
388410

389411
if provider == AttestationProvider.AWS:
390-
return create_aws_attestation(session_manager)
412+
return create_aws_attestation(impersonation_path)
391413
elif provider == AttestationProvider.AZURE:
392414
return create_azure_attestation(entra_resource, session_manager)
393415
elif provider == AttestationProvider.GCP:

test/csp_helpers.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ def gen_dummy_id_token(
4040
)
4141

4242

43-
def gen_dummy_access_token(sub="test-subject") -> str:
43+
def gen_dummy_access_token(sub="test-subject", key="secret") -> str:
4444
"""Generates a dummy access token using the given subject."""
45-
key = "secret"
4645
logger.debug(f"Generating dummy access token for subject {sub}")
4746
return (sub + key).encode("utf-8").hex()
4847

@@ -368,6 +367,11 @@ class FakeAwsEnvironment:
368367
def __init__(self):
369368
# Defaults used for generating a token. Can be overriden in individual tests.
370369
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
370+
# Path of roles that can be assumed. Empty if no impersonation is allowed.
371+
# Can be overriden in individual tests.
372+
self.assumption_path = []
373+
self.assume_role_call_count = 0
374+
371375
self.caller_identity = {"Arn": self.arn}
372376
self.region = "us-east-1"
373377
self.credentials = Credentials(access_key="ak", secret_key="sk")
@@ -376,6 +380,25 @@ def __init__(self):
376380
)
377381
self.metadata_token = "test-token"
378382

383+
def assume_role(self, **kwargs):
384+
if (
385+
self.assumption_path
386+
and kwargs["RoleArn"] == self.assumption_path[self.assume_role_call_count]
387+
):
388+
arn = self.assumption_path[self.assume_role_call_count]
389+
self.assume_role_call_count += 1
390+
return {
391+
"Credentials": {
392+
"AccessKeyId": "access_key",
393+
"SecretAccessKey": "secret_key",
394+
"SessionToken": "session_token",
395+
"Expiration": int(time()) + 60 * 60,
396+
},
397+
"AssumedRoleUser": {"AssumedRoleId": hash(arn), "Arn": arn},
398+
"ResponseMetadata": {},
399+
}
400+
return {}
401+
379402
def get_region(self):
380403
return self.region
381404

@@ -399,6 +422,7 @@ def fetcher_fetch_metadata_token(self):
399422
def boto3_client(self, *args, **kwargs):
400423
mock_client = mock.Mock()
401424
mock_client.get_caller_identity.return_value = self.caller_identity
425+
mock_client.assume_role = self.assume_role
402426
return mock_client
403427

404428
def __enter__(self):
@@ -439,6 +463,9 @@ def __enter__(self):
439463
side_effect=self.boto3_client,
440464
)
441465
)
466+
self.patchers.append(
467+
mock.patch("boto3.session.Session.client", side_effect=self.boto3_client)
468+
)
442469
for patcher in self.patchers:
443470
patcher.__enter__()
444471
return self

test/unit/test_auth_workload_identity.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,22 @@ def test_get_aws_sts_hostname_invalid_inputs(region, partition):
274274
assert "Invalid AWS partition" in str(excinfo.value)
275275

276276

277+
def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
278+
fake_aws_environment: FakeAwsEnvironment,
279+
):
280+
impersonation_path = [
281+
"arn:aws:iam::123456789:role/role2",
282+
"arn:aws:iam::123456789:role/role3",
283+
]
284+
fake_aws_environment.assumption_path = impersonation_path
285+
auth_class = AuthByWorkloadIdentity(
286+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
287+
)
288+
auth_class.prepare(conn=None)
289+
290+
assert fake_aws_environment.assume_role_call_count == 2
291+
292+
277293
# -- GCP Tests --
278294

279295

test/unit/test_connection.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -682,16 +682,14 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
682682
"provider_param",
683683
[
684684
# Strongly-typed values.
685-
AttestationProvider.AWS,
686685
AttestationProvider.AZURE,
687686
AttestationProvider.OIDC,
688687
# String values.
689-
"AWS",
690688
"AZURE",
691689
"OIDC",
692690
],
693691
)
694-
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
692+
def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
695693
monkeypatch, provider_param
696694
):
697695
with monkeypatch.context() as m:
@@ -709,20 +707,22 @@ def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
709707
],
710708
)
711709
assert (
712-
"workload_identity_impersonation_path is currently only supported for GCP."
710+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
713711
in str(excinfo.value)
714712
)
715713

716714

717715
@pytest.mark.parametrize(
718-
"provider_param",
716+
"provider_param,impersonation_path",
719717
[
720-
AttestationProvider.GCP,
721-
"GCP",
718+
(AttestationProvider.GCP, ["[email protected]"]),
719+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
720+
("GCP", ["[email protected]"]),
721+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
722722
],
723723
)
724-
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
725-
monkeypatch, provider_param
724+
def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
725+
monkeypatch, provider_param, impersonation_path
726726
):
727727
with monkeypatch.context() as m:
728728
m.setattr(
@@ -733,14 +733,9 @@ def test_workload_identity_impersonation_path_supported_for_gcp_provider(
733733
account="account",
734734
authenticator="WORKLOAD_IDENTITY",
735735
workload_identity_provider=provider_param,
736-
workload_identity_impersonation_path=[
737-
738-
],
736+
workload_identity_impersonation_path=impersonation_path,
739737
)
740-
assert conn.auth_class.provider == AttestationProvider.GCP
741-
assert conn.auth_class.impersonation_path == [
742-
743-
]
738+
assert conn.auth_class.impersonation_path == impersonation_path
744739

745740

746741
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)