Skip to content

Commit f8b0944

Browse files
SNOW-2183023: Removed boto
1 parent 965786a commit f8b0944

File tree

2 files changed

+132
-130
lines changed

2 files changed

+132
-130
lines changed

src/snowflake/connector/wif_util.py

Lines changed: 92 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dataclasses import dataclass
1010
from datetime import datetime, timezone
1111
from enum import Enum, unique
12-
from urllib.parse import urlparse
12+
from urllib.parse import parse_qsl, quote, urlparse
1313

1414
import jwt
1515

@@ -89,7 +89,9 @@ def try_metadata_service_call(
8989
return res
9090

9191

92-
def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]:
92+
def extract_iss_and_sub_without_signature_verification(
93+
jwt_str: str,
94+
) -> tuple[str | None, str | None]:
9395
"""Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature.
9496
9597
Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have
@@ -114,6 +116,18 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st
114116
return claims["iss"], claims["sub"]
115117

116118

119+
# --------------------------------------------------------------------------- #
120+
# AWS helper utilities (token, credentials, region) #
121+
# --------------------------------------------------------------------------- #
122+
def _imds_v2_token() -> str | None:
123+
res = try_metadata_service_call(
124+
method="PUT",
125+
url="http://169.254.169.254/latest/api/token",
126+
headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"},
127+
)
128+
return res.text.strip() if res else None
129+
130+
117131
def get_aws_credentials() -> AwsCredentials | None:
118132
"""Get AWS credentials from environment variables or instance metadata.
119133
@@ -129,24 +143,18 @@ def get_aws_credentials() -> AwsCredentials | None:
129143

130144
# Try instance metadata service (IMDSv2)
131145
try:
132-
# First, get a token for IMDSv2
133-
token_res = try_metadata_service_call(
134-
method="PUT",
135-
url="http://169.254.169.254/latest/api/token",
136-
headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"},
137-
)
138-
139-
if token_res is None:
146+
token = _imds_v2_token()
147+
if token is None:
140148
logger.debug("Failed to get IMDSv2 token from metadata service.")
141149
return None
142150

143-
token = token_res.text.strip()
151+
token_hdr = {"X-aws-ec2-metadata-token": token} if token else {}
144152

145153
# Get the security credentials from the metadata service
146154
res = try_metadata_service_call(
147155
method="GET",
148156
url="http://169.254.169.254/latest/meta-data/iam/security-credentials/",
149-
headers={"X-aws-ec2-metadata-token": token},
157+
headers=token_hdr,
150158
)
151159
if res is None:
152160
logger.debug("Failed to get IAM role list from metadata service.")
@@ -161,7 +169,7 @@ def get_aws_credentials() -> AwsCredentials | None:
161169
res = try_metadata_service_call(
162170
method="GET",
163171
url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}",
164-
headers={"X-aws-ec2-metadata-token": token},
172+
headers=token_hdr,
165173
)
166174
if res is None:
167175
logger.debug("Failed to get IAM role credentials from metadata service.")
@@ -174,7 +182,6 @@ def get_aws_credentials() -> AwsCredentials | None:
174182

175183
if access_key and secret_key:
176184
return AwsCredentials(access_key, secret_key, token)
177-
178185
except Exception as e:
179186
logger.debug(f"Error getting AWS credentials from metadata service: {e}")
180187

@@ -183,43 +190,47 @@ def get_aws_credentials() -> AwsCredentials | None:
183190

184191
def get_aws_region() -> str | None:
185192
"""Get the current AWS workload's region, if any."""
186-
# Try environment variable first
187193
region = os.environ.get("AWS_REGION")
188194
if region:
189195
return region
190196

191-
# Try instance metadata service (IMDSv2)
192197
try:
193-
# First, get a token for IMDSv2
194-
token_res = try_metadata_service_call(
195-
method="PUT",
196-
url="http://169.254.169.254/latest/api/token",
197-
headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"},
198-
)
199-
200-
if token_res is None:
198+
token = _imds_v2_token()
199+
if token is None:
201200
logger.debug("Failed to get IMDSv2 token from metadata service.")
202201
return None
203202

204-
token = token_res.text.strip()
203+
token_hdr = {"X-aws-ec2-metadata-token": token} if token else {}
205204

206205
# Get region from metadata service
207206
res = try_metadata_service_call(
208207
method="GET",
209208
url="http://169.254.169.254/latest/meta-data/placement/region",
210-
headers={"X-aws-ec2-metadata-token": token},
209+
headers=token_hdr,
211210
)
212211
if res is not None:
213212
return res.text.strip()
213+
214+
res = try_metadata_service_call(
215+
method="GET",
216+
url="http://169.254.169.254/latest/meta-data/placement/availability-zone",
217+
headers=token_hdr,
218+
)
219+
if res is not None:
220+
return res.text.strip()[:-1]
214221
except Exception as e:
215222
logger.debug(f"Error getting AWS region from metadata service: {e}")
216223

217224
return None
218225

219226

220-
def get_aws_sts_hostname(region: str) -> str:
227+
def get_aws_sts_hostname(region: str) -> str | None:
221228
"""Constructs the AWS STS hostname for a given region.
222229
230+
* China regions (`cn-*`) → sts.<region>.amazonaws.com.cn
231+
* All other regions → sts.<region>.amazonaws.com
232+
* Any invalid input → None
233+
223234
Args:
224235
region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1').
225236
@@ -231,6 +242,10 @@ def get_aws_sts_hostname(region: str) -> str:
231242
- https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html
232243
- https://docs.aws.amazon.com/general/latest/gr/sts.html
233244
"""
245+
246+
if not region or not isinstance(region, str):
247+
return None
248+
234249
if region.startswith("cn-"):
235250
# China regions have a different domain suffix
236251
return f"sts.{region}.amazonaws.com.cn"
@@ -239,6 +254,19 @@ def get_aws_sts_hostname(region: str) -> str:
239254
return f"sts.{region}.amazonaws.com"
240255

