|
8 | 8 | import json
|
9 | 9 | import logging
|
10 | 10 | import os
|
| 11 | +from base64 import b64encode |
11 | 12 |
|
12 | 13 | import aiohttp
|
13 | 14 |
|
| 15 | +try: |
| 16 | + import aioboto3 |
| 17 | + from botocore.auth import SigV4Auth |
| 18 | + from botocore.awsrequest import AWSRequest |
| 19 | + from botocore.utils import InstanceMetadataRegionFetcher |
| 20 | +except ImportError: |
| 21 | + aioboto3 = None |
| 22 | + SigV4Auth = None |
| 23 | + AWSRequest = None |
| 24 | + InstanceMetadataRegionFetcher = None |
| 25 | + |
14 | 26 | from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
|
15 | 27 | from ..errors import ProgrammingError
|
16 | 28 | from ..wif_util import (
|
17 | 29 | DEFAULT_ENTRA_SNOWFLAKE_RESOURCE,
|
18 | 30 | SNOWFLAKE_AUDIENCE,
|
19 | 31 | AttestationProvider,
|
20 | 32 | WorkloadIdentityAttestation,
|
21 |
| - create_aws_attestation, |
22 | 33 | create_oidc_attestation,
|
23 | 34 | extract_iss_and_sub_without_signature_verification,
|
24 | 35 | )
|
@@ -49,6 +60,91 @@ async def try_metadata_service_call(
|
49 | 60 | return None
|
50 | 61 |
|
51 | 62 |
|
| 63 | +async def get_aws_region() -> str | None: |
| 64 | + """Get the current AWS workload's region, if any.""" |
| 65 | + # Use sync implementation which has proper mocking support |
| 66 | + from ..wif_util import get_aws_region as sync_get_aws_region |
| 67 | + |
| 68 | + return sync_get_aws_region() |
| 69 | + |
| 70 | + |
| 71 | +async def get_aws_arn() -> str | None: |
| 72 | + """Get the current AWS workload's ARN, if any.""" |
| 73 | + if aioboto3 is None: |
| 74 | + logger.debug("aioboto3 not available, falling back to sync implementation") |
| 75 | + from ..wif_util import get_aws_arn as sync_get_aws_arn |
| 76 | + |
| 77 | + return sync_get_aws_arn() |
| 78 | + |
| 79 | + try: |
| 80 | + session = aioboto3.Session() |
| 81 | + async with session.client("sts") as client: |
| 82 | + caller_identity = await client.get_caller_identity() |
| 83 | + if not caller_identity or "Arn" not in caller_identity: |
| 84 | + return None |
| 85 | + return caller_identity["Arn"] |
| 86 | + except Exception: |
| 87 | + logger.debug("Failed to get AWS ARN", exc_info=True) |
| 88 | + return None |
| 89 | + |
| 90 | + |
| 91 | +async def create_aws_attestation() -> WorkloadIdentityAttestation | None: |
| 92 | + """Tries to create a workload identity attestation for AWS. |
| 93 | +
|
| 94 | + If the application isn't running on AWS or no credentials were found, returns None. |
| 95 | + """ |
| 96 | + if aioboto3 is None: |
| 97 | + logger.debug("aioboto3 not available, falling back to sync implementation") |
| 98 | + from ..wif_util import create_aws_attestation as sync_create_aws_attestation |
| 99 | + |
| 100 | + return sync_create_aws_attestation() |
| 101 | + |
| 102 | + try: |
| 103 | + # Get credentials using aioboto3 |
| 104 | + session = aioboto3.Session() |
| 105 | + aws_creds = await session.get_credentials() # This IS async in aioboto3 |
| 106 | + if not aws_creds: |
| 107 | + logger.debug("No AWS credentials were found.") |
| 108 | + return None |
| 109 | + |
| 110 | + region = await get_aws_region() |
| 111 | + if not region: |
| 112 | + logger.debug("No AWS region was found.") |
| 113 | + return None |
| 114 | + |
| 115 | + arn = await get_aws_arn() |
| 116 | + if not arn: |
| 117 | + logger.debug("No AWS caller identity was found.") |
| 118 | + return None |
| 119 | + |
| 120 | + sts_hostname = f"sts.{region}.amazonaws.com" |
| 121 | + request = AWSRequest( |
| 122 | + method="POST", |
| 123 | + url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", |
| 124 | + headers={ |
| 125 | + "Host": sts_hostname, |
| 126 | + "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, |
| 127 | + }, |
| 128 | + ) |
| 129 | + |
| 130 | + SigV4Auth(aws_creds, "sts", region).add_auth(request) |
| 131 | + |
| 132 | + assertion_dict = { |
| 133 | + "url": request.url, |
| 134 | + "method": request.method, |
| 135 | + "headers": dict(request.headers.items()), |
| 136 | + } |
| 137 | + credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode( |
| 138 | + "utf-8" |
| 139 | + ) |
| 140 | + return WorkloadIdentityAttestation( |
| 141 | + AttestationProvider.AWS, credential, {"arn": arn} |
| 142 | + ) |
| 143 | + except Exception: |
| 144 | + logger.debug("Failed to create AWS attestation", exc_info=True) |
| 145 | + return None |
| 146 | + |
| 147 | + |
52 | 148 | async def create_gcp_attestation() -> WorkloadIdentityAttestation | None:
|
53 | 149 | """Tries to create a workload identity attestation for GCP.
|
54 | 150 |
|
@@ -157,7 +253,7 @@ async def create_autodetect_attestation(
|
157 | 253 | if attestation:
|
158 | 254 | return attestation
|
159 | 255 |
|
160 |
| - attestation = create_aws_attestation() |
| 256 | + attestation = await create_aws_attestation() |
161 | 257 | if attestation:
|
162 | 258 | return attestation
|
163 | 259 |
|
@@ -188,7 +284,7 @@ async def create_attestation(
|
188 | 284 |
|
189 | 285 | attestation: WorkloadIdentityAttestation = None
|
190 | 286 | if provider == AttestationProvider.AWS:
|
191 |
| - attestation = create_aws_attestation() |
| 287 | + attestation = await create_aws_attestation() |
192 | 288 | elif provider == AttestationProvider.AZURE:
|
193 | 289 | attestation = await create_azure_attestation(entra_resource)
|
194 | 290 | elif provider == AttestationProvider.GCP:
|
|
0 commit comments