Skip to content

Commit b40e1df

Browse files
SNOW-2203079: Keep csp helpers untouched
1 parent 53723c2 commit b40e1df

File tree

1 file changed

+36
-101
lines changed

1 file changed

+36
-101
lines changed

test/csp_helpers.py

Lines changed: 36 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#!/usr/bin/env python
2-
from __future__ import annotations
3-
42
import datetime
53
import json
64
import logging
@@ -11,14 +9,11 @@
119
from urllib.parse import parse_qs, urlparse
1210

1311
import jwt
14-
15-
# Boto is left as a development-dependency - to be sure our http requests correspond to the appropriate behavior and old driver tests are passing in the future
1612
from botocore.awsrequest import AWSRequest
1713
from botocore.credentials import Credentials
1814

1915
from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError
2016
from snowflake.connector.vendored.requests.models import Response
21-
from snowflake.connector.wif_util import AwsCredentials
2217

2318
logger = logging.getLogger(__name__)
2419

@@ -97,12 +92,11 @@ def __enter__(self):
9792
"""Patches the relevant HTTP calls when entering as a context manager."""
9893
self.reset_defaults()
9994
self.patchers = []
100-
# Session.request is used by the direct metadata service API calls from our code. This is the main
95+
# requests.request is used by the direct metadata service API calls from our code. This is the main
10196
# thing being faked here.
10297
self.patchers.append(
10398
mock.patch(
104-
"snowflake.connector.vendored.requests.Session.request",
105-
side_effect=self,
99+
"snowflake.connector.vendored.requests.request", side_effect=self
106100
)
107101
)
108102
# HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we
@@ -250,124 +244,65 @@ def handle_request(self, method, parsed_url, headers, timeout):
250244

251245

252246
class FakeAwsEnvironment:
253-
"""Emulates AWS for both the legacy boto path and the new SDK-free helpers."""
247+
"""Emulates the AWS environment-specific functions used in wif_util.py.
248+
249+
Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so
250+
emulating them here would be complex and fragile. Instead, we emulate the higher-level functions
251+
called by the connector code.
252+
"""
254253

255254
def __init__(self):
256255
# Defaults used for generating a token. Can be overriden in individual tests.
257-
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-abc123"
256+
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
258257
self.region = "us-east-1"
258+
self.credentials = Credentials(access_key="ak", secret_key="sk")
259259

260-
# boto-style creds (used by old tests / patches)
261-
self.boto_creds = Credentials("AKIA123", "SECRET123", token="SESSION_TOKEN")
262-
263-
# util-style creds (returned by get_aws_credentials)
264-
self.util_creds = AwsCredentials(
265-
access_key=self.boto_creds.access_key,
266-
secret_key=self.boto_creds.secret_key,
267-
token=self.boto_creds.token,
268-
)
269-
270-
def get_region(self, *_, **__) -> str:
260+
def get_region(self):
271261
return self.region
272262

273-
def get_arn(self, *_, **__) -> str:
263+
def get_arn(self):
274264
return self.arn
275265

276-
def get_boto_credentials(self, *_, **__) -> Credentials | None:
277-
return self.boto_creds
278-
279-
def get_aws_credentials(self, *_, **__) -> AwsCredentials | None:
280-
return self.util_creds
266+
def get_credentials(self):
267+
return self.credentials
281268

282-
def sign_request(self, request: AWSRequest) -> None:
283-
"""
284-
Fake replacement for botocore SigV4Auth.add_auth that produces the same
285-
*static* parts of the Authorization header (everything before
286-
`Signature=`).
287-
"""
288-
# Add the headers a real signer would inject
289-
utc_now = datetime.datetime.utcnow()
290-
amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ")
291-
date_stamp = utc_now.strftime("%Y%m%d")
292-
293-
request.headers["X-Amz-Date"] = amz_date
294-
request.headers["X-Amz-Security-Token"] = self.util_creds.token
295-
296-
# Host header is already set by the test; add it if a future test forgets
297-
if "Host" not in request.headers:
298-
request.headers["Host"] = urlparse(request.url).netloc
299-
300-
# Build the signed-headers list
301-
signed_headers = ";".join(sorted(h.lower() for h in request.headers.keys()))
302-
303-
credential_scope = f"{date_stamp}/{self.region}/sts/aws4_request"
304-
305-
request.headers["Authorization"] = (
306-
"AWS4-HMAC-SHA256 "
307-
f"Credential={self.util_creds.access_key}/{credential_scope}, "
308-
f"SignedHeaders={signed_headers}, Signature=<sig>"
269+
def sign_request(self, request: AWSRequest):
270+
request.headers.add_header("X-Amz-Date", datetime.time().isoformat())
271+
request.headers.add_header("X-Amz-Security-Token", "<TOKEN>")
272+
request.headers.add_header(
273+
"Authorization",
274+
f"AWS4-HMAC-SHA256 Credential=<cred>, SignedHeaders={';'.join(request.headers.keys())}, Signature=<sig>",
309275
)
310276

311277
def __enter__(self):
312-
# Preserve existing env and then set creds/region for util fallback
313-
self._old_env = {
314-
k: os.environ.get(k)
315-
for k in (
316-
"AWS_ACCESS_KEY_ID",
317-
"AWS_SECRET_ACCESS_KEY",
318-
"AWS_SESSION_TOKEN",
319-
"AWS_REGION",
320-
)
321-
}
322-
os.environ.update(
323-
{
324-
"AWS_ACCESS_KEY_ID": self.util_creds.access_key,
325-
"AWS_SECRET_ACCESS_KEY": self.util_creds.secret_key,
326-
"AWS_SESSION_TOKEN": self.util_creds.token or "",
327-
"AWS_REGION": self.region,
328-
}
329-
)
330-
331-
self.patchers = [
332-
# boto patches - for old driver tests
278+
# Patch the relevant functions to do what we want.
279+
self.patchers = []
280+
self.patchers.append(
333281
mock.patch(
334282
"boto3.session.Session.get_credentials",
335-
side_effect=self.get_boto_credentials,
336-
),
283+
side_effect=self.get_credentials,
284+
)
285+
)
286+
self.patchers.append(
337287
mock.patch(
338288
"botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request
339-
),
340-
# http approach patches - for new driver tests
289+
)
290+
)
291+
self.patchers.append(
341292
mock.patch(
342293
"snowflake.connector.wif_util.get_aws_region",
343294
side_effect=self.get_region,
344-
),
345-
mock.patch(
346-
"snowflake.connector.wif_util.get_aws_credentials",
347-
side_effect=self.get_aws_credentials,
348-
),
349-
mock.patch(
350-
"snowflake.connector.wif_util.get_aws_arn",
351-
side_effect=self.get_arn,
352-
create=True,
353-
),
354-
# never contact IMDS for token
295+
)
296+
)
297+
self.patchers.append(
355298
mock.patch(
356-
"snowflake.connector.wif_util._imds_v2_token", return_value=None
357-
),
358-
]
359-
299+
"snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn
300+
)
301+
)
360302
for patcher in self.patchers:
361303
patcher.__enter__()
362304
return self
363305

364306
def __exit__(self, *args, **kwargs):
365307
for patcher in self.patchers:
366308
patcher.__exit__(*args, **kwargs)
367-
368-
# restore previous env
369-
for k, v in self._old_env.items():
370-
if v is None:
371-
os.environ.pop(k, None)
372-
else:
373-
os.environ[k] = v

0 commit comments

Comments
 (0)