8
8
import json
9
9
import logging
10
10
import os
11
- from base64 import b64encode
12
- from dataclasses import dataclass
13
- from enum import Enum , unique
14
11
15
12
import aiohttp
16
- import boto3
17
- import jwt
18
- from botocore .auth import SigV4Auth
19
- from botocore .awsrequest import AWSRequest
20
- from botocore .utils import InstanceMetadataRegionFetcher
21
13
22
14
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
23
15
from ..errors import ProgrammingError
16
+ from ..wif_util import (
17
+ DEFAULT_ENTRA_SNOWFLAKE_RESOURCE ,
18
+ SNOWFLAKE_AUDIENCE ,
19
+ AttestationProvider ,
20
+ WorkloadIdentityAttestation ,
21
+ create_aws_attestation ,
22
+ create_oidc_attestation ,
23
+ extract_iss_and_sub_without_signature_verification ,
24
+ )
24
25
25
26
logger = logging .getLogger (__name__ )
26
- SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
27
- # TODO: use real app ID or domain name once it's available.
28
- DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "NOT REAL - WILL BREAK"
29
-
30
-
31
- @unique
32
- class AttestationProvider (Enum ):
33
- """A WIF provider implementation that can produce an attestation."""
34
-
35
- AWS = "AWS"
36
- """Provider that builds an encoded pre-signed GetCallerIdentity request using the current workload's IAM role."""
37
- AZURE = "AZURE"
38
- """Provider that requests an OAuth access token for the workload's managed identity."""
39
- GCP = "GCP"
40
- """Provider that requests an ID token for the workload's attached service account."""
41
- OIDC = "OIDC"
42
- """Provider that looks for an OIDC ID token."""
43
-
44
- @staticmethod
45
- def from_string (provider : str ) -> AttestationProvider :
46
- """Converts a string to a strongly-typed enum value of AttestationProvider."""
47
- return AttestationProvider [provider .upper ()]
48
-
49
-
50
- @dataclass
51
- class WorkloadIdentityAttestation :
52
- provider : AttestationProvider
53
- credential : str
54
- user_identifier_components : dict
55
27
56
28
57
29
async def try_metadata_service_call (
@@ -77,88 +49,6 @@ async def try_metadata_service_call(
77
49
return None
78
50
79
51
80
- def extract_iss_and_sub_without_signature_verification (jwt_str : str ) -> tuple [str , str ]:
81
- """Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature.
82
-
83
- Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have
84
- the keys to verify these JWTs, and in any case that's not where the security boundary is drawn.
85
-
86
- We only decode the JWT here to get some basic claims, which will be used for a) a quick smoke test to ensure we got the right
87
- issuer, and b) to find the unique user being asserted and populate assertion_content. The latter may be used for logging
88
- and possibly caching.
89
-
90
- If there are any errors in parsing the token or extracting iss and sub, this will return (None, None).
91
- """
92
- try :
93
- claims = jwt .decode (jwt_str , options = {"verify_signature" : False })
94
- except jwt .exceptions .InvalidTokenError :
95
- logger .warning ("Token is not a valid JWT." , exc_info = True )
96
- return None , None
97
-
98
- if not ("iss" in claims and "sub" in claims ):
99
- logger .warning ("Token is missing 'iss' or 'sub' claims." )
100
- return None , None
101
-
102
- return claims ["iss" ], claims ["sub" ]
103
-
104
-
105
- def get_aws_region () -> str | None :
106
- """Get the current AWS workload's region, if any."""
107
- if "AWS_REGION" in os .environ : # Lambda
108
- return os .environ ["AWS_REGION" ]
109
- else : # EC2
110
- return InstanceMetadataRegionFetcher ().retrieve_region ()
111
-
112
-
113
- def get_aws_arn () -> str | None :
114
- """Get the current AWS workload's ARN, if any."""
115
- caller_identity = boto3 .client ("sts" ).get_caller_identity ()
116
- if not caller_identity or "Arn" not in caller_identity :
117
- return None
118
- return caller_identity ["Arn" ]
119
-
120
-
121
- def create_aws_attestation () -> WorkloadIdentityAttestation | None :
122
- """Tries to create a workload identity attestation for AWS.
123
-
124
- If the application isn't running on AWS or no credentials were found, returns None.
125
- """
126
- aws_creds = boto3 .session .Session ().get_credentials ()
127
- if not aws_creds :
128
- logger .debug ("No AWS credentials were found." )
129
- return None
130
- region = get_aws_region ()
131
- if not region :
132
- logger .debug ("No AWS region was found." )
133
- return None
134
- arn = get_aws_arn ()
135
- if not arn :
136
- logger .debug ("No AWS caller identity was found." )
137
- return None
138
-
139
- sts_hostname = f"sts.{ region } .amazonaws.com"
140
- request = AWSRequest (
141
- method = "POST" ,
142
- url = f"https://{ sts_hostname } /?Action=GetCallerIdentity&Version=2011-06-15" ,
143
- headers = {
144
- "Host" : sts_hostname ,
145
- "X-Snowflake-Audience" : SNOWFLAKE_AUDIENCE ,
146
- },
147
- )
148
-
149
- SigV4Auth (aws_creds , "sts" , region ).add_auth (request )
150
-
151
- assertion_dict = {
152
- "url" : request .url ,
153
- "method" : request .method ,
154
- "headers" : dict (request .headers .items ()),
155
- }
156
- credential = b64encode (json .dumps (assertion_dict ).encode ("utf-8" )).decode ("utf-8" )
157
- return WorkloadIdentityAttestation (
158
- AttestationProvider .AWS , credential , {"arn" : arn }
159
- )
160
-
161
-
162
52
async def create_gcp_attestation () -> WorkloadIdentityAttestation | None :
163
53
"""Tries to create a workload identity attestation for GCP.
164
54
@@ -256,24 +146,6 @@ async def create_azure_attestation(
256
146
)
257
147
258
148
259
- def create_oidc_attestation (token : str | None ) -> WorkloadIdentityAttestation | None :
260
- """Tries to create an attestation using the given token.
261
-
262
- If this is not populated, returns None.
263
- """
264
- if not token :
265
- logger .debug ("No OIDC token was specified." )
266
- return None
267
-
268
- issuer , subject = extract_iss_and_sub_without_signature_verification (token )
269
- if not issuer or not subject :
270
- return None
271
-
272
- return WorkloadIdentityAttestation (
273
- AttestationProvider .OIDC , token , {"iss" : issuer , "sub" : subject }
274
- )
275
-
276
-
277
149
async def create_autodetect_attestation (
278
150
entra_resource : str , token : str | None = None
279
151
) -> WorkloadIdentityAttestation | None :
0 commit comments