Skip to content

Commit 7220c4c

Browse files
sfc-gh-eqinsfc-gh-pmansour
authored andcommitted
Support WIF Impersonation on AWS workloads (#2517)
Co-authored-by: Peter Mansour <[email protected]>
1 parent beb34fd commit 7220c4c

File tree

5 files changed

+87
-23
lines changed

5 files changed

+87
-23
lines changed

src/snowflake/connector/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,14 +1358,18 @@ def __open_connection(self):
13581358
)
13591359
if (
13601360
self._workload_identity_impersonation_path
1361-
and self._workload_identity_provider != AttestationProvider.GCP
1361+
and self._workload_identity_provider
1362+
not in (
1363+
AttestationProvider.GCP,
1364+
AttestationProvider.AWS,
1365+
)
13621366
):
13631367
Error.errorhandler_wrapper(
13641368
self,
13651369
None,
13661370
ProgrammingError,
13671371
{
1368-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1372+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
13691373
"errno": ER_INVALID_WIF_SETTINGS,
13701374
},
13711375
)

src/snowflake/connector/wif_util.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,37 @@ def get_aws_sts_hostname(region: str, partition: str) -> str:
145145
)
146146

147147

148+
def get_aws_session(impersonation_path: list[str] | None = None):
149+
"""Creates a boto3 session with the appropriate credentials.
150+
151+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
152+
"""
153+
session = boto3.session.Session()
154+
155+
impersonation_path = impersonation_path or []
156+
for arn in impersonation_path:
157+
response = session.client("sts").assume_role(
158+
RoleArn=arn, RoleSessionName="identity-federation-session"
159+
)
160+
creds = response["Credentials"]
161+
session = boto3.session.Session(
162+
aws_access_key_id=creds["AccessKeyId"],
163+
aws_secret_access_key=creds["SecretAccessKey"],
164+
aws_session_token=creds["SessionToken"],
165+
)
166+
return session
167+
168+
148169
def create_aws_attestation(
149-
session_manager: SessionManager | None = None,
170+
impersonation_path: list[str] | None = None,
150171
) -> WorkloadIdentityAttestation:
151172
"""Tries to create a workload identity attestation for AWS.
152173
153174
If the application isn't running on AWS or no credentials were found, raises an error.
154175
"""
155176
# TODO: SNOW-2223669 Investigate if our adapters - containing settings of http traffic - should be passed here as boto urllib3session. Those requests go to local servers, so they do not need Proxy setup or Headers customization in theory. But we may want to have all the traffic going through one class (e.g. Adapter or mixin).
156-
session = boto3.session.Session()
177+
session = get_aws_session(impersonation_path)
178+
157179
aws_creds = session.get_credentials()
158180
if not aws_creds:
159181
raise ProgrammingError(
@@ -385,7 +407,7 @@ def create_attestation(
385407
)
386408

387409
if provider == AttestationProvider.AWS:
388-
return create_aws_attestation(session_manager)
410+
return create_aws_attestation(impersonation_path)
389411
elif provider == AttestationProvider.AZURE:
390412
return create_azure_attestation(entra_resource, session_manager)
391413
elif provider == AttestationProvider.GCP:

test/csp_helpers.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ def gen_dummy_id_token(
4040
)
4141

4242

43-
def gen_dummy_access_token(sub="test-subject") -> str:
43+
def gen_dummy_access_token(sub="test-subject", key="secret") -> str:
4444
"""Generates a dummy access token using the given subject."""
45-
key = "secret"
4645
logger.debug(f"Generating dummy access token for subject {sub}")
4746
return (sub + key).encode("utf-8").hex()
4847

@@ -368,6 +367,11 @@ class FakeAwsEnvironment:
368367
def __init__(self):
369368
# Defaults used for generating a token. Can be overriden in individual tests.
370369
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
370+
# Path of roles that can be assumed. Empty if no impersonation is allowed.
371+
# Can be overriden in individual tests.
372+
self.assumption_path = []
373+
self.assume_role_call_count = 0
374+
371375
self.caller_identity = {"Arn": self.arn}
372376
self.region = "us-east-1"
373377
self.credentials = Credentials(access_key="ak", secret_key="sk")
@@ -376,6 +380,25 @@ def __init__(self):
376380
)
377381
self.metadata_token = "test-token"
378382

