Skip to content

Commit 30865dc

Browse files
Fix async get_aws_region
1 parent c6394c7 commit 30865dc

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

src/snowflake/connector/aio/_wif_util.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,11 @@
1010
import os
1111
from base64 import b64encode
1212

13+
import aioboto3
1314
import aiohttp
14-
15-
try:
16-
import aioboto3
17-
from botocore.auth import SigV4Auth
18-
from botocore.awsrequest import AWSRequest
19-
from botocore.utils import InstanceMetadataRegionFetcher
20-
except ImportError:
21-
aioboto3 = None
22-
SigV4Auth = None
23-
AWSRequest = None
24-
InstanceMetadataRegionFetcher = None
15+
from aiobotocore.utils import InstanceMetadataRegionFetcher
16+
from botocore.auth import SigV4Auth
17+
from botocore.awsrequest import AWSRequest
2518

2619
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
2720
from ..errors import ProgrammingError
@@ -62,10 +55,10 @@ async def try_metadata_service_call(
6255

6356
async def get_aws_region() -> str | None:
6457
"""Get the current AWS workload's region, if any."""
65-
# Use sync implementation which has proper mocking support
66-
from ..wif_util import get_aws_region as sync_get_aws_region
67-
68-
return sync_get_aws_region()
58+
if "AWS_REGION" in os.environ: # Lambda
59+
return os.environ["AWS_REGION"]
60+
else: # EC2
61+
return await InstanceMetadataRegionFetcher().retrieve_region()
6962

7063

7164
async def get_aws_arn() -> str | None:

test/csp_helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ async def async_get_credentials():
364364
async def async_get_caller_identity():
365365
return {"Arn": self.arn}
366366

367+
async def async_get_region():
368+
return self.get_region()
369+
367370
# Mock aioboto3.Session.get_credentials (IS async)
368371
self.patchers.append(
369372
mock.patch(
@@ -372,6 +375,24 @@ async def async_get_caller_identity():
372375
)
373376
)
374377

378+
# Mock the async AWS region and ARN functions
379+
self.patchers.append(
380+
mock.patch(
381+
"snowflake.connector.aio._wif_util.get_aws_region",
382+
side_effect=async_get_region,
383+
)
384+
)
385+
386+
async def async_get_arn():
387+
return self.get_arn()
388+
389+
self.patchers.append(
390+
mock.patch(
391+
"snowflake.connector.aio._wif_util.get_aws_arn",
392+
side_effect=async_get_arn,
393+
)
394+
)
395+
375396
# Mock the async STS client for direct aioboto3 usage
376397
class MockStsClient:
377398
async def __aenter__(self):

0 commit comments

Comments
 (0)