Skip to content

Commit 59ab6d4

Browse files
Remove duplication in wif_util
1 parent ecfa609 commit 59ab6d4

File tree

1 file changed

+9
-137
lines changed

1 file changed

+9
-137
lines changed

src/snowflake/connector/aio/_wif_util.py

Lines changed: 9 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -8,50 +8,22 @@
88
import json
99
import logging
1010
import os
11-
from base64 import b64encode
12-
from dataclasses import dataclass
13-
from enum import Enum, unique
1411

1512
import aiohttp
16-
import boto3
17-
import jwt
18-
from botocore.auth import SigV4Auth
19-
from botocore.awsrequest import AWSRequest
20-
from botocore.utils import InstanceMetadataRegionFetcher
2113

2214
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
2315
from ..errors import ProgrammingError
16+
from ..wif_util import (
17+
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE,
18+
SNOWFLAKE_AUDIENCE,
19+
AttestationProvider,
20+
WorkloadIdentityAttestation,
21+
create_aws_attestation,
22+
create_oidc_attestation,
23+
extract_iss_and_sub_without_signature_verification,
24+
)
2425

2526
logger = logging.getLogger(__name__)
26-
SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
27-
# TODO: use real app ID or domain name once it's available.
28-
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "NOT REAL - WILL BREAK"
29-
30-
31-
@unique
32-
class AttestationProvider(Enum):
33-
"""A WIF provider implementation that can produce an attestation."""
34-
35-
AWS = "AWS"
36-
"""Provider that builds an encoded pre-signed GetCallerIdentity request using the current workload's IAM role."""
37-
AZURE = "AZURE"
38-
"""Provider that requests an OAuth access token for the workload's managed identity."""
39-
GCP = "GCP"
40-
"""Provider that requests an ID token for the workload's attached service account."""
41-
OIDC = "OIDC"
42-
"""Provider that looks for an OIDC ID token."""
43-
44-
@staticmethod
45-
def from_string(provider: str) -> AttestationProvider:
46-
"""Converts a string to a strongly-typed enum value of AttestationProvider."""
47-
return AttestationProvider[provider.upper()]
48-
49-
50-
@dataclass
51-
class WorkloadIdentityAttestation:
52-
provider: AttestationProvider
53-
credential: str
54-
user_identifier_components: dict
5527

5628

5729
async def try_metadata_service_call(
@@ -77,88 +49,6 @@ async def try_metadata_service_call(
7749
return None
7850

7951

80-
def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]:
81-
"""Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature.
82-
83-
Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have
84-
the keys to verify these JWTs, and in any case that's not where the security boundary is drawn.
85-
86-
We only decode the JWT here to get some basic claims, which will be used for a) a quick smoke test to ensure we got the right
87-
issuer, and b) to find the unique user being asserted and populate assertion_content. The latter may be used for logging
88-
and possibly caching.
89-
90-
If there are any errors in parsing the token or extracting iss and sub, this will return (None, None).
91-
"""
92-
try:
93-
claims = jwt.decode(jwt_str, options={"verify_signature": False})
94-
except jwt.exceptions.InvalidTokenError:
95-
logger.warning("Token is not a valid JWT.", exc_info=True)
96-
return None, None
97-
98-
if not ("iss" in claims and "sub" in claims):
99-
logger.warning("Token is missing 'iss' or 'sub' claims.")
100-
return None, None
101-
102-
return claims["iss"], claims["sub"]
103-
104-
105-
def get_aws_region() -> str | None:
106-
"""Get the current AWS workload's region, if any."""
107-
if "AWS_REGION" in os.environ: # Lambda
108-
return os.environ["AWS_REGION"]
109-
else: # EC2
110-
return InstanceMetadataRegionFetcher().retrieve_region()
111-
112-
113-
def get_aws_arn() -> str | None:
114-
"""Get the current AWS workload's ARN, if any."""
115-
caller_identity = boto3.client("sts").get_caller_identity()
116-
if not caller_identity or "Arn" not in caller_identity:
117-
return None
118-
return caller_identity["Arn"]
119-
120-
121-
def create_aws_attestation() -> WorkloadIdentityAttestation | None:
122-
"""Tries to create a workload identity attestation for AWS.
123-
124-
If the application isn't running on AWS or no credentials were found, returns None.
125-
"""
126-
aws_creds = boto3.session.Session().get_credentials()
127-
if not aws_creds:
128-
logger.debug("No AWS credentials were found.")
129-
return None
130-
region = get_aws_region()
131-
if not region:
132-
logger.debug("No AWS region was found.")
133-
return None
134-
arn = get_aws_arn()
135-
if not arn:
136-
logger.debug("No AWS caller identity was found.")
137-
return None
138-
139-
sts_hostname = f"sts.{region}.amazonaws.com"
140-
request = AWSRequest(
141-
method="POST",
142-
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
143-
headers={
144-
"Host": sts_hostname,
145-
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
146-
},
147-
)
148-
149-
SigV4Auth(aws_creds, "sts", region).add_auth(request)
150-
151-
assertion_dict = {
152-
"url": request.url,
153-
"method": request.method,
154-
"headers": dict(request.headers.items()),
155-
}
156-
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
157-
return WorkloadIdentityAttestation(
158-
AttestationProvider.AWS, credential, {"arn": arn}
159-
)
160-
161-
16252
async def create_gcp_attestation() -> WorkloadIdentityAttestation | None:
16353
"""Tries to create a workload identity attestation for GCP.
16454
@@ -256,24 +146,6 @@ async def create_azure_attestation(
256146
)
257147

258148

259-
def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | None:
260-
"""Tries to create an attestation using the given token.
261-
262-
If this is not populated, returns None.
263-
"""
264-
if not token:
265-
logger.debug("No OIDC token was specified.")
266-
return None
267-
268-
issuer, subject = extract_iss_and_sub_without_signature_verification(token)
269-
if not issuer or not subject:
270-
return None
271-
272-
return WorkloadIdentityAttestation(
273-
AttestationProvider.OIDC, token, {"iss": issuer, "sub": subject}
274-
)
275-
276-
277149
async def create_autodetect_attestation(
278150
entra_resource: str, token: str | None = None
279151
) -> WorkloadIdentityAttestation | None:

0 commit comments

Comments
 (0)