Skip to content

Commit fab3af9

Browse files
SNOW-2183023: fixed http traffix
1 parent fbbc223 commit fab3af9

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

test/csp_helpers.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,20 @@ def _clean_env_vars_for_scope() -> dict[str, str]:
100100
return {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS}
101101

102102
@abstractmethod
103-
def reset_defaults(self) -> None: ...
103+
def reset_defaults(self) -> None:
104+
"""Resets any default values for test parameters.
105+
106+
This is called in the constructor and when entering as a context manager.
107+
"""
108+
pass
104109

105110
@abstractmethod
106-
def is_expected_hostname(self, host: str | None) -> bool: ...
111+
def is_expected_hostname(self, host: str | None) -> bool:
112+
"""Returns true if the passed hostname is the one at which this metadata service is listening.
113+
114+
Used to raise a ConnectTimeout for requests not targeted to this hostname.
115+
"""
116+
pass
107117

108118
@abstractmethod
109119
def handle_request(
@@ -112,9 +122,12 @@ def handle_request(
112122
parsed_url,
113123
headers,
114124
timeout,
115-
) -> Response: ...
125+
) -> Response:
126+
"""Main business logic for handling this request. Should return a Response object."""
127+
pass
116128

117129
def __call__(self, method, url, headers=None, timeout=None, **_kw):
130+
"""Entry-point for the requests monkey-patch."""
118131
headers = headers or {}
119132
parsed = urlparse(url)
120133
logger.debug("FakeMetadataService received %s %s %s", method, url, headers)
@@ -128,6 +141,7 @@ def __call__(self, method, url, headers=None, timeout=None, **_kw):
128141
return self.handle_request(method.upper(), parsed, headers, timeout)
129142

130143
def __enter__(self):
144+
"""Patches the relevant HTTP calls when entering as a context manager."""
131145
self.reset_defaults()
132146
self._context_stack = ExitStack()
133147
self._context_stack.enter_context(
@@ -164,6 +178,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
164178
return False
165179

166180
def handle_request(self, *_):
181+
# This should never be called because we always raise a ConnectTimeout.
167182
raise AssertionError(
168183
"This should never be called because we always raise a ConnectTimeout."
169184
)
@@ -173,6 +188,7 @@ class FakeAzureVmMetadataService(FakeMetadataService):
173188
"""Emulates an environment with the Azure VM metadata service."""
174189

175190
def reset_defaults(self) -> None:
191+
# Defaults used for generating an Entra ID token. Can be overriden in individual tests.
176192
self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269"
177193
self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd"
178194

@@ -193,6 +209,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
193209
def handle_request(self, method, parsed_url, headers, timeout):
194210
query_string = parse_qs(parsed_url.query)
195211

212+
# Reject malformed requests.
196213
if not (
197214
method == "GET"
198215
and parsed_url.path == AZURE_VM_TOKEN_PATH
@@ -205,7 +222,7 @@ def handle_request(self, method, parsed_url, headers, timeout):
205222

206223
resource = query_string["resource"][0]
207224
self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource)
208-
return build_response(json.dumps({"access_token": self.token}).encode())
225+
return build_response(json.dumps({"access_token": self.token}).encode("utf-8"))
209226

210227

211228
class FakeAzureFunctionMetadataService(FakeMetadataService):
@@ -220,6 +237,7 @@ def reset_defaults(self) -> None:
220237

221238
def __enter__(self):
222239
self._stack = contextlib.ExitStack()
240+
# Inject the variables without touching os.environ directly
223241
self._stack.enter_context(
224242
mock.patch.dict(
225243
os.environ,
@@ -246,6 +264,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
246264
def handle_request(self, method, parsed_url, headers, timeout):
247265
query_string = parse_qs(parsed_url.query)
248266

267+
# Reject malformed requests.
249268
if not (
250269
method == "GET"
251270
and parsed_url.path == self.parsed_identity_endpoint.path
@@ -269,6 +288,7 @@ class FakeGceMetadataService(FakeMetadataService):
269288
"""Simulates GCE metadata endpoint."""
270289

271290
def reset_defaults(self) -> None:
291+
# Defaults used for generating a token. Can be overriden in individual tests.
272292
self.sub = "123"
273293
self.iss = "https://accounts.google.com"
274294

@@ -289,6 +309,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
289309
def handle_request(self, method, parsed_url, headers, timeout):
290310
query_string = parse_qs(parsed_url.query)
291311

312+
# Reject malformed requests.
292313
if not (
293314
method == "GET"
294315
and parsed_url.path == GCE_IDENTITY_PATH
@@ -370,11 +391,8 @@ def handle_request(self, method, parsed_url, headers, timeout):
370391
):
371392
return build_response(self.region.encode())
372393

373-
# New: availability-zone path (region extracted by stripping last char)
374-
if (
375-
method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_AZ_PATH}"
376-
): # <-- new
377-
return build_response(f"{self.region}a".encode()) # <-- new
394+
if method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_AZ_PATH}":
395+
return build_response(f"{self.region}a".encode())
378396

379397
if (
380398
method == "GET"
@@ -392,6 +410,7 @@ class FakeAwsEnvironment:
392410
"""
393411

394412
def __init__(self):
413+
# Defaults used for generating a token. Can be overriden in individual tests.
395414
self._region = "us-east-1"
396415
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
397416
self.credentials: Credentials | None = Credentials(
@@ -406,6 +425,9 @@ def region(self) -> str:
406425

407426
@region.setter
408427
def region(self, new_region: str) -> None:
428+
"""Change runtime region and, if the env-vars already exist,
429+
patch them via ExitStack so they’re cleaned up on __exit__.
430+
"""
409431
self._region = new_region
410432
self._metadata.region = new_region
411433

@@ -417,6 +439,7 @@ def region(self, new_region: str) -> None:
417439
)
418440

419441
def _prepare_runtime(self):
442+
"""Sub-classes patch env / credentials here."""
420443
return None
421444

422445
def __enter__(self):
@@ -441,6 +464,7 @@ def __enter__(self):
441464
)
442465
)
443466

467+
# Keep the metadata stub in sync with the final credential set.
444468
self._metadata.access_key = (
445469
self.credentials.access_key if self.credentials else None
446470
)
@@ -463,6 +487,7 @@ def __enter__(self):
463487
mock.patch.dict(os.environ, env_for_chain, clear=False)
464488
)
465489

490+
# Runtime-specific tweaks (may change creds / env).
466491
self._prepare_runtime()
467492
return self
468493

@@ -492,13 +517,15 @@ class FakeAwsLambda(FakeAwsEnvironment):
492517

493518
def __init__(self):
494519
super().__init__()
520+
# Lambda always returns *session* credentials
495521
self.credentials = Credentials(
496522
access_key="ak",
497523
secret_key="sk",
498524
token="dummy-session-token",
499525
)
500526

501527
def _prepare_runtime(self) -> None:
528+
# Patch env vars via mock.patch.dict so nothing touches os.environ directly
502529
self._stack.enter_context(
503530
mock.patch.dict(
504531
os.environ,

test/unit/test_auth_workload_identity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def test_explicit_aws_region_falls_back_to_imds(imds_only_aws_environment):
485485

486486
def test_autodetect_prefers_gcp_when_no_aws_env(fake_gce_metadata_service):
487487
"""
488-
No AWS env-vars + a responsive GCP metadata server GCP selected.
488+
No AWS env-vars + a responsive GCP metadata server -> GCP selected.
489489
"""
490490
auth_class = AuthByWorkloadIdentity(provider=None)
491491
auth_class.prepare()

0 commit comments

Comments
 (0)