383+
def assume_role(self, **kwargs):
384+
if (
385+
self.assumption_path
386+
and kwargs["RoleArn"] == self.assumption_path[self.assume_role_call_count]
387+
):
388+
arn = self.assumption_path[self.assume_role_call_count]
389+
self.assume_role_call_count += 1
390+
return {
391+
"Credentials": {
392+
"AccessKeyId": "access_key",
393+
"SecretAccessKey": "secret_key",
394+
"SessionToken": "session_token",
395+
"Expiration": int(time()) + 60 * 60,
396+
},
397+
"AssumedRoleUser": {"AssumedRoleId": hash(arn), "Arn": arn},
398+
"ResponseMetadata": {},
399+
}
400+
return {}
401+
379402
def get_region(self):
380403
return self.region
381404

@@ -401,6 +424,7 @@ def fetcher_fetch_metadata_token(self):
401424
def boto3_client(self, *args, **kwargs):
402425
mock_client = mock.Mock()
403426
mock_client.get_caller_identity.return_value = self.caller_identity
427+
mock_client.assume_role = self.assume_role
404428
return mock_client
405429

406430
def __enter__(self):
@@ -443,6 +467,9 @@ def __enter__(self):
443467
side_effect=self.boto3_client,
444468
)
445469
)
470+
self.patchers.append(
471+
mock.patch("boto3.session.Session.client", side_effect=self.boto3_client)
472+
)
446473
for patcher in self.patchers:
447474
patcher.__enter__()
448475
return self

test/unit/test_auth_workload_identity.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,22 @@ def test_get_aws_sts_hostname_invalid_inputs(region, partition):
275275
assert "Invalid AWS partition" in str(excinfo.value)
276276

277277

278+
def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
279+
fake_aws_environment: FakeAwsEnvironment,
280+
):
281+
impersonation_path = [
282+
"arn:aws:iam::123456789:role/role2",
283+
"arn:aws:iam::123456789:role/role3",
284+
]
285+
fake_aws_environment.assumption_path = impersonation_path
286+
auth_class = AuthByWorkloadIdentity(
287+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
288+
)
289+
auth_class.prepare(conn=None)
290+
291+
assert fake_aws_environment.assume_role_call_count == 2
292+
293+
278294
# -- GCP Tests --
279295

280296

test/unit/test_connection.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -684,16 +684,14 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
684684
"provider_param",
685685
[
686686
# Strongly-typed values.
687-
AttestationProvider.AWS,
688687
AttestationProvider.AZURE,
689688
AttestationProvider.OIDC,
690689
# String values.
691-
"AWS",
692690
"AZURE",
693691
"OIDC",
694692
],
695693
)
696-
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
694+
def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
697695
monkeypatch, provider_param
698696
):
699697
with monkeypatch.context() as m:
@@ -711,20 +709,22 @@ def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
711709
],
712710
)
713711
assert (
714-
"workload_identity_impersonation_path is currently only supported for GCP."
712+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
715713
in str(excinfo.value)
716714
)
717715

718716

719717
@pytest.mark.parametrize(
720-
"provider_param",
718+
"provider_param,impersonation_path",
721719
[
722-
AttestationProvider.GCP,
723-
"GCP",
720+
(AttestationProvider.GCP, ["[email protected]"]),
721+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
722+
("GCP", ["[email protected]"]),
723+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
724724
],
725725
)
726-
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
727-
monkeypatch, provider_param
726+
def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
727+
monkeypatch, provider_param, impersonation_path
728728
):
729729
with monkeypatch.context() as m:
730730
m.setattr(
@@ -735,14 +735,9 @@ def test_workload_identity_impersonation_path_supported_for_gcp_provider(
735735
account="account",
736736
authenticator="WORKLOAD_IDENTITY",
737737
workload_identity_provider=provider_param,
738-
workload_identity_impersonation_path=[
739-
740-
],
738+
workload_identity_impersonation_path=impersonation_path,
741739
)
742-
assert conn.auth_class.provider == AttestationProvider.GCP
743-
assert conn.auth_class.impersonation_path == [
744-
745-
]
740+
assert conn.auth_class.impersonation_path == impersonation_path
746741

747742

748743
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)