@@ -100,10 +100,20 @@ def _clean_env_vars_for_scope() -> dict[str, str]:
100100 return {k : "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS }
101101
102102 @abstractmethod
103- def reset_defaults (self ) -> None : ...
103+ def reset_defaults (self ) -> None :
104+ """Resets any default values for test parameters.
105+
106+ This is called in the constructor and when entering as a context manager.
107+ """
108+ pass
104109
105110 @abstractmethod
106- def is_expected_hostname (self , host : str | None ) -> bool : ...
111+ def is_expected_hostname (self , host : str | None ) -> bool :
112+ """Returns true if the passed hostname is the one at which this metadata service is listening.
113+
114+ Used to raise a ConnectTimeout for requests not targeted to this hostname.
115+ """
116+ pass
107117
108118 @abstractmethod
109119 def handle_request (
@@ -112,9 +122,12 @@ def handle_request(
112122 parsed_url ,
113123 headers ,
114124 timeout ,
115- ) -> Response : ...
125+ ) -> Response :
126+ """Main business logic for handling this request. Should return a Response object."""
127+ pass
116128
117129 def __call__ (self , method , url , headers = None , timeout = None , ** _kw ):
130+ """Entry-point for the requests monkey-patch."""
118131 headers = headers or {}
119132 parsed = urlparse (url )
120133 logger .debug ("FakeMetadataService received %s %s %s" , method , url , headers )
@@ -128,6 +141,7 @@ def __call__(self, method, url, headers=None, timeout=None, **_kw):
128141 return self .handle_request (method .upper (), parsed , headers , timeout )
129142
130143 def __enter__ (self ):
144+ """Patches the relevant HTTP calls when entering as a context manager."""
131145 self .reset_defaults ()
132146 self ._context_stack = ExitStack ()
133147 self ._context_stack .enter_context (
@@ -164,6 +178,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
164178 return False
165179
166180 def handle_request (self , * _ ):
181+ # This should never be called because we always raise a ConnectTimeout.
167182 raise AssertionError (
168183 "This should never be called because we always raise a ConnectTimeout."
169184 )
@@ -173,6 +188,7 @@ class FakeAzureVmMetadataService(FakeMetadataService):
173188 """Emulates an environment with the Azure VM metadata service."""
174189
175190 def reset_defaults (self ) -> None :
191+ # Defaults used for generating an Entra ID token. Can be overriden in individual tests.
176192 self .sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269"
177193 self .iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd"
178194
@@ -193,6 +209,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
193209 def handle_request (self , method , parsed_url , headers , timeout ):
194210 query_string = parse_qs (parsed_url .query )
195211
212+ # Reject malformed requests.
196213 if not (
197214 method == "GET"
198215 and parsed_url .path == AZURE_VM_TOKEN_PATH
@@ -205,7 +222,7 @@ def handle_request(self, method, parsed_url, headers, timeout):
205222
206223 resource = query_string ["resource" ][0 ]
207224 self .token = gen_dummy_id_token (sub = self .sub , iss = self .iss , aud = resource )
208- return build_response (json .dumps ({"access_token" : self .token }).encode ())
225+ return build_response (json .dumps ({"access_token" : self .token }).encode ("utf-8" ))
209226
210227
211228class FakeAzureFunctionMetadataService (FakeMetadataService ):
@@ -220,6 +237,7 @@ def reset_defaults(self) -> None:
220237
221238 def __enter__ (self ):
222239 self ._stack = contextlib .ExitStack ()
240+ # Inject the variables without touching os.environ directly
223241 self ._stack .enter_context (
224242 mock .patch .dict (
225243 os .environ ,
@@ -246,6 +264,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
246264 def handle_request (self , method , parsed_url , headers , timeout ):
247265 query_string = parse_qs (parsed_url .query )
248266
267+ # Reject malformed requests.
249268 if not (
250269 method == "GET"
251270 and parsed_url .path == self .parsed_identity_endpoint .path
@@ -269,6 +288,7 @@ class FakeGceMetadataService(FakeMetadataService):
269288 """Simulates GCE metadata endpoint."""
270289
271290 def reset_defaults (self ) -> None :
291+ # Defaults used for generating a token. Can be overriden in individual tests.
272292 self .sub = "123"
273293 self .iss = "https://accounts.google.com"
274294
@@ -289,6 +309,7 @@ def is_expected_hostname(self, host: str | None) -> bool:
289309 def handle_request (self , method , parsed_url , headers , timeout ):
290310 query_string = parse_qs (parsed_url .query )
291311
312+ # Reject malformed requests.
292313 if not (
293314 method == "GET"
294315 and parsed_url .path == GCE_IDENTITY_PATH
@@ -370,11 +391,8 @@ def handle_request(self, method, parsed_url, headers, timeout):
370391 ):
371392 return build_response (self .region .encode ())
372393
373- # New: availability-zone path (region extracted by stripping last char)
374- if (
375- method == "GET" and url == f"{ _IMDS_BASE_URL } { self .__class__ .IMDS_AZ_PATH } "
376- ): # <-- new
377- return build_response (f"{ self .region } a" .encode ()) # <-- new
394+ if method == "GET" and url == f"{ _IMDS_BASE_URL } { self .__class__ .IMDS_AZ_PATH } " :
395+ return build_response (f"{ self .region } a" .encode ())
378396
379397 if (
380398 method == "GET"
@@ -392,6 +410,7 @@ class FakeAwsEnvironment:
392410 """
393411
394412 def __init__ (self ):
413+ # Defaults used for generating a token. Can be overriden in individual tests.
395414 self ._region = "us-east-1"
396415 self .arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
397416 self .credentials : Credentials | None = Credentials (
@@ -406,6 +425,9 @@ def region(self) -> str:
406425
407426 @region .setter
408427 def region (self , new_region : str ) -> None :
428+ """Change runtime region and, if the env-vars already exist,
429+ patch them via ExitStack so they’re cleaned up on __exit__.
430+ """
409431 self ._region = new_region
410432 self ._metadata .region = new_region
411433
@@ -417,6 +439,7 @@ def region(self, new_region: str) -> None:
417439 )
418440
419441 def _prepare_runtime (self ):
442+ """Sub-classes patch env / credentials here."""
420443 return None
421444
422445 def __enter__ (self ):
@@ -441,6 +464,7 @@ def __enter__(self):
441464 )
442465 )
443466
467+ # Keep the metadata stub in sync with the final credential set.
444468 self ._metadata .access_key = (
445469 self .credentials .access_key if self .credentials else None
446470 )
@@ -463,6 +487,7 @@ def __enter__(self):
463487 mock .patch .dict (os .environ , env_for_chain , clear = False )
464488 )
465489
490+ # Runtime-specific tweaks (may change creds / env).
466491 self ._prepare_runtime ()
467492 return self
468493
@@ -492,13 +517,15 @@ class FakeAwsLambda(FakeAwsEnvironment):
492517
493518 def __init__ (self ):
494519 super ().__init__ ()
520+ # Lambda always returns *session* credentials
495521 self .credentials = Credentials (
496522 access_key = "ak" ,
497523 secret_key = "sk" ,
498524 token = "dummy-session-token" ,
499525 )
500526
501527 def _prepare_runtime (self ) -> None :
528+ # Patch env vars via mock.patch.dict so nothing touches os.environ directly
502529 self ._stack .enter_context (
503530 mock .patch .dict (
504531 os .environ ,
0 commit comments