Skip to content

Commit 75d33cc

Browse files
[async] WIF impersonation for AWS (#2517)
1 parent 7220c4c commit 75d33cc

File tree

5 files changed

+66
-21
lines changed

5 files changed

+66
-21
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,14 +420,18 @@ async def __open_connection(self):
420420
)
421421
if (
422422
self._workload_identity_impersonation_path
423-
and self._workload_identity_provider != AttestationProvider.GCP
423+
and self._workload_identity_provider
424+
not in (
425+
AttestationProvider.GCP,
426+
AttestationProvider.AWS,
427+
)
424428
):
425429
Error.errorhandler_wrapper(
426430
self,
427431
None,
428432
ProgrammingError,
429433
{
430-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
434+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
431435
"errno": ER_INVALID_WIF_SETTINGS,
432436
},
433437
)

src/snowflake/connector/aio/_wif_util.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,37 @@ async def get_aws_region() -> str:
4545
return region
4646

4747

48-
async def create_aws_attestation() -> WorkloadIdentityAttestation:
48+
async def get_aws_session(impersonation_path: list[str] | None = None):
49+
"""Creates an aioboto3 session with the appropriate credentials.
50+
51+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
52+
"""
53+
session = aioboto3.Session()
54+
55+
impersonation_path = impersonation_path or []
56+
for arn in impersonation_path:
57+
async with session.client("sts") as sts_client:
58+
response = await sts_client.assume_role(
59+
RoleArn=arn, RoleSessionName="identity-federation-session"
60+
)
61+
creds = response["Credentials"]
62+
session = aioboto3.Session(
63+
aws_access_key_id=creds["AccessKeyId"],
64+
aws_secret_access_key=creds["SecretAccessKey"],
65+
aws_session_token=creds["SessionToken"],
66+
)
67+
return session
68+
69+
70+
async def create_aws_attestation(
71+
impersonation_path: list[str] | None = None,
72+
) -> WorkloadIdentityAttestation:
4973
"""Tries to create a workload identity attestation for AWS.
5074
5175
If the application isn't running on AWS or no credentials were found, raises an error.
5276
"""
53-
session = aioboto3.Session()
77+
session = await get_aws_session(impersonation_path)
78+
5479
aws_creds = await session.get_credentials()
5580
if not aws_creds:
5681
raise ProgrammingError(
@@ -274,7 +299,7 @@ async def create_attestation(
274299
)
275300

276301
if provider == AttestationProvider.AWS:
277-
return await create_aws_attestation()
302+
return await create_aws_attestation(impersonation_path)
278303
elif provider == AttestationProvider.AZURE:
279304
return await create_azure_attestation(entra_resource, session_manager)
280305
elif provider == AttestationProvider.GCP:

test/unit/aio/csp_helpers_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ async def async_get_arn():
202202
)
203203

204204
# Mock the async STS client for direct aioboto3 usage
205+
fake_aws_self = self
206+
205207
class MockStsClient:
206208
async def __aenter__(self):
207209
return self
@@ -212,6 +214,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
212214
async def get_caller_identity(self):
213215
return await async_get_caller_identity()
214216

217+
async def assume_role(self, **kwargs):
218+
return fake_aws_self.assume_role(**kwargs)
219+
215220
def mock_session_client(service_name):
216221
if service_name == "sts":
217222
return MockStsClient()

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,22 @@ async def test_explicit_aws_generates_unique_assertion_content(
253253
)
254254

255255

256+
async def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
257+
fake_aws_environment: FakeAwsEnvironmentAsync,
258+
):
259+
impersonation_path = [
260+
"arn:aws:iam::123456789:role/role2",
261+
"arn:aws:iam::123456789:role/role3",
262+
]
263+
fake_aws_environment.assumption_path = impersonation_path
264+
auth_class = AuthByWorkloadIdentity(
265+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
266+
)
267+
await auth_class.prepare(conn=None)
268+
269+
assert fake_aws_environment.assume_role_call_count == 2
270+
271+
256272
# -- GCP Tests --
257273

258274

test/unit/aio/test_connection_async_unit.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -659,16 +659,14 @@ async def test_workload_identity_provider_is_required_for_wif_authenticator(
659659
"provider_param",
660660
[
661661
# Strongly-typed values.
662-
AttestationProvider.AWS,
663662
AttestationProvider.AZURE,
664663
AttestationProvider.OIDC,
665664
# String values.
666-
"AWS",
667665
"AZURE",
668666
"OIDC",
669667
],
670668
)
671-
async def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
669+
async def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
672670
monkeypatch, provider_param
673671
):
674672
async def mock_authenticate(*_):
@@ -690,20 +688,22 @@ async def mock_authenticate(*_):
690688
],
691689
)
692690
assert (
693-
"workload_identity_impersonation_path is currently only supported for GCP."
691+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
694692
in str(excinfo.value)
695693
)
696694

697695

698696
@pytest.mark.parametrize(
699-
"provider_param",
697+
"provider_param,impersonation_path",
700698
[
701-
AttestationProvider.GCP,
702-
"GCP",
699+
(AttestationProvider.GCP, ["[email protected]"]),
700+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
701+
("GCP", ["[email protected]"]),
702+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
703703
],
704704
)
705-
async def test_workload_identity_impersonation_path_supported_for_gcp_provider(
706-
monkeypatch, provider_param
705+
async def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
706+
monkeypatch, provider_param, impersonation_path
707707
):
708708
async def mock_authenticate(*_):
709709
pass
@@ -718,14 +718,9 @@ async def mock_authenticate(*_):
718718
account="account",
719719
authenticator="WORKLOAD_IDENTITY",
720720
workload_identity_provider=provider_param,
721-
workload_identity_impersonation_path=[
722-
723-
],
721+
workload_identity_impersonation_path=impersonation_path,
724722
)
725-
assert conn.auth_class.provider == AttestationProvider.GCP
726-
assert conn.auth_class.impersonation_path == [
727-
728-
]
723+
assert conn.auth_class.impersonation_path == impersonation_path
729724

730725

731726
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)