|
1 | 1 | #!/usr/bin/env python |
2 | | -from __future__ import annotations |
3 | | - |
4 | 2 | import datetime |
5 | 3 | import json |
6 | 4 | import logging |
|
11 | 9 | from urllib.parse import parse_qs, urlparse |
12 | 10 |
|
13 | 11 | import jwt |
14 | | - |
15 | | -# Boto is left as a development-dependency - to be sure our http requests correspond to the appropriate behavior and old driver tests are passing in the future |
16 | 12 | from botocore.awsrequest import AWSRequest |
17 | 13 | from botocore.credentials import Credentials |
18 | 14 |
|
19 | 15 | from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError |
20 | 16 | from snowflake.connector.vendored.requests.models import Response |
21 | | -from snowflake.connector.wif_util import AwsCredentials |
22 | 17 |
|
23 | 18 | logger = logging.getLogger(__name__) |
24 | 19 |
|
@@ -97,12 +92,11 @@ def __enter__(self): |
97 | 92 | """Patches the relevant HTTP calls when entering as a context manager.""" |
98 | 93 | self.reset_defaults() |
99 | 94 | self.patchers = [] |
100 | | - # Session.request is used by the direct metadata service API calls from our code. This is the main |
| 95 | + # requests.request is used by the direct metadata service API calls from our code. This is the main |
101 | 96 | # thing being faked here. |
102 | 97 | self.patchers.append( |
103 | 98 | mock.patch( |
104 | | - "snowflake.connector.vendored.requests.Session.request", |
105 | | - side_effect=self, |
| 99 | + "snowflake.connector.vendored.requests.request", side_effect=self |
106 | 100 | ) |
107 | 101 | ) |
108 | 102 | # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we |
@@ -250,124 +244,65 @@ def handle_request(self, method, parsed_url, headers, timeout): |
250 | 244 |
|
251 | 245 |
|
252 | 246 | class FakeAwsEnvironment: |
253 | | - """Emulates AWS for both the legacy boto path and the new SDK-free helpers.""" |
| 247 | + """Emulates the AWS environment-specific functions used in wif_util.py. |
| 248 | +
|
| 249 | + Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so |
| 250 | + emulating them here would be complex and fragile. Instead, we emulate the higher-level functions |
| 251 | + called by the connector code. |
| 252 | + """ |
254 | 253 |
|
255 | 254 | def __init__(self): |
256 | 255 | # Defaults used for generating a token. Can be overriden in individual tests. |
257 | | - self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-abc123" |
| 256 | + self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" |
258 | 257 | self.region = "us-east-1" |
| 258 | + self.credentials = Credentials(access_key="ak", secret_key="sk") |
259 | 259 |
|
260 | | - # boto-style creds (used by old tests / patches) |
261 | | - self.boto_creds = Credentials("AKIA123", "SECRET123", token="SESSION_TOKEN") |
262 | | - |
263 | | - # util-style creds (returned by get_aws_credentials) |
264 | | - self.util_creds = AwsCredentials( |
265 | | - access_key=self.boto_creds.access_key, |
266 | | - secret_key=self.boto_creds.secret_key, |
267 | | - token=self.boto_creds.token, |
268 | | - ) |
269 | | - |
270 | | - def get_region(self, *_, **__) -> str: |
| 260 | + def get_region(self): |
271 | 261 | return self.region |
272 | 262 |
|
273 | | - def get_arn(self, *_, **__) -> str: |
| 263 | + def get_arn(self): |
274 | 264 | return self.arn |
275 | 265 |
|
276 | | - def get_boto_credentials(self, *_, **__) -> Credentials | None: |
277 | | - return self.boto_creds |
278 | | - |
279 | | - def get_aws_credentials(self, *_, **__) -> AwsCredentials | None: |
280 | | - return self.util_creds |
| 266 | + def get_credentials(self): |
| 267 | + return self.credentials |
281 | 268 |
|
282 | | - def sign_request(self, request: AWSRequest) -> None: |
283 | | - """ |
284 | | - Fake replacement for botocore SigV4Auth.add_auth that produces the same |
285 | | - *static* parts of the Authorization header (everything before |
286 | | - `Signature=`). |
287 | | - """ |
288 | | - # Add the headers a real signer would inject |
289 | | - utc_now = datetime.datetime.utcnow() |
290 | | - amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") |
291 | | - date_stamp = utc_now.strftime("%Y%m%d") |
292 | | - |
293 | | - request.headers["X-Amz-Date"] = amz_date |
294 | | - request.headers["X-Amz-Security-Token"] = self.util_creds.token |
295 | | - |
296 | | - # Host header is already set by the test; add it if a future test forgets |
297 | | - if "Host" not in request.headers: |
298 | | - request.headers["Host"] = urlparse(request.url).netloc |
299 | | - |
300 | | - # Build the signed-headers list |
301 | | - signed_headers = ";".join(sorted(h.lower() for h in request.headers.keys())) |
302 | | - |
303 | | - credential_scope = f"{date_stamp}/{self.region}/sts/aws4_request" |
304 | | - |
305 | | - request.headers["Authorization"] = ( |
306 | | - "AWS4-HMAC-SHA256 " |
307 | | - f"Credential={self.util_creds.access_key}/{credential_scope}, " |
308 | | - f"SignedHeaders={signed_headers}, Signature=<sig>" |
| 269 | + def sign_request(self, request: AWSRequest): |
| 270 | + request.headers.add_header("X-Amz-Date", datetime.time().isoformat()) |
| 271 | + request.headers.add_header("X-Amz-Security-Token", "<TOKEN>") |
| 272 | + request.headers.add_header( |
| 273 | + "Authorization", |
| 274 | + f"AWS4-HMAC-SHA256 Credential=<cred>, SignedHeaders={';'.join(request.headers.keys())}, Signature=<sig>", |
309 | 275 | ) |
310 | 276 |
|
311 | 277 | def __enter__(self): |
312 | | - # Preserve existing env and then set creds/region for util fallback |
313 | | - self._old_env = { |
314 | | - k: os.environ.get(k) |
315 | | - for k in ( |
316 | | - "AWS_ACCESS_KEY_ID", |
317 | | - "AWS_SECRET_ACCESS_KEY", |
318 | | - "AWS_SESSION_TOKEN", |
319 | | - "AWS_REGION", |
320 | | - ) |
321 | | - } |
322 | | - os.environ.update( |
323 | | - { |
324 | | - "AWS_ACCESS_KEY_ID": self.util_creds.access_key, |
325 | | - "AWS_SECRET_ACCESS_KEY": self.util_creds.secret_key, |
326 | | - "AWS_SESSION_TOKEN": self.util_creds.token or "", |
327 | | - "AWS_REGION": self.region, |
328 | | - } |
329 | | - ) |
330 | | - |
331 | | - self.patchers = [ |
332 | | - # boto patches - for old driver tests |
| 278 | + # Patch the relevant functions to do what we want. |
| 279 | + self.patchers = [] |
| 280 | + self.patchers.append( |
333 | 281 | mock.patch( |
334 | 282 | "boto3.session.Session.get_credentials", |
335 | | - side_effect=self.get_boto_credentials, |
336 | | - ), |
| 283 | + side_effect=self.get_credentials, |
| 284 | + ) |
| 285 | + ) |
| 286 | + self.patchers.append( |
337 | 287 | mock.patch( |
338 | 288 | "botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request |
339 | | - ), |
340 | | - # http approach patches - for new driver tests |
| 289 | + ) |
| 290 | + ) |
| 291 | + self.patchers.append( |
341 | 292 | mock.patch( |
342 | 293 | "snowflake.connector.wif_util.get_aws_region", |
343 | 294 | side_effect=self.get_region, |
344 | | - ), |
345 | | - mock.patch( |
346 | | - "snowflake.connector.wif_util.get_aws_credentials", |
347 | | - side_effect=self.get_aws_credentials, |
348 | | - ), |
349 | | - mock.patch( |
350 | | - "snowflake.connector.wif_util.get_aws_arn", |
351 | | - side_effect=self.get_arn, |
352 | | - create=True, |
353 | | - ), |
354 | | - # never contact IMDS for token |
| 295 | + ) |
| 296 | + ) |
| 297 | + self.patchers.append( |
355 | 298 | mock.patch( |
356 | | - "snowflake.connector.wif_util._imds_v2_token", return_value=None |
357 | | - ), |
358 | | - ] |
359 | | - |
| 299 | + "snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn |
| 300 | + ) |
| 301 | + ) |
360 | 302 | for patcher in self.patchers: |
361 | 303 | patcher.__enter__() |
362 | 304 | return self |
363 | 305 |
|
364 | 306 | def __exit__(self, *args, **kwargs): |
365 | 307 | for patcher in self.patchers: |
366 | 308 | patcher.__exit__(*args, **kwargs) |
367 | | - |
368 | | - # restore previous env |
369 | | - for k, v in self._old_env.items(): |
370 | | - if v is None: |
371 | | - os.environ.pop(k, None) |
372 | | - else: |
373 | | - os.environ[k] = v |
0 commit comments