File tree Expand file tree Collapse file tree 2 files changed +29
-15
lines changed
src/snowflake/connector/aio Expand file tree Collapse file tree 2 files changed +29
-15
lines changed Original file line number Diff line number Diff line change 10
10
import os
11
11
from base64 import b64encode
12
12
13
+ import aioboto3
13
14
import aiohttp
14
-
15
- try :
16
- import aioboto3
17
- from botocore .auth import SigV4Auth
18
- from botocore .awsrequest import AWSRequest
19
- from botocore .utils import InstanceMetadataRegionFetcher
20
- except ImportError :
21
- aioboto3 = None
22
- SigV4Auth = None
23
- AWSRequest = None
24
- InstanceMetadataRegionFetcher = None
15
+ from aiobotocore .utils import InstanceMetadataRegionFetcher
16
+ from botocore .auth import SigV4Auth
17
+ from botocore .awsrequest import AWSRequest
25
18
26
19
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
27
20
from ..errors import ProgrammingError
@@ -62,10 +55,10 @@ async def try_metadata_service_call(
62
55
63
56
async def get_aws_region () -> str | None :
64
57
"""Get the current AWS workload's region, if any."""
65
- # Use sync implementation which has proper mocking support
66
- from .. wif_util import get_aws_region as sync_get_aws_region
67
-
68
- return sync_get_aws_region ()
58
+ if "AWS_REGION" in os . environ : # Lambda
59
+ return os . environ [ "AWS_REGION" ]
60
+ else : # EC2
61
+ return await InstanceMetadataRegionFetcher (). retrieve_region ()
69
62
70
63
71
64
async def get_aws_arn () -> str | None :
Original file line number Diff line number Diff line change @@ -364,6 +364,9 @@ async def async_get_credentials():
364
364
async def async_get_caller_identity ():
365
365
return {"Arn" : self .arn }
366
366
367
+ async def async_get_region ():
368
+ return self .get_region ()
369
+
367
370
# Mock aioboto3.Session.get_credentials (IS async)
368
371
self .patchers .append (
369
372
mock .patch (
@@ -372,6 +375,24 @@ async def async_get_caller_identity():
372
375
)
373
376
)
374
377
378
+ # Mock the async AWS region and ARN functions
379
+ self .patchers .append (
380
+ mock .patch (
381
+ "snowflake.connector.aio._wif_util.get_aws_region" ,
382
+ side_effect = async_get_region ,
383
+ )
384
+ )
385
+
386
+ async def async_get_arn ():
387
+ return self .get_arn ()
388
+
389
+ self .patchers .append (
390
+ mock .patch (
391
+ "snowflake.connector.aio._wif_util.get_aws_arn" ,
392
+ side_effect = async_get_arn ,
393
+ )
394
+ )
395
+
375
396
# Mock the async STS client for direct aioboto3 usage
376
397
class MockStsClient :
377
398
async def __aenter__ (self ):
You can’t perform that action at this time.
0 commit comments