66from abc import ABC , abstractmethod
77from time import time
88from unittest import mock
9+ from unittest .mock import patch
910from urllib .parse import parse_qs , urlparse
1011
1112import jwt
@@ -81,6 +82,10 @@ def handle_request(self, method, parsed_url, headers, timeout):
8182 def handle_unexpected_hostname (self ):
8283 return ConnectTimeout ()
8384
85+ def get_environment_variables (self ) -> dict [str , str ]:
86+ """Returns a dictionary of environment variables to patch in to fake the metadata service."""
87+ return {}
88+
8489 def _handle_get (self , url , headers = None , timeout = None ):
8590 """Handles requests.get() calls by converting them to request() format."""
8691 if headers is None :
@@ -125,6 +130,7 @@ def __enter__(self):
125130 side_effect = ConnectTimeout (),
126131 )
127132 )
133+ self .patchers .append (patch .dict (os .environ , self .get_environment_variables ()))
128134 for patcher in self .patchers :
129135 patcher .__enter__ ()
130136 return self
@@ -246,22 +252,14 @@ def handle_request(self, method, parsed_url, headers, timeout):
246252 self .token = gen_dummy_id_token (sub = self .sub , iss = self .iss , aud = resource )
247253 return build_response (json .dumps ({"access_token" : self .token }).encode ("utf-8" ))
248254
249- def __enter__ (self ):
250- # In addition to the normal patching, we need to set the environment variables that Azure Functions would set.
251- os .environ ["IDENTITY_ENDPOINT" ] = self .identity_endpoint
252- os .environ ["IDENTITY_HEADER" ] = self .identity_header
253- os .environ ["FUNCTIONS_WORKER_RUNTIME" ] = self .functions_worker_runtime
254- os .environ ["FUNCTIONS_EXTENSION_VERSION" ] = self .functions_extension_version
255- os .environ ["AzureWebJobsStorage" ] = self .azure_web_jobs_storage
256- return super ().__enter__ ()
257-
258- def __exit__ (self , * args , ** kwargs ):
259- os .environ .pop ("IDENTITY_ENDPOINT" )
260- os .environ .pop ("IDENTITY_HEADER" )
261- os .environ .pop ("FUNCTIONS_WORKER_RUNTIME" )
262- os .environ .pop ("FUNCTIONS_EXTENSION_VERSION" )
263- os .environ .pop ("AzureWebJobsStorage" )
264- return super ().__exit__ (* args , ** kwargs )
255+ def get_environment_variables (self ) -> dict [str , str ]:
256+ return {
257+ "IDENTITY_ENDPOINT" : self .identity_endpoint ,
258+ "IDENTITY_HEADER" : self .identity_header ,
259+ "FUNCTIONS_WORKER_RUNTIME" : self .functions_worker_runtime ,
260+ "FUNCTIONS_EXTENSION_VERSION" : self .functions_extension_version ,
261+ "AzureWebJobsStorage" : self .azure_web_jobs_storage ,
262+ }
265263
266264
267265class FakeGceMetadataService (FakeMetadataService ):
@@ -323,18 +321,12 @@ def reset_defaults(self):
323321 self .k_revision = "test-revision"
324322 self .k_configuration = "test-configuration"
325323
326- def __enter__ (self ):
327- # We need to set the environment variables that GCE Cloud Run Service would set.
328- os .environ ["K_SERVICE" ] = self .k_service
329- os .environ ["K_REVISION" ] = self .k_revision
330- os .environ ["K_CONFIGURATION" ] = self .k_configuration
331- return super ().__enter__ ()
332-
333- def __exit__ (self , * args , ** kwargs ):
334- os .environ .pop ("K_SERVICE" )
335- os .environ .pop ("K_REVISION" )
336- os .environ .pop ("K_CONFIGURATION" )
337- return super ().__exit__ (* args , ** kwargs )
324+ def get_environment_variables (self ) -> dict [str , str ]:
325+ return {
326+ "K_SERVICE" : self .k_service ,
327+ "K_REVISION" : self .k_revision ,
328+ "K_CONFIGURATION" : self .k_configuration ,
329+ }
338330
339331
340332class FakeGceCloudRunJobService (FakeMetadataService ):
@@ -351,16 +343,11 @@ def reset_defaults(self):
351343 self .cloud_run_job = "test-job"
352344 self .cloud_run_execution = "test-execution"
353345
354- def __enter__ (self ):
355- # We need to set the environment variables that GCE Cloud Run Service would set.
356- os .environ ["CLOUD_RUN_JOB" ] = self .cloud_run_job
357- os .environ ["CLOUD_RUN_EXECUTION" ] = self .cloud_run_execution
358- return super ().__enter__ ()
359-
360- def __exit__ (self , * args , ** kwargs ):
361- os .environ .pop ("CLOUD_RUN_JOB" )
362- os .environ .pop ("CLOUD_RUN_EXECUTION" )
363- return super ().__exit__ (* args , ** kwargs )
346+ def get_environment_variables (self ) -> dict [str , str ]:
347+ return {
348+ "CLOUD_RUN_JOB" : self .cloud_run_job ,
349+ "CLOUD_RUN_EXECUTION" : self .cloud_run_execution ,
350+ }
364351
365352
366353class FakeGitHubActionsService (FakeMetadataService ):
@@ -376,14 +363,8 @@ def handle_request(self, method, parsed_url, headers, timeout):
376363 def reset_defaults (self ):
377364 self .github_actions = "github-actions"
378365
379- def __enter__ (self ):
380- # We need to set the environment variables that GCE Cloud Run Service would set.
381- os .environ ["GITHUB_ACTIONS" ] = self .github_actions
382- return super ().__enter__ ()
383-
384- def __exit__ (self , * args , ** kwargs ):
385- os .environ .pop ("GITHUB_ACTIONS" )
386- return super ().__exit__ (* args , ** kwargs )
366+ def get_environment_variables (self ) -> dict [str , str ]:
367+ return {"GITHUB_ACTIONS" : self .github_actions }
387368
388369
389370class FakeAwsEnvironment :
0 commit comments