99from dataclasses import dataclass
1010from datetime import datetime , timezone
1111from enum import Enum , unique
12- from urllib .parse import urlparse
12+ from urllib .parse import parse_qsl , quote , urlparse
1313
1414import jwt
1515
@@ -89,7 +89,9 @@ def try_metadata_service_call(
8989 return res
9090
9191
92- def extract_iss_and_sub_without_signature_verification (jwt_str : str ) -> tuple [str , str ]:
92+ def extract_iss_and_sub_without_signature_verification (
93+ jwt_str : str ,
94+ ) -> tuple [str | None , str | None ]:
9395 """Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature.
9496
9597 Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have
@@ -114,6 +116,18 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st
114116 return claims ["iss" ], claims ["sub" ]
115117
116118
119+ # --------------------------------------------------------------------------- #
120+ # AWS helper utilities (token, credentials, region) #
121+ # --------------------------------------------------------------------------- #
122+ def _imds_v2_token () -> str | None :
123+ res = try_metadata_service_call (
124+ method = "PUT" ,
125+ url = "http://169.254.169.254/latest/api/token" ,
126+ headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "300" },
127+ )
128+ return res .text .strip () if res else None
129+
130+
117131def get_aws_credentials () -> AwsCredentials | None :
118132 """Get AWS credentials from environment variables or instance metadata.
119133
@@ -129,24 +143,18 @@ def get_aws_credentials() -> AwsCredentials | None:
129143
130144 # Try instance metadata service (IMDSv2)
131145 try :
132- # First, get a token for IMDSv2
133- token_res = try_metadata_service_call (
134- method = "PUT" ,
135- url = "http://169.254.169.254/latest/api/token" ,
136- headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "300" },
137- )
138-
139- if token_res is None :
146+ token = _imds_v2_token ()
147+ if token is None :
140148 logger .debug ("Failed to get IMDSv2 token from metadata service." )
141149 return None
142150
143- token = token_res . text . strip ()
151+ token_hdr = { "X-aws-ec2-metadata-token" : token } if token else {}
144152
145153 # Get the security credentials from the metadata service
146154 res = try_metadata_service_call (
147155 method = "GET" ,
148156 url = "http://169.254.169.254/latest/meta-data/iam/security-credentials/" ,
149- headers = { "X-aws-ec2-metadata-token" : token } ,
157+ headers = token_hdr ,
150158 )
151159 if res is None :
152160 logger .debug ("Failed to get IAM role list from metadata service." )
@@ -161,7 +169,7 @@ def get_aws_credentials() -> AwsCredentials | None:
161169 res = try_metadata_service_call (
162170 method = "GET" ,
163171 url = f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{ role_name } " ,
164- headers = { "X-aws-ec2-metadata-token" : token } ,
172+ headers = token_hdr ,
165173 )
166174 if res is None :
167175 logger .debug ("Failed to get IAM role credentials from metadata service." )
@@ -174,7 +182,6 @@ def get_aws_credentials() -> AwsCredentials | None:
174182
175183 if access_key and secret_key :
176184 return AwsCredentials (access_key , secret_key , token )
177-
178185 except Exception as e :
179186 logger .debug (f"Error getting AWS credentials from metadata service: { e } " )
180187
@@ -183,43 +190,47 @@ def get_aws_credentials() -> AwsCredentials | None:
183190
184191def get_aws_region () -> str | None :
185192 """Get the current AWS workload's region, if any."""
186- # Try environment variable first
187193 region = os .environ .get ("AWS_REGION" )
188194 if region :
189195 return region
190196
191- # Try instance metadata service (IMDSv2)
192197 try :
193- # First, get a token for IMDSv2
194- token_res = try_metadata_service_call (
195- method = "PUT" ,
196- url = "http://169.254.169.254/latest/api/token" ,
197- headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "300" },
198- )
199-
200- if token_res is None :
198+ token = _imds_v2_token ()
199+ if token is None :
201200 logger .debug ("Failed to get IMDSv2 token from metadata service." )
202201 return None
203202
204- token = token_res . text . strip ()
203+ token_hdr = { "X-aws-ec2-metadata-token" : token } if token else {}
205204
206205 # Get region from metadata service
207206 res = try_metadata_service_call (
208207 method = "GET" ,
209208 url = "http://169.254.169.254/latest/meta-data/placement/region" ,
210- headers = { "X-aws-ec2-metadata-token" : token } ,
209+ headers = token_hdr ,
211210 )
212211 if res is not None :
213212 return res .text .strip ()
213+
214+ res = try_metadata_service_call (
215+ method = "GET" ,
216+ url = "http://169.254.169.254/latest/meta-data/placement/availability-zone" ,
217+ headers = token_hdr ,
218+ )
219+ if res is not None :
220+ return res .text .strip ()[:- 1 ]
214221 except Exception as e :
215222 logger .debug (f"Error getting AWS region from metadata service: { e } " )
216223
217224 return None
218225
219226
220- def get_aws_sts_hostname (region : str ) -> str :
227+ def get_aws_sts_hostname (region : str ) -> str | None :
221228 """Constructs the AWS STS hostname for a given region.
222229
230+ * China regions (`cn-*`) → sts.<region>.amazonaws.com.cn
231+ * All other regions → sts.<region>.amazonaws.com
232+ * Any invalid input → None
233+
223234 Args:
224235 region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1').
225236
@@ -231,6 +242,10 @@ def get_aws_sts_hostname(region: str) -> str:
231242 - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html
232243 - https://docs.aws.amazon.com/general/latest/gr/sts.html
233244 """
245+
246+ if not region or not isinstance (region , str ):
247+ return None
248+
234249 if region .startswith ("cn-" ):
235250 # China regions have a different domain suffix
236251 return f"sts.{ region } .amazonaws.com.cn"
@@ -239,6 +254,19 @@ def get_aws_sts_hostname(region: str) -> str:
239254 return f"sts.{ region } .amazonaws.com"
240255
241256
257+ def _aws_percent_encode (s : str ) -> str :
258+ return quote (s , safe = "~" )
259+
260+
261+ def _canonical_query (query : str ) -> str :
262+ if not query :
263+ return ""
264+ pairs = sorted (parse_qsl (query , keep_blank_values = True ))
265+ return "&" .join (
266+ f"{ _aws_percent_encode (k )} ={ _aws_percent_encode (v )} " for k , v in pairs
267+ )
268+
269+
242270def aws_signature_v4_sign (
243271 credentials : AwsCredentials ,
244272 method : str ,
@@ -252,47 +280,43 @@ def aws_signature_v4_sign(
252280
253281 Based on the C# implementation in AwsSignature4Signer.cs.
254282 """
255- # Parse the URL
256283 parsed_url = urlparse (url )
257284
258- # Create timestamp
259285 utc_now = datetime .now (timezone .utc )
260286 amz_date = utc_now .strftime ("%Y%m%dT%H%M%SZ" )
261287 date_string = utc_now .strftime ("%Y%m%d" )
262288
263- # Add required headers
264- headers = headers . copy ()
265- headers ["x-amz-date" ] = amz_date
289+ headers_lower = { k . lower (): str ( v ). strip () for k , v in headers . items ()}
290+ headers_lower [ "host" ] = parsed_url . netloc
291+ headers_lower ["x-amz-date" ] = amz_date
266292 if credentials .token :
267- headers ["x-amz-security-token" ] = credentials .token
293+ headers_lower ["x-amz-security-token" ] = credentials .token
268294
269- # Create canonical request
270- canonical_uri = parsed_url . path or "/"
271- canonical_querystring = parsed_url . query or ""
295+ sorted_header_keys = sorted ( headers_lower . keys ())
296+ canonical_headers = "" . join ( f" { k } : { headers_lower [ k ] } \n " for k in sorted_header_keys )
297+ signed_headers = ";" . join ( sorted_header_keys )
272298
273- # Sort headers and create canonical headers
274- sorted_headers = sorted (headers .items (), key = lambda x : x [0 ].lower ())
275- canonical_headers = ""
276- signed_headers = ""
277-
278- for key , value in sorted_headers :
279- canonical_headers += f"{ key .lower ()} :{ str (value ).strip ()} \n "
280- if signed_headers :
281- signed_headers += ";"
282- signed_headers += key .lower ()
283-
284- # Create payload hash
299+ canonical_querystring = _canonical_query (parsed_url .query )
285300 payload_hash = hashlib .sha256 (payload .encode ("utf-8" )).hexdigest ()
286301
287- # Create canonical request
288- canonical_request = f"{ method } \n { canonical_uri } \n { canonical_querystring } \n { canonical_headers } \n { signed_headers } \n { payload_hash } "
302+ canonical_request = (
303+ f"{ method } \n "
304+ f"{ parsed_url .path or '/' } \n "
305+ f"{ canonical_querystring } \n "
306+ f"{ canonical_headers } "
307+ f"{ signed_headers } \n "
308+ f"{ payload_hash } "
309+ )
289310
290- # Create string to sign
291311 algorithm = "AWS4-HMAC-SHA256"
292312 credential_scope = f"{ date_string } /{ region } /{ service } /aws4_request"
293- string_to_sign = f"{ algorithm } \n { amz_date } \n { credential_scope } \n { hashlib .sha256 (canonical_request .encode ('utf-8' )).hexdigest ()} "
313+ string_to_sign = (
314+ f"{ algorithm } \n "
315+ f"{ amz_date } \n "
316+ f"{ credential_scope } \n "
317+ f"{ hashlib .sha256 (canonical_request .encode ('utf-8' )).hexdigest ()} "
318+ )
294319
295- # Calculate signature
296320 def hmac_sha256 (key : bytes , msg : str ) -> bytes :
297321 return hmac .new (key , msg .encode ("utf-8" ), hashlib .sha256 ).digest ()
298322
@@ -305,11 +329,20 @@ def hmac_sha256(key: bytes, msg: str) -> bytes:
305329 k_signing , string_to_sign .encode ("utf-8" ), hashlib .sha256
306330 ).hexdigest ()
307331
308- # Create authorization header
309- authorization = f"{ algorithm } Credential={ credentials .access_key } /{ credential_scope } , SignedHeaders={ signed_headers } , Signature={ signature } "
310- headers ["authorization" ] = authorization
332+ authorization = (
333+ f"{ algorithm } "
334+ f"Credential={ credentials .access_key } /{ credential_scope } , "
335+ f"SignedHeaders={ signed_headers } , Signature={ signature } "
336+ )
337+
338+ final_headers = headers .copy ()
339+ final_headers ["Host" ] = parsed_url .netloc
340+ final_headers ["X-Amz-Date" ] = amz_date
341+ if credentials .token :
342+ final_headers ["X-Amz-Security-Token" ] = credentials .token
343+ final_headers ["Authorization" ] = authorization
311344
312- return headers
345+ return final_headers
313346
314347
315348def create_aws_attestation () -> WorkloadIdentityAttestation | None :
@@ -331,19 +364,17 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None:
331364 sts_hostname = get_aws_sts_hostname (region )
332365 url = f"https://{ sts_hostname } /?Action=GetCallerIdentity&Version=2011-06-15"
333366
334- headers = {
335- "Host" : sts_hostname ,
367+ base_headers = {
336368 "X-Snowflake-Audience" : SNOWFLAKE_AUDIENCE ,
337369 }
338370
339- # Sign the request
340371 signed_headers = aws_signature_v4_sign (
341372 credentials = credentials ,
342373 method = "POST" ,
343374 url = url ,
344375 region = region ,
345376 service = "sts" ,
346- headers = headers ,
377+ headers = base_headers ,
347378 )
348379
349380 # Create attestation request
@@ -353,7 +384,6 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None:
353384 "headers" : signed_headers ,
354385 }
355386
356- # Encode to base64
357387 credential = b64encode (json .dumps (attestation_request ).encode ("utf-8" )).decode (
358388 "utf-8"
359389 )
0 commit comments