Skip to content

Commit 63f918c

Browse files
use aioboto3
1 parent 7ecd740 commit 63f918c

File tree

3 files changed

+140
-4
lines changed

3 files changed

+140
-4
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,4 @@ secure-local-storage =
100100
keyring>=23.1.0,<26.0.0
101101
aio =
102102
aiohttp
103+
aioboto3

src/snowflake/connector/aio/_wif_util.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,28 @@
88
import json
99
import logging
1010
import os
11+
from base64 import b64encode
1112

1213
import aiohttp
1314

15+
try:
16+
import aioboto3
17+
from botocore.auth import SigV4Auth
18+
from botocore.awsrequest import AWSRequest
19+
from botocore.utils import InstanceMetadataRegionFetcher
20+
except ImportError:
21+
aioboto3 = None
22+
SigV4Auth = None
23+
AWSRequest = None
24+
InstanceMetadataRegionFetcher = None
25+
1426
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
1527
from ..errors import ProgrammingError
1628
from ..wif_util import (
1729
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE,
1830
SNOWFLAKE_AUDIENCE,
1931
AttestationProvider,
2032
WorkloadIdentityAttestation,
21-
create_aws_attestation,
2233
create_oidc_attestation,
2334
extract_iss_and_sub_without_signature_verification,
2435
)
@@ -49,6 +60,91 @@ async def try_metadata_service_call(
4960
return None
5061

5162

63+
async def get_aws_region() -> str | None:
64+
"""Get the current AWS workload's region, if any."""
65+
# Use sync implementation which has proper mocking support
66+
from ..wif_util import get_aws_region as sync_get_aws_region
67+
68+
return sync_get_aws_region()
69+
70+
71+
async def get_aws_arn() -> str | None:
72+
"""Get the current AWS workload's ARN, if any."""
73+
if aioboto3 is None:
74+
logger.debug("aioboto3 not available, falling back to sync implementation")
75+
from ..wif_util import get_aws_arn as sync_get_aws_arn
76+
77+
return sync_get_aws_arn()
78+
79+
try:
80+
session = aioboto3.Session()
81+
async with session.client("sts") as client:
82+
caller_identity = await client.get_caller_identity()
83+
if not caller_identity or "Arn" not in caller_identity:
84+
return None
85+
return caller_identity["Arn"]
86+
except Exception:
87+
logger.debug("Failed to get AWS ARN", exc_info=True)
88+
return None
89+
90+
91+
async def create_aws_attestation() -> WorkloadIdentityAttestation | None:
92+
"""Tries to create a workload identity attestation for AWS.
93+
94+
If the application isn't running on AWS or no credentials were found, returns None.
95+
"""
96+
if aioboto3 is None:
97+
logger.debug("aioboto3 not available, falling back to sync implementation")
98+
from ..wif_util import create_aws_attestation as sync_create_aws_attestation
99+
100+
return sync_create_aws_attestation()
101+
102+
try:
103+
# Get credentials using aioboto3
104+
session = aioboto3.Session()
105+
aws_creds = await session.get_credentials() # This IS async in aioboto3
106+
if not aws_creds:
107+
logger.debug("No AWS credentials were found.")
108+
return None
109+
110+
region = await get_aws_region()
111+
if not region:
112+
logger.debug("No AWS region was found.")
113+
return None
114+
115+
arn = await get_aws_arn()
116+
if not arn:
117+
logger.debug("No AWS caller identity was found.")
118+
return None
119+
120+
sts_hostname = f"sts.{region}.amazonaws.com"
121+
request = AWSRequest(
122+
method="POST",
123+
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
124+
headers={
125+
"Host": sts_hostname,
126+
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
127+
},
128+
)
129+
130+
SigV4Auth(aws_creds, "sts", region).add_auth(request)
131+
132+
assertion_dict = {
133+
"url": request.url,
134+
"method": request.method,
135+
"headers": dict(request.headers.items()),
136+
}
137+
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode(
138+
"utf-8"
139+
)
140+
return WorkloadIdentityAttestation(
141+
AttestationProvider.AWS, credential, {"arn": arn}
142+
)
143+
except Exception:
144+
logger.debug("Failed to create AWS attestation", exc_info=True)
145+
return None
146+
147+
52148
async def create_gcp_attestation() -> WorkloadIdentityAttestation | None:
53149
"""Tries to create a workload identity attestation for GCP.
54150
@@ -157,7 +253,7 @@ async def create_autodetect_attestation(
157253
if attestation:
158254
return attestation
159255

160-
attestation = create_aws_attestation()
256+
attestation = await create_aws_attestation()
161257
if attestation:
162258
return attestation
163259

@@ -188,7 +284,7 @@ async def create_attestation(
188284

189285
attestation: WorkloadIdentityAttestation = None
190286
if provider == AttestationProvider.AWS:
191-
attestation = create_aws_attestation()
287+
attestation = await create_aws_attestation()
192288
elif provider == AttestationProvider.AZURE:
193289
attestation = await create_azure_attestation(entra_resource)
194290
elif provider == AttestationProvider.GCP:

test/csp_helpers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ def sign_request(self, request: AWSRequest):
332332
def __enter__(self):
333333
# Patch the relevant functions to do what we want.
334334
self.patchers = []
335+
336+
# Patch sync boto3 calls
335337
self.patchers.append(
336338
mock.patch(
337339
"boto3.session.Session.get_credentials",
@@ -354,7 +356,44 @@ def __enter__(self):
354356
"snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn
355357
)
356358
)
357-
# Note: No need to patch async versions anymore since async now imports from sync
359+
360+
# Patch async aioboto3 calls (for when aioboto3 is used directly)
361+
async def async_get_credentials():
362+
return self.credentials
363+
364+
async def async_get_caller_identity():
365+
return {"Arn": self.arn}
366+
367+
# Mock aioboto3.Session.get_credentials (IS async)
368+
self.patchers.append(
369+
mock.patch(
370+
"snowflake.connector.aio._wif_util.aioboto3.Session.get_credentials",
371+
side_effect=async_get_credentials,
372+
)
373+
)
374+
375+
# Mock the async STS client for direct aioboto3 usage
376+
class MockStsClient:
377+
async def __aenter__(self):
378+
return self
379+
380+
async def __aexit__(self, exc_type, exc_val, exc_tb):
381+
pass
382+
383+
async def get_caller_identity(self):
384+
return await async_get_caller_identity()
385+
386+
def mock_session_client(service_name):
387+
if service_name == "sts":
388+
return MockStsClient()
389+
return None
390+
391+
self.patchers.append(
392+
mock.patch(
393+
"snowflake.connector.aio._wif_util.aioboto3.Session.client",
394+
side_effect=mock_session_client,
395+
)
396+
)
358397
for patcher in self.patchers:
359398
patcher.__enter__()
360399
return self

0 commit comments

Comments
 (0)