241256

257+
def _aws_percent_encode(s: str) -> str:
258+
return quote(s, safe="~")
259+
260+
261+
def _canonical_query(query: str) -> str:
262+
if not query:
263+
return ""
264+
pairs = sorted(parse_qsl(query, keep_blank_values=True))
265+
return "&".join(
266+
f"{_aws_percent_encode(k)}={_aws_percent_encode(v)}" for k, v in pairs
267+
)
268+
269+
242270
def aws_signature_v4_sign(
243271
credentials: AwsCredentials,
244272
method: str,
@@ -252,47 +280,43 @@ def aws_signature_v4_sign(
252280
253281
Based on the C# implementation in AwsSignature4Signer.cs.
254282
"""
255-
# Parse the URL
256283
parsed_url = urlparse(url)
257284

258-
# Create timestamp
259285
utc_now = datetime.now(timezone.utc)
260286
amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ")
261287
date_string = utc_now.strftime("%Y%m%d")
262288

263-
# Add required headers
264-
headers = headers.copy()
265-
headers["x-amz-date"] = amz_date
289+
headers_lower = {k.lower(): str(v).strip() for k, v in headers.items()}
290+
headers_lower["host"] = parsed_url.netloc
291+
headers_lower["x-amz-date"] = amz_date
266292
if credentials.token:
267-
headers["x-amz-security-token"] = credentials.token
293+
headers_lower["x-amz-security-token"] = credentials.token
268294

269-
# Create canonical request
270-
canonical_uri = parsed_url.path or "/"
271-
canonical_querystring = parsed_url.query or ""
295+
sorted_header_keys = sorted(headers_lower.keys())
296+
canonical_headers = "".join(f"{k}:{headers_lower[k]}\n" for k in sorted_header_keys)
297+
signed_headers = ";".join(sorted_header_keys)
272298

273-
# Sort headers and create canonical headers
274-
sorted_headers = sorted(headers.items(), key=lambda x: x[0].lower())
275-
canonical_headers = ""
276-
signed_headers = ""
277-
278-
for key, value in sorted_headers:
279-
canonical_headers += f"{key.lower()}:{str(value).strip()}\n"
280-
if signed_headers:
281-
signed_headers += ";"
282-
signed_headers += key.lower()
283-
284-
# Create payload hash
299+
canonical_querystring = _canonical_query(parsed_url.query)
285300
payload_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest()
286301

287-
# Create canonical request
288-
canonical_request = f"{method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
302+
canonical_request = (
303+
f"{method}\n"
304+
f"{parsed_url.path or '/'}\n"
305+
f"{canonical_querystring}\n"
306+
f"{canonical_headers}"
307+
f"{signed_headers}\n"
308+
f"{payload_hash}"
309+
)
289310

290-
# Create string to sign
291311
algorithm = "AWS4-HMAC-SHA256"
292312
credential_scope = f"{date_string}/{region}/{service}/aws4_request"
293-
string_to_sign = f"{algorithm}\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}"
313+
string_to_sign = (
314+
f"{algorithm}\n"
315+
f"{amz_date}\n"
316+
f"{credential_scope}\n"
317+
f"{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}"
318+
)
294319

295-
# Calculate signature
296320
def hmac_sha256(key: bytes, msg: str) -> bytes:
297321
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
298322

@@ -305,11 +329,20 @@ def hmac_sha256(key: bytes, msg: str) -> bytes:
305329
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
306330
).hexdigest()
307331

308-
# Create authorization header
309-
authorization = f"{algorithm} Credential={credentials.access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
310-
headers["authorization"] = authorization
332+
authorization = (
333+
f"{algorithm} "
334+
f"Credential={credentials.access_key}/{credential_scope}, "
335+
f"SignedHeaders={signed_headers}, Signature={signature}"
336+
)
337+
338+
final_headers = headers.copy()
339+
final_headers["Host"] = parsed_url.netloc
340+
final_headers["X-Amz-Date"] = amz_date
341+
if credentials.token:
342+
final_headers["X-Amz-Security-Token"] = credentials.token
343+
final_headers["Authorization"] = authorization
311344

312-
return headers
345+
return final_headers
313346

314347

315348
def create_aws_attestation() -> WorkloadIdentityAttestation | None:
@@ -331,19 +364,17 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None:
331364
sts_hostname = get_aws_sts_hostname(region)
332365
url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15"
333366

334-
headers = {
335-
"Host": sts_hostname,
367+
base_headers = {
336368
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
337369
}
338370

339-
# Sign the request
340371
signed_headers = aws_signature_v4_sign(
341372
credentials=credentials,
342373
method="POST",
343374
url=url,
344375
region=region,
345376
service="sts",
346-
headers=headers,
377+
headers=base_headers,
347378
)
348379

349380
# Create attestation request
@@ -353,7 +384,6 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None:
353384
"headers": signed_headers,
354385
}
355386

356-
# Encode to base64
357387
credential = b64encode(json.dumps(attestation_request).encode("utf-8")).decode(
358388
"utf-8"
359389
)

0 commit comments

Comments
 (0)