Skip to content

Commit c2cda00

Browse files
use aiohttp in wif_util
1 parent f1dd2bb commit c2cda00

File tree

4 files changed

+464
-18
lines changed

4 files changed

+464
-18
lines changed
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import json
9+
import logging
10+
import os
11+
from base64 import b64encode
12+
from dataclasses import dataclass
13+
from enum import Enum, unique
14+
15+
import aiohttp
16+
import boto3
17+
import jwt
18+
from botocore.auth import SigV4Auth
19+
from botocore.awsrequest import AWSRequest
20+
from botocore.utils import InstanceMetadataRegionFetcher
21+
22+
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
23+
from ..errors import ProgrammingError
24+
25+
logger = logging.getLogger(__name__)
26+
SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
27+
# TODO: use real app ID or domain name once it's available.
28+
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "NOT REAL - WILL BREAK"
29+
30+
31+
@unique
32+
class AttestationProvider(Enum):
33+
"""A WIF provider implementation that can produce an attestation."""
34+
35+
AWS = "AWS"
36+
"""Provider that builds an encoded pre-signed GetCallerIdentity request using the current workload's IAM role."""
37+
AZURE = "AZURE"
38+
"""Provider that requests an OAuth access token for the workload's managed identity."""
39+
GCP = "GCP"
40+
"""Provider that requests an ID token for the workload's attached service account."""
41+
OIDC = "OIDC"
42+
"""Provider that looks for an OIDC ID token."""
43+
44+
@staticmethod
45+
def from_string(provider: str) -> AttestationProvider:
46+
"""Converts a string to a strongly-typed enum value of AttestationProvider."""
47+
return AttestationProvider[provider.upper()]
48+
49+
50+
@dataclass
51+
class WorkloadIdentityAttestation:
52+
provider: AttestationProvider
53+
credential: str
54+
user_identifier_components: dict
55+
56+
57+
async def try_metadata_service_call(
58+
method: str, url: str, headers: dict, timeout_sec: int = 3
59+
) -> aiohttp.ClientResponse | None:
60+
"""Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout.
61+
62+
If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response.
63+
"""
64+
try:
65+
timeout = aiohttp.ClientTimeout(total=timeout_sec)
66+
async with aiohttp.ClientSession(timeout=timeout) as session:
67+
async with session.request(
68+
method=method, url=url, headers=headers
69+
) as response:
70+
if not response.ok:
71+
return None
72+
# Create a copy of the response data since the response will be closed
73+
content = await response.read()
74+
response._content = content
75+
return response
76+
except (aiohttp.ClientError, asyncio.TimeoutError):
77+
return None
78+
79+
80+
def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]:
81+
"""Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature.
82+
83+
Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have
84+
the keys to verify these JWTs, and in any case that's not where the security boundary is drawn.
85+
86+
We only decode the JWT here to get some basic claims, which will be used for a) a quick smoke test to ensure we got the right
87+
issuer, and b) to find the unique user being asserted and populate assertion_content. The latter may be used for logging
88+
and possibly caching.
89+
90+
If there are any errors in parsing the token or extracting iss and sub, this will return (None, None).
91+
"""
92+
try:
93+
claims = jwt.decode(jwt_str, options={"verify_signature": False})
94+
except jwt.exceptions.InvalidTokenError:
95+
logger.warning("Token is not a valid JWT.", exc_info=True)
96+
return None, None
97+
98+
if not ("iss" in claims and "sub" in claims):
99+
logger.warning("Token is missing 'iss' or 'sub' claims.")
100+
return None, None
101+
102+
return claims["iss"], claims["sub"]
103+
104+
105+
def get_aws_region() -> str | None:
106+
"""Get the current AWS workload's region, if any."""
107+
if "AWS_REGION" in os.environ: # Lambda
108+
return os.environ["AWS_REGION"]
109+
else: # EC2
110+
return InstanceMetadataRegionFetcher().retrieve_region()
111+
112+
113+
def get_aws_arn() -> str | None:
114+
"""Get the current AWS workload's ARN, if any."""
115+
caller_identity = boto3.client("sts").get_caller_identity()
116+
if not caller_identity or "Arn" not in caller_identity:
117+
return None
118+
return caller_identity["Arn"]
119+
120+
121+
def create_aws_attestation() -> WorkloadIdentityAttestation | None:
122+
"""Tries to create a workload identity attestation for AWS.
123+
124+
If the application isn't running on AWS or no credentials were found, returns None.
125+
"""
126+
aws_creds = boto3.session.Session().get_credentials()
127+
if not aws_creds:
128+
logger.debug("No AWS credentials were found.")
129+
return None
130+
region = get_aws_region()
131+
if not region:
132+
logger.debug("No AWS region was found.")
133+
return None
134+
arn = get_aws_arn()
135+
if not arn:
136+
logger.debug("No AWS caller identity was found.")
137+
return None
138+
139+
sts_hostname = f"sts.{region}.amazonaws.com"
140+
request = AWSRequest(
141+
method="POST",
142+
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
143+
headers={
144+
"Host": sts_hostname,
145+
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
146+
},
147+
)
148+
149+
SigV4Auth(aws_creds, "sts", region).add_auth(request)
150+
151+
assertion_dict = {
152+
"url": request.url,
153+
"method": request.method,
154+
"headers": dict(request.headers.items()),
155+
}
156+
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
157+
return WorkloadIdentityAttestation(
158+
AttestationProvider.AWS, credential, {"arn": arn}
159+
)
160+
161+
162+
async def create_gcp_attestation() -> WorkloadIdentityAttestation | None:
163+
"""Tries to create a workload identity attestation for GCP.
164+
165+
If the application isn't running on GCP or no credentials were found, returns None.
166+
"""
167+
res = await try_metadata_service_call(
168+
method="GET",
169+
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
170+
headers={
171+
"Metadata-Flavor": "Google",
172+
},
173+
)
174+
if res is None:
175+
# Most likely we're just not running on GCP, which may be expected.
176+
logger.debug("GCP metadata server request was not successful.")
177+
return None
178+
179+
jwt_str = res._content.decode("utf-8")
180+
issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
181+
if not issuer or not subject:
182+
return None
183+
if issuer != "https://accounts.google.com":
184+
# This might happen if we're running on a different platform that responds to the same metadata request signature as GCP.
185+
logger.debug("Unexpected GCP token issuer '%s'", issuer)
186+
return None
187+
188+
return WorkloadIdentityAttestation(
189+
AttestationProvider.GCP, jwt_str, {"sub": subject}
190+
)
191+
192+
193+
async def create_azure_attestation(
194+
snowflake_entra_resource: str,
195+
) -> WorkloadIdentityAttestation | None:
196+
"""Tries to create a workload identity attestation for Azure.
197+
198+
If the application isn't running on Azure or no credentials were found, returns None.
199+
"""
200+
headers = {"Metadata": "True"}
201+
url_without_query_string = "http://169.254.169.254/metadata/identity/oauth2/token"
202+
query_params = f"api-version=2018-02-01&resource={snowflake_entra_resource}"
203+
204+
# Check if running in Azure Functions environment
205+
identity_endpoint = os.environ.get("IDENTITY_ENDPOINT")
206+
identity_header = os.environ.get("IDENTITY_HEADER")
207+
is_azure_functions = identity_endpoint is not None
208+
209+
if is_azure_functions:
210+
if not identity_header:
211+
logger.warning("Managed identity is not enabled on this Azure function.")
212+
return None
213+
214+
# Azure Functions uses a different endpoint, headers and API version.
215+
url_without_query_string = identity_endpoint
216+
headers = {"X-IDENTITY-HEADER": identity_header}
217+
query_params = f"api-version=2019-08-01&resource={snowflake_entra_resource}"
218+
219+
# Some Azure Functions environments may require client_id in the URL
220+
managed_identity_client_id = os.environ.get("MANAGED_IDENTITY_CLIENT_ID")
221+
if managed_identity_client_id:
222+
query_params += f"&client_id={managed_identity_client_id}"
223+
224+
res = await try_metadata_service_call(
225+
method="GET",
226+
url=f"{url_without_query_string}?{query_params}",
227+
headers=headers,
228+
)
229+
if res is None:
230+
# Most likely we're just not running on Azure, which may be expected.
231+
logger.debug("Azure metadata server request was not successful.")
232+
return None
233+
234+
try:
235+
response_text = res._content.decode("utf-8")
236+
response_data = json.loads(response_text)
237+
jwt_str = response_data.get("access_token")
238+
if not jwt_str:
239+
# Could be that Managed Identity is disabled.
240+
logger.debug("No access token found in Azure response.")
241+
return None
242+
except (ValueError, KeyError) as e:
243+
logger.debug(f"Error parsing Azure response: {e}")
244+
return None
245+
246+
issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
247+
if not issuer or not subject:
248+
return None
249+
if not issuer.startswith("https://sts.windows.net/"):
250+
# This might happen if we're running on a different platform that responds to the same metadata request signature as Azure.
251+
logger.debug("Unexpected Azure token issuer '%s'", issuer)
252+
return None
253+
254+
return WorkloadIdentityAttestation(
255+
AttestationProvider.AZURE, jwt_str, {"iss": issuer, "sub": subject}
256+
)
257+
258+
259+
def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | None:
260+
"""Tries to create an attestation using the given token.
261+
262+
If this is not populated, returns None.
263+
"""
264+
if not token:
265+
logger.debug("No OIDC token was specified.")
266+
return None
267+
268+
issuer, subject = extract_iss_and_sub_without_signature_verification(token)
269+
if not issuer or not subject:
270+
return None
271+
272+
return WorkloadIdentityAttestation(
273+
AttestationProvider.OIDC, token, {"iss": issuer, "sub": subject}
274+
)
275+
276+
277+
async def create_autodetect_attestation(
278+
entra_resource: str, token: str | None = None
279+
) -> WorkloadIdentityAttestation | None:
280+
"""Tries to create an attestation using the auto-detected runtime environment.
281+
282+
If no attestation can be found, returns None.
283+
"""
284+
attestation = create_oidc_attestation(token)
285+
if attestation:
286+
return attestation
287+
288+
attestation = create_aws_attestation()
289+
if attestation:
290+
return attestation
291+
292+
attestation = await create_azure_attestation(entra_resource)
293+
if attestation:
294+
return attestation
295+
296+
attestation = await create_gcp_attestation()
297+
if attestation:
298+
return attestation
299+
300+
return None
301+
302+
303+
async def create_attestation(
304+
provider: AttestationProvider | None,
305+
entra_resource: str | None = None,
306+
token: str | None = None,
307+
) -> WorkloadIdentityAttestation:
308+
"""Entry point to create an attestation using the given provider.
309+
310+
If the provider is None, this will try to auto-detect a credential from the runtime environment. If the provider fails to detect a credential,
311+
a ProgrammingError will be raised.
312+
313+
If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used.
314+
"""
315+
entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
316+
317+
attestation: WorkloadIdentityAttestation = None
318+
if provider == AttestationProvider.AWS:
319+
attestation = create_aws_attestation()
320+
elif provider == AttestationProvider.AZURE:
321+
attestation = await create_azure_attestation(entra_resource)
322+
elif provider == AttestationProvider.GCP:
323+
attestation = await create_gcp_attestation()
324+
elif provider == AttestationProvider.OIDC:
325+
attestation = create_oidc_attestation(token)
326+
elif provider is None:
327+
attestation = await create_autodetect_attestation(entra_resource, token)
328+
329+
if not attestation:
330+
provider_str = "auto-detect" if provider is None else provider.value
331+
raise ProgrammingError(
332+
msg=f"No workload identity credential was found for '{provider_str}'.",
333+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
334+
)
335+
336+
return attestation

0 commit comments

Comments
 (0)