Skip to content

Commit c27b8ec

Browse files
updated os dictionary patching so it shouldn't use the actual OS environment variables and can be independent of the platfrom it is running on. This helps with the GitHub actions test since the test is actually run on GitHub actions.
1 parent a2c4049 commit c27b8ec

File tree

1 file changed

+27
-46
lines changed

1 file changed

+27
-46
lines changed

test/csp_helpers.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from abc import ABC, abstractmethod
77
from time import time
88
from unittest import mock
9+
from unittest.mock import patch
910
from urllib.parse import parse_qs, urlparse
1011

1112
import jwt
@@ -81,6 +82,10 @@ def handle_request(self, method, parsed_url, headers, timeout):
8182
def handle_unexpected_hostname(self):
8283
return ConnectTimeout()
8384

85+
def get_environment_variables(self) -> dict[str, str]:
86+
"""Returns a dictionary of environment variables to patch in to fake the metadata service."""
87+
return {}
88+
8489
def _handle_get(self, url, headers=None, timeout=None):
8590
"""Handles requests.get() calls by converting them to request() format."""
8691
if headers is None:
@@ -125,6 +130,7 @@ def __enter__(self):
125130
side_effect=ConnectTimeout(),
126131
)
127132
)
133+
self.patchers.append(patch.dict(os.environ, self.get_environment_variables()))
128134
for patcher in self.patchers:
129135
patcher.__enter__()
130136
return self
@@ -246,22 +252,14 @@ def handle_request(self, method, parsed_url, headers, timeout):
246252
self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource)
247253
return build_response(json.dumps({"access_token": self.token}).encode("utf-8"))
248254

249-
def __enter__(self):
250-
# In addition to the normal patching, we need to set the environment variables that Azure Functions would set.
251-
os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint
252-
os.environ["IDENTITY_HEADER"] = self.identity_header
253-
os.environ["FUNCTIONS_WORKER_RUNTIME"] = self.functions_worker_runtime
254-
os.environ["FUNCTIONS_EXTENSION_VERSION"] = self.functions_extension_version
255-
os.environ["AzureWebJobsStorage"] = self.azure_web_jobs_storage
256-
return super().__enter__()
257-
258-
def __exit__(self, *args, **kwargs):
259-
os.environ.pop("IDENTITY_ENDPOINT")
260-
os.environ.pop("IDENTITY_HEADER")
261-
os.environ.pop("FUNCTIONS_WORKER_RUNTIME")
262-
os.environ.pop("FUNCTIONS_EXTENSION_VERSION")
263-
os.environ.pop("AzureWebJobsStorage")
264-
return super().__exit__(*args, **kwargs)
255+
def get_environment_variables(self) -> dict[str, str]:
256+
return {
257+
"IDENTITY_ENDPOINT": self.identity_endpoint,
258+
"IDENTITY_HEADER": self.identity_header,
259+
"FUNCTIONS_WORKER_RUNTIME": self.functions_worker_runtime,
260+
"FUNCTIONS_EXTENSION_VERSION": self.functions_extension_version,
261+
"AzureWebJobsStorage": self.azure_web_jobs_storage,
262+
}
265263

266264

267265
class FakeGceMetadataService(FakeMetadataService):
@@ -323,18 +321,12 @@ def reset_defaults(self):
323321
self.k_revision = "test-revision"
324322
self.k_configuration = "test-configuration"
325323

326-
def __enter__(self):
327-
# We need to set the environment variables that GCE Cloud Run Service would set.
328-
os.environ["K_SERVICE"] = self.k_service
329-
os.environ["K_REVISION"] = self.k_revision
330-
os.environ["K_CONFIGURATION"] = self.k_configuration
331-
return super().__enter__()
332-
333-
def __exit__(self, *args, **kwargs):
334-
os.environ.pop("K_SERVICE")
335-
os.environ.pop("K_REVISION")
336-
os.environ.pop("K_CONFIGURATION")
337-
return super().__exit__(*args, **kwargs)
324+
def get_environment_variables(self) -> dict[str, str]:
325+
return {
326+
"K_SERVICE": self.k_service,
327+
"K_REVISION": self.k_revision,
328+
"K_CONFIGURATION": self.k_configuration,
329+
}
338330

339331

340332
class FakeGceCloudRunJobService(FakeMetadataService):
@@ -351,16 +343,11 @@ def reset_defaults(self):
351343
self.cloud_run_job = "test-job"
352344
self.cloud_run_execution = "test-execution"
353345

354-
def __enter__(self):
355-
# We need to set the environment variables that GCE Cloud Run Service would set.
356-
os.environ["CLOUD_RUN_JOB"] = self.cloud_run_job
357-
os.environ["CLOUD_RUN_EXECUTION"] = self.cloud_run_execution
358-
return super().__enter__()
359-
360-
def __exit__(self, *args, **kwargs):
361-
os.environ.pop("CLOUD_RUN_JOB")
362-
os.environ.pop("CLOUD_RUN_EXECUTION")
363-
return super().__exit__(*args, **kwargs)
346+
def get_environment_variables(self) -> dict[str, str]:
347+
return {
348+
"CLOUD_RUN_JOB": self.cloud_run_job,
349+
"CLOUD_RUN_EXECUTION": self.cloud_run_execution,
350+
}
364351

365352

366353
class FakeGitHubActionsService(FakeMetadataService):
@@ -376,14 +363,8 @@ def handle_request(self, method, parsed_url, headers, timeout):
376363
def reset_defaults(self):
377364
self.github_actions = "github-actions"
378365

379-
def __enter__(self):
380-
# We need to set the environment variables that GCE Cloud Run Service would set.
381-
os.environ["GITHUB_ACTIONS"] = self.github_actions
382-
return super().__enter__()
383-
384-
def __exit__(self, *args, **kwargs):
385-
os.environ.pop("GITHUB_ACTIONS")
386-
return super().__exit__(*args, **kwargs)
366+
def get_environment_variables(self) -> dict[str, str]:
367+
return {"GITHUB_ACTIONS": self.github_actions}
387368

388369

389370
class FakeAwsEnvironment:

0 commit comments

Comments
 (0)