Skip to content

Commit 815cd95

Browse files
[async] WIF impersonation for AWS (#2517)
1 parent 67822a2 commit 815cd95

File tree

5 files changed

+66
-21
lines changed

5 files changed

+66
-21
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,14 +416,18 @@ async def __open_connection(self):
416416
)
417417
if (
418418
self._workload_identity_impersonation_path
419-
and self._workload_identity_provider != AttestationProvider.GCP
419+
and self._workload_identity_provider
420+
not in (
421+
AttestationProvider.GCP,
422+
AttestationProvider.AWS,
423+
)
420424
):
421425
Error.errorhandler_wrapper(
422426
self,
423427
None,
424428
ProgrammingError,
425429
{
426-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
430+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
427431
"errno": ER_INVALID_WIF_SETTINGS,
428432
},
429433
)

src/snowflake/connector/aio/_wif_util.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,37 @@ async def get_aws_region() -> str:
4545
return region
4646

4747

48-
async def create_aws_attestation() -> WorkloadIdentityAttestation:
48+
async def get_aws_session(impersonation_path: list[str] | None = None):
49+
"""Creates an aioboto3 session with the appropriate credentials.
50+
51+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
52+
"""
53+
session = aioboto3.Session()
54+
55+
impersonation_path = impersonation_path or []
56+
for arn in impersonation_path:
57+
async with session.client("sts") as sts_client:
58+
response = await sts_client.assume_role(
59+
RoleArn=arn, RoleSessionName="identity-federation-session"
60+
)
61+
creds = response["Credentials"]
62+
session = aioboto3.Session(
63+
aws_access_key_id=creds["AccessKeyId"],
64+
aws_secret_access_key=creds["SecretAccessKey"],
65+
aws_session_token=creds["SessionToken"],
66+
)
67+
return session
68+
69+
70+
async def create_aws_attestation(
71+
impersonation_path: list[str] | None = None,
72+
) -> WorkloadIdentityAttestation:
4973
"""Tries to create a workload identity attestation for AWS.
5074
5175
If the application isn't running on AWS or no credentials were found, raises an error.
5276
"""
53-
session = aioboto3.Session()
77+
session = await get_aws_session(impersonation_path)
78+
5479
aws_creds = await session.get_credentials()
5580
if not aws_creds:
5681
raise ProgrammingError(
@@ -276,7 +301,7 @@ async def create_attestation(
276301
)
277302

278303
if provider == AttestationProvider.AWS:
279-
return await create_aws_attestation()
304+
return await create_aws_attestation(impersonation_path)
280305
elif provider == AttestationProvider.AZURE:
281306
return await create_azure_attestation(entra_resource, session_manager)
282307
elif provider == AttestationProvider.GCP:

test/unit/aio/csp_helpers_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ async def async_get_arn():
202202
)
203203

204204
# Mock the async STS client for direct aioboto3 usage
205+
fake_aws_self = self
206+
205207
class MockStsClient:
206208
async def __aenter__(self):
207209
return self
@@ -212,6 +214,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
212214
async def get_caller_identity(self):
213215
return await async_get_caller_identity()
214216

217+
async def assume_role(self, **kwargs):
218+
return fake_aws_self.assume_role(**kwargs)
219+
215220
def mock_session_client(service_name):
216221
if service_name == "sts":
217222
return MockStsClient()

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,22 @@ async def test_explicit_aws_generates_unique_assertion_content(
253253
)
254254

255255

256+
async def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
257+
fake_aws_environment: FakeAwsEnvironmentAsync,
258+
):
259+
impersonation_path = [
260+
"arn:aws:iam::123456789:role/role2",
261+
"arn:aws:iam::123456789:role/role3",
262+
]
263+
fake_aws_environment.assumption_path = impersonation_path
264+
auth_class = AuthByWorkloadIdentity(
265+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
266+
)
267+
await auth_class.prepare(conn=None)
268+
269+
assert fake_aws_environment.assume_role_call_count == 2
270+
271+
256272
# -- GCP Tests --
257273

258274

test/unit/aio/test_connection_async_unit.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -660,16 +660,14 @@ async def test_workload_identity_provider_is_required_for_wif_authenticator(
660660
"provider_param",
661661
[
662662
# Strongly-typed values.
663-
AttestationProvider.AWS,
664663
AttestationProvider.AZURE,
665664
AttestationProvider.OIDC,
666665
# String values.
667-
"AWS",
668666
"AZURE",
669667
"OIDC",
670668
],
671669
)
672-
async def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
670+
async def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
673671
monkeypatch, provider_param
674672
):
675673
async def mock_authenticate(*_):
@@ -691,20 +689,22 @@ async def mock_authenticate(*_):
691689
],
692690
)
693691
assert (
694-
"workload_identity_impersonation_path is currently only supported for GCP."
692+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
695693
in str(excinfo.value)
696694
)
697695

698696

699697
@pytest.mark.parametrize(
700-
"provider_param",
698+
"provider_param,impersonation_path",
701699
[
702-
AttestationProvider.GCP,
703-
"GCP",
700+
(AttestationProvider.GCP, ["[email protected]"]),
701+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
702+
("GCP", ["[email protected]"]),
703+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
704704
],
705705
)
706-
async def test_workload_identity_impersonation_path_supported_for_gcp_provider(
707-
monkeypatch, provider_param
706+
async def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
707+
monkeypatch, provider_param, impersonation_path
708708
):
709709
async def mock_authenticate(*_):
710710
pass
@@ -719,14 +719,9 @@ async def mock_authenticate(*_):
719719
account="account",
720720
authenticator="WORKLOAD_IDENTITY",
721721
workload_identity_provider=provider_param,
722-
workload_identity_impersonation_path=[
723-
724-
],
722+
workload_identity_impersonation_path=impersonation_path,
725723
)
726-
assert conn.auth_class.provider == AttestationProvider.GCP
727-
assert conn.auth_class.impersonation_path == [
728-
729-
]
724+
assert conn.auth_class.impersonation_path == impersonation_path
730725

731726

732727
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)