Skip to content

Commit 10b76d3

Browse files
remove silent exception catching; fix async get_aws_region
1 parent 30865dc commit 10b76d3

File tree

2 files changed

+64
-65
lines changed

2 files changed

+64
-65
lines changed

src/snowflake/connector/aio/_wif_util.py

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import aioboto3
1414
import aiohttp
15-
from aiobotocore.utils import InstanceMetadataRegionFetcher
15+
from aiobotocore.utils import AioInstanceMetadataRegionFetcher
1616
from botocore.auth import SigV4Auth
1717
from botocore.awsrequest import AWSRequest
1818

@@ -58,84 +58,61 @@ async def get_aws_region() -> str | None:
5858
if "AWS_REGION" in os.environ: # Lambda
5959
return os.environ["AWS_REGION"]
6060
else: # EC2
61-
return await InstanceMetadataRegionFetcher().retrieve_region()
61+
return await AioInstanceMetadataRegionFetcher().retrieve_region()
6262

6363

6464
async def get_aws_arn() -> str | None:
6565
"""Get the current AWS workload's ARN, if any."""
66-
if aioboto3 is None:
67-
logger.debug("aioboto3 not available, falling back to sync implementation")
68-
from ..wif_util import get_aws_arn as sync_get_aws_arn
69-
70-
return sync_get_aws_arn()
71-
72-
try:
73-
session = aioboto3.Session()
74-
async with session.client("sts") as client:
75-
caller_identity = await client.get_caller_identity()
76-
if not caller_identity or "Arn" not in caller_identity:
77-
return None
78-
return caller_identity["Arn"]
79-
except Exception:
80-
logger.debug("Failed to get AWS ARN", exc_info=True)
81-
return None
66+
session = aioboto3.Session()
67+
async with session.client("sts") as client:
68+
caller_identity = await client.get_caller_identity()
69+
if not caller_identity or "Arn" not in caller_identity:
70+
return None
71+
return caller_identity["Arn"]
8272

8373

8474
async def create_aws_attestation() -> WorkloadIdentityAttestation | None:
8575
"""Tries to create a workload identity attestation for AWS.
8676
8777
If the application isn't running on AWS or no credentials were found, returns None.
8878
"""
89-
if aioboto3 is None:
90-
logger.debug("aioboto3 not available, falling back to sync implementation")
91-
from ..wif_util import create_aws_attestation as sync_create_aws_attestation
92-
93-
return sync_create_aws_attestation()
94-
95-
try:
96-
# Get credentials using aioboto3
97-
session = aioboto3.Session()
98-
aws_creds = await session.get_credentials() # This IS async in aioboto3
99-
if not aws_creds:
100-
logger.debug("No AWS credentials were found.")
101-
return None
79+
session = aioboto3.Session()
80+
aws_creds = await session.get_credentials()
81+
if not aws_creds:
82+
logger.debug("No AWS credentials were found.")
83+
return None
10284

103-
region = await get_aws_region()
104-
if not region:
105-
logger.debug("No AWS region was found.")
106-
return None
85+
region = await get_aws_region()
86+
if not region:
87+
logger.debug("No AWS region was found.")
88+
return None
10789

108-
arn = await get_aws_arn()
109-
if not arn:
110-
logger.debug("No AWS caller identity was found.")
111-
return None
90+
arn = await get_aws_arn()
91+
if not arn:
92+
logger.debug("No AWS caller identity was found.")
93+
return None
11294

113-
sts_hostname = f"sts.{region}.amazonaws.com"
114-
request = AWSRequest(
115-
method="POST",
116-
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
117-
headers={
118-
"Host": sts_hostname,
119-
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
120-
},
121-
)
95+
sts_hostname = f"sts.{region}.amazonaws.com"
96+
request = AWSRequest(
97+
method="POST",
98+
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
99+
headers={
100+
"Host": sts_hostname,
101+
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
102+
},
103+
)
122104

123-
SigV4Auth(aws_creds, "sts", region).add_auth(request)
105+
SigV4Auth(aws_creds, "sts", region).add_auth(request)
124106

125-
assertion_dict = {
126-
"url": request.url,
127-
"method": request.method,
128-
"headers": dict(request.headers.items()),
129-
}
130-
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode(
131-
"utf-8"
132-
)
133-
return WorkloadIdentityAttestation(
134-
AttestationProvider.AWS, credential, {"arn": arn}
135-
)
136-
except Exception:
137-
logger.debug("Failed to create AWS attestation", exc_info=True)
138-
return None
107+
assertion_dict = {
108+
"url": request.url,
109+
"method": request.method,
110+
"headers": dict(request.headers.items()),
111+
}
112+
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
113+
return WorkloadIdentityAttestation(
114+
AttestationProvider.AWS, credential, {"arn": arn}
115+
)
139116

140117

141118
async def create_gcp_attestation() -> WorkloadIdentityAttestation | None:

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,17 @@ async def test_autodetect_aws_present(
337337
verify_aws_token(data["TOKEN"], fake_aws_environment.region)
338338

339339

340+
@mock.patch("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher")
340341
async def test_autodetect_gcp_present(
342+
mock_fetcher,
341343
fake_gce_metadata_service: FakeGceMetadataService,
342344
):
345+
# Mock AioInstanceMetadataRegionFetcher to return None properly as an async function
346+
async def mock_retrieve_region():
347+
return None
348+
349+
mock_fetcher.return_value.retrieve_region.side_effect = mock_retrieve_region
350+
343351
auth_class = AuthByWorkloadIdentity(provider=None)
344352
await auth_class.prepare()
345353

@@ -350,7 +358,14 @@ async def test_autodetect_gcp_present(
350358
}
351359

352360

353-
async def test_autodetect_azure_present(fake_azure_metadata_service):
361+
@mock.patch("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher")
362+
async def test_autodetect_azure_present(mock_fetcher, fake_azure_metadata_service):
363+
# Mock AioInstanceMetadataRegionFetcher to return None properly as an async function
364+
async def mock_retrieve_region():
365+
return None
366+
367+
mock_fetcher.return_value.retrieve_region.side_effect = mock_retrieve_region
368+
354369
auth_class = AuthByWorkloadIdentity(provider=None)
355370
await auth_class.prepare()
356371

@@ -373,7 +388,14 @@ async def test_autodetect_oidc_present(no_metadata_service):
373388
}
374389

375390

376-
async def test_autodetect_no_provider_raises_error(no_metadata_service):
391+
@mock.patch("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher")
392+
async def test_autodetect_no_provider_raises_error(mock_fetcher, no_metadata_service):
393+
# Mock AioInstanceMetadataRegionFetcher to return None properly as an async function
394+
async def mock_retrieve_region():
395+
return None
396+
397+
mock_fetcher.return_value.retrieve_region.side_effect = mock_retrieve_region
398+
377399
auth_class = AuthByWorkloadIdentity(provider=None, token=None)
378400
with pytest.raises(ProgrammingError) as excinfo:
379401
await auth_class.prepare()

0 commit comments

Comments
 (0)