|
12 | 12 |
|
13 | 13 | import aioboto3
|
14 | 14 | import aiohttp
|
15 |
| -from aiobotocore.utils import InstanceMetadataRegionFetcher |
| 15 | +from aiobotocore.utils import AioInstanceMetadataRegionFetcher |
16 | 16 | from botocore.auth import SigV4Auth
|
17 | 17 | from botocore.awsrequest import AWSRequest
|
18 | 18 |
|
@@ -58,84 +58,61 @@ async def get_aws_region() -> str | None:
|
58 | 58 | if "AWS_REGION" in os.environ: # Lambda
|
59 | 59 | return os.environ["AWS_REGION"]
|
60 | 60 | else: # EC2
|
61 |
| - return await InstanceMetadataRegionFetcher().retrieve_region() |
| 61 | + return await AioInstanceMetadataRegionFetcher().retrieve_region() |
62 | 62 |
|
63 | 63 |
|
64 | 64 | async def get_aws_arn() -> str | None:
|
65 | 65 | """Get the current AWS workload's ARN, if any."""
|
66 |
| - if aioboto3 is None: |
67 |
| - logger.debug("aioboto3 not available, falling back to sync implementation") |
68 |
| - from ..wif_util import get_aws_arn as sync_get_aws_arn |
69 |
| - |
70 |
| - return sync_get_aws_arn() |
71 |
| - |
72 |
| - try: |
73 |
| - session = aioboto3.Session() |
74 |
| - async with session.client("sts") as client: |
75 |
| - caller_identity = await client.get_caller_identity() |
76 |
| - if not caller_identity or "Arn" not in caller_identity: |
77 |
| - return None |
78 |
| - return caller_identity["Arn"] |
79 |
| - except Exception: |
80 |
| - logger.debug("Failed to get AWS ARN", exc_info=True) |
81 |
| - return None |
| 66 | + session = aioboto3.Session() |
| 67 | + async with session.client("sts") as client: |
| 68 | + caller_identity = await client.get_caller_identity() |
| 69 | + if not caller_identity or "Arn" not in caller_identity: |
| 70 | + return None |
| 71 | + return caller_identity["Arn"] |
82 | 72 |
|
83 | 73 |
|
84 | 74 | async def create_aws_attestation() -> WorkloadIdentityAttestation | None:
|
85 | 75 | """Tries to create a workload identity attestation for AWS.
|
86 | 76 |
|
87 | 77 | If the application isn't running on AWS or no credentials were found, returns None.
|
88 | 78 | """
|
89 |
| - if aioboto3 is None: |
90 |
| - logger.debug("aioboto3 not available, falling back to sync implementation") |
91 |
| - from ..wif_util import create_aws_attestation as sync_create_aws_attestation |
92 |
| - |
93 |
| - return sync_create_aws_attestation() |
94 |
| - |
95 |
| - try: |
96 |
| - # Get credentials using aioboto3 |
97 |
| - session = aioboto3.Session() |
98 |
| - aws_creds = await session.get_credentials() # This IS async in aioboto3 |
99 |
| - if not aws_creds: |
100 |
| - logger.debug("No AWS credentials were found.") |
101 |
| - return None |
| 79 | + session = aioboto3.Session() |
| 80 | + aws_creds = await session.get_credentials() |
| 81 | + if not aws_creds: |
| 82 | + logger.debug("No AWS credentials were found.") |
| 83 | + return None |
102 | 84 |
|
103 |
| - region = await get_aws_region() |
104 |
| - if not region: |
105 |
| - logger.debug("No AWS region was found.") |
106 |
| - return None |
| 85 | + region = await get_aws_region() |
| 86 | + if not region: |
| 87 | + logger.debug("No AWS region was found.") |
| 88 | + return None |
107 | 89 |
|
108 |
| - arn = await get_aws_arn() |
109 |
| - if not arn: |
110 |
| - logger.debug("No AWS caller identity was found.") |
111 |
| - return None |
| 90 | + arn = await get_aws_arn() |
| 91 | + if not arn: |
| 92 | + logger.debug("No AWS caller identity was found.") |
| 93 | + return None |
112 | 94 |
|
113 |
| - sts_hostname = f"sts.{region}.amazonaws.com" |
114 |
| - request = AWSRequest( |
115 |
| - method="POST", |
116 |
| - url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", |
117 |
| - headers={ |
118 |
| - "Host": sts_hostname, |
119 |
| - "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, |
120 |
| - }, |
121 |
| - ) |
| 95 | + sts_hostname = f"sts.{region}.amazonaws.com" |
| 96 | + request = AWSRequest( |
| 97 | + method="POST", |
| 98 | + url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", |
| 99 | + headers={ |
| 100 | + "Host": sts_hostname, |
| 101 | + "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, |
| 102 | + }, |
| 103 | + ) |
122 | 104 |
|
123 |
| - SigV4Auth(aws_creds, "sts", region).add_auth(request) |
| 105 | + SigV4Auth(aws_creds, "sts", region).add_auth(request) |
124 | 106 |
|
125 |
| - assertion_dict = { |
126 |
| - "url": request.url, |
127 |
| - "method": request.method, |
128 |
| - "headers": dict(request.headers.items()), |
129 |
| - } |
130 |
| - credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode( |
131 |
| - "utf-8" |
132 |
| - ) |
133 |
| - return WorkloadIdentityAttestation( |
134 |
| - AttestationProvider.AWS, credential, {"arn": arn} |
135 |
| - ) |
136 |
| - except Exception: |
137 |
| - logger.debug("Failed to create AWS attestation", exc_info=True) |
138 |
| - return None |
| 107 | + assertion_dict = { |
| 108 | + "url": request.url, |
| 109 | + "method": request.method, |
| 110 | + "headers": dict(request.headers.items()), |
| 111 | + } |
| 112 | + credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") |
| 113 | + return WorkloadIdentityAttestation( |
| 114 | + AttestationProvider.AWS, credential, {"arn": arn} |
| 115 | + ) |
139 | 116 |
|
140 | 117 |
|
141 | 118 | async def create_gcp_attestation() -> WorkloadIdentityAttestation | None:
|
|
0 commit comments