Skip to content

Commit a8085f7

Browse files
addressed comments related to fixed comments, optimizing platform detection code
1 parent e7995a6 commit a8085f7

File tree

5 files changed

+111
-99
lines changed

5 files changed

+111
-99
lines changed

src/snowflake/connector/auth/_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def base_auth_data(
102102
login_timeout: int | None = None,
103103
network_timeout: int | None = None,
104104
socket_timeout: int | None = None,
105-
platform_detection_timeout: float | None = None,
105+
platform_detection_timeout_seconds: float | None = None,
106106
):
107107
return {
108108
"data": {
@@ -125,7 +125,7 @@ def base_auth_data(
125125
"NETWORK_TIMEOUT": network_timeout,
126126
"SOCKET_TIMEOUT": socket_timeout,
127127
"PLATFORM": detect_platforms(
128-
timeout_seconds=platform_detection_timeout
128+
platform_detection_timeout_seconds=platform_detection_timeout_seconds
129129
),
130130
},
131131
},

src/snowflake/connector/platform_detection.py

Lines changed: 81 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class _DetectionState(Enum):
2121
TIMEOUT = "timeout"
2222

2323

24-
def is_ec2_instance(timeout_seconds: float):
24+
def is_ec2_instance(platform_detection_timeout_seconds: float):
2525
"""
2626
Check if the current environment is running on an AWS EC2 instance.
2727
@@ -31,13 +31,15 @@ def is_ec2_instance(timeout_seconds: float):
3131
It will ignore the token if on IMDSv1 and use the token if on IMDSv2.
3232
3333
Args:
34-
timeout_seconds: Timeout value for the metadata service request.
34+
platform_detection_timeout_seconds: Timeout value for the metadata service request.
3535
3636
Returns:
3737
_DetectionState: DETECTED if running on EC2, NOT_DETECTED otherwise.
3838
"""
3939
try:
40-
fetcher = IMDSFetcher(timeout=timeout_seconds, num_attempts=1)
40+
fetcher = IMDSFetcher(
41+
timeout=platform_detection_timeout_seconds, num_attempts=1
42+
)
4143
document = fetcher._get_request(
4244
"/latest/dynamic/instance-identity/document",
4345
None,
@@ -71,7 +73,7 @@ def is_aws_lambda():
7173

7274
def is_valid_arn_for_wif(arn: str) -> bool:
7375
"""
74-
Validate if an AWS ARN is suitable for Web Identity Federation (WIF).
76+
Validate if an AWS ARN is suitable for use with Snowflake's Workload Identity Federation (WIF).
7577
7678
Args:
7779
arn: The AWS ARN string to validate.
@@ -86,47 +88,46 @@ def is_valid_arn_for_wif(arn: str) -> bool:
8688
return any(re.match(p, arn) for p in patterns)
8789

8890

89-
def has_aws_identity(timeout_seconds: float):
91+
def has_aws_identity(platform_detection_timeout_seconds: float):
9092
"""
9193
Check if the current environment has a valid AWS identity for authentication.
9294
9395
If we retrieve an ARN from the caller identity and it is a valid WIF ARN,
9496
then we assume we have a valid AWS identity for authentication.
9597
9698
Args:
97-
timeout_seconds: Timeout value for AWS API calls.
99+
platform_detection_timeout_seconds: Timeout value for AWS API calls.
98100
99101
Returns:
100102
_DetectionState: DETECTED if valid AWS identity exists, NOT_DETECTED otherwise.
101103
"""
102104
try:
103105
config = Config(
104-
connect_timeout=timeout_seconds,
105-
read_timeout=timeout_seconds,
106+
connect_timeout=platform_detection_timeout_seconds,
107+
read_timeout=platform_detection_timeout_seconds,
106108
retries={"total_max_attempts": 1},
107109
)
108110
caller_identity = boto3.client("sts", config=config).get_caller_identity()
109111
if not caller_identity or "Arn" not in caller_identity:
110112
return _DetectionState.NOT_DETECTED
111-
else:
112-
return (
113-
_DetectionState.DETECTED
114-
if is_valid_arn_for_wif(caller_identity["Arn"])
115-
else _DetectionState.NOT_DETECTED
116-
)
113+
return (
114+
_DetectionState.DETECTED
115+
if is_valid_arn_for_wif(caller_identity["Arn"])
116+
else _DetectionState.NOT_DETECTED
117+
)
117118
except Exception:
118119
return _DetectionState.NOT_DETECTED
119120

120121

121-
def is_azure_vm(timeout_seconds: float):
122+
def is_azure_vm(platform_detection_timeout_seconds: float):
122123
"""
123124
Check if the current environment is running on an Azure Virtual Machine.
124125
125126
If we query the Azure Instance Metadata Service and receive an HTTP 200 response,
126127
then we assume we are running on an Azure VM.
127128
128129
Args:
129-
timeout_seconds: Timeout value for the metadata service request.
130+
platform_detection_timeout_seconds: Timeout value for the metadata service request.
130131
131132
Returns:
132133
_DetectionState: DETECTED if on Azure VM, TIMEOUT if request times out,
@@ -136,7 +137,7 @@ def is_azure_vm(timeout_seconds: float):
136137
token_resp = requests.get(
137138
"http://169.254.169.254/metadata/instance?api-version=2021-02-01",
138139
headers={"Metadata": "True"},
139-
timeout=timeout_seconds,
140+
timeout=platform_detection_timeout_seconds,
140141
)
141142
return (
142143
_DetectionState.DETECTED
@@ -174,7 +175,7 @@ def is_azure_function():
174175

175176

176177
def is_managed_identity_available_on_azure_vm(
177-
timeout_seconds, resource="https://management.azure.com/"
178+
platform_detection_timeout_seconds, resource="https://management.azure.com"
178179
):
179180
"""
180181
Check if Azure Managed Identity is available and accessible on an Azure VM.
@@ -184,7 +185,7 @@ def is_managed_identity_available_on_azure_vm(
184185
then we assume managed identity is available.
185186
186187
Args:
187-
timeout_seconds: Timeout value for the metadata service request.
188+
platform_detection_timeout_seconds: Timeout value for the metadata service request.
188189
resource: The Azure resource URI to request a token for.
189190
190191
Returns:
@@ -194,7 +195,9 @@ def is_managed_identity_available_on_azure_vm(
194195
endpoint = f"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource={resource}"
195196
headers = {"Metadata": "true"}
196197
try:
197-
response = requests.get(endpoint, headers=headers, timeout=timeout_seconds)
198+
response = requests.get(
199+
endpoint, headers=headers, timeout=platform_detection_timeout_seconds
200+
)
198201
return (
199202
_DetectionState.DETECTED
200203
if response.status_code == 200
@@ -206,62 +209,56 @@ def is_managed_identity_available_on_azure_vm(
206209
return _DetectionState.NOT_DETECTED
207210

208211

209-
def has_azure_managed_identity(timeout_seconds: float):
212+
def is_managed_identity_available_on_azure_function():
213+
return bool(os.environ.get("IDENTITY_HEADER"))
214+
215+
216+
def has_azure_managed_identity(platform_detection_timeout_seconds: float):
210217
"""
211218
Determine if Azure Managed Identity is available in the current environment.
212219
213220
If we are on Azure Functions and the IDENTITY_HEADER environment variable exists,
214221
then we assume managed identity is available.
215222
If we are on an Azure VM and can mint an access token from the managed identity endpoint,
216223
then we assume managed identity is available.
217-
Assumes timeout state if either VM or Function detection timed out.
218-
Otherwise, assumes it is not available
224+
Handles Azure Functions first since the checks are faster
225+
Handles Azure VM checks second since they involve network calls.
219226
220227
Args:
221-
timeout_seconds: Timeout value for managed identity checks.
228+
platform_detection_timeout_seconds: Timeout value for managed identity checks.
222229
223230
Returns:
224231
_DetectionState: DETECTED if managed identity is available, TIMEOUT if
225232
detection timed out, NOT_DETECTED otherwise.
226233
"""
227-
has_azure_function_managed_identity = (
228-
_DetectionState.DETECTED
229-
if os.environ.get("IDENTITY_HEADER")
230-
else _DetectionState.NOT_DETECTED
231-
)
232-
has_azure_vm_managed_identity = is_managed_identity_available_on_azure_vm(
233-
timeout_seconds
234-
)
235-
if (
236-
has_azure_vm_managed_identity == _DetectionState.DETECTED
237-
or has_azure_function_managed_identity == _DetectionState.DETECTED
238-
):
239-
return _DetectionState.DETECTED
240-
if (
241-
has_azure_vm_managed_identity == _DetectionState.TIMEOUT
242-
or has_azure_function_managed_identity == _DetectionState.TIMEOUT
243-
):
244-
return _DetectionState.TIMEOUT
245-
return _DetectionState.NOT_DETECTED
234+
# short circuit early to save on latency and avoid minting an unnecessary token
235+
if is_azure_function() == _DetectionState.DETECTED:
236+
return (
237+
_DetectionState.DETECTED
238+
if is_managed_identity_available_on_azure_function()
239+
else _DetectionState.NOT_DETECTED
240+
)
241+
return is_managed_identity_available_on_azure_vm(platform_detection_timeout_seconds)
246242

247243

248-
def is_gce_vm(timeout_seconds: float):
244+
def is_gce_vm(platform_detection_timeout_seconds: float):
249245
"""
250246
Check if the current environment is running on Google Compute Engine (GCE).
251247
252248
If we query the Google metadata server and receive a response with the
253249
"Metadata-Flavor: Google" header, then we assume we are running on GCE.
254250
255251
Args:
256-
timeout_seconds: Timeout value for the metadata service request.
252+
platform_detection_timeout_seconds: Timeout value for the metadata service request.
257253
258254
Returns:
259255
_DetectionState: DETECTED if on GCE, TIMEOUT if request times out,
260256
NOT_DETECTED otherwise.
261257
"""
262258
try:
263259
response = requests.get(
264-
"http://metadata.google.internal", timeout=timeout_seconds
260+
"http://metadata.google.internal",
261+
timeout=platform_detection_timeout_seconds,
265262
)
266263
return (
267264
_DetectionState.DETECTED
@@ -274,7 +271,7 @@ def is_gce_vm(timeout_seconds: float):
274271
return _DetectionState.NOT_DETECTED
275272

276273

277-
def is_gce_cloud_run_service():
274+
def is_gcp_cloud_run_service():
278275
"""
279276
Check if the current environment is running in Google Cloud Run service.
280277
@@ -293,7 +290,7 @@ def is_gce_cloud_run_service():
293290
)
294291

295292

296-
def is_gce_cloud_run_job():
293+
def is_gcp_cloud_run_job():
297294
"""
298295
Check if the current environment is running in Google Cloud Run job.
299296
@@ -312,15 +309,15 @@ def is_gce_cloud_run_job():
312309
)
313310

314311

315-
def has_gcp_identity(timeout_seconds: float):
312+
def has_gcp_identity(platform_detection_timeout_seconds: float):
316313
"""
317314
Check if the current environment has a valid Google Cloud Platform identity.
318315
319316
If we query the GCP metadata service for the default service account email
320317
and receive a non-empty response, then we assume we have a valid GCP identity.
321318
322319
Args:
323-
timeout_seconds: Timeout value for the metadata service request.
320+
platform_detection_timeout_seconds: Timeout value for the metadata service request.
324321
Returns:
325322
_DetectionState: DETECTED if valid GCP identity exists, TIMEOUT if request
326323
times out, NOT_DETECTED otherwise.
@@ -329,7 +326,7 @@ def has_gcp_identity(timeout_seconds: float):
329326
response = requests.get(
330327
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/email",
331328
headers={"Metadata-Flavor": "Google"},
332-
timeout=timeout_seconds,
329+
timeout=platform_detection_timeout_seconds,
333330
)
334331
return (
335332
_DetectionState.DETECTED
@@ -360,13 +357,13 @@ def is_github_action():
360357

361358

362359
@cache
363-
def detect_platforms(timeout_seconds: float | None) -> list[str]:
360+
def detect_platforms(platform_detection_timeout_seconds: float | None) -> list[str]:
364361
"""
365362
Detect all potential platforms that the current environment may be running on.
366363
Swallows all exceptions and returns an empty list if any exception occurs to not affect main driver functionality.
367364
368365
Args:
369-
timeout_seconds: Timeout value for platform detection requests. Defaults to 0.2 seconds
366+
platform_detection_timeout_seconds: Timeout value for platform detection requests. Defaults to 0.2 seconds
370367
if None is provided.
371368
372369
Returns:
@@ -375,27 +372,42 @@ def detect_platforms(timeout_seconds: float | None) -> list[str]:
375372
exception occurs during detection.
376373
"""
377374
try:
378-
if timeout_seconds is None:
379-
timeout_seconds = 0.2
380-
381-
with ThreadPoolExecutor(max_workers=10) as executor:
375+
if platform_detection_timeout_seconds is None:
376+
platform_detection_timeout_seconds = 0.2
377+
378+
# Run environment-only checks synchronously (no network calls, no threading overhead)
379+
platforms = {
380+
"is_aws_lambda": is_aws_lambda(),
381+
"is_azure_function": is_azure_function(),
382+
"is_gce_cloud_run_service": is_gcp_cloud_run_service(),
383+
"is_gce_cloud_run_job": is_gcp_cloud_run_job(),
384+
"is_github_action": is_github_action(),
385+
}
386+
387+
# Run network-calling functions in parallel
388+
with ThreadPoolExecutor(max_workers=6) as executor:
382389
futures = {
383-
"is_ec2_instance": executor.submit(is_ec2_instance, timeout_seconds),
384-
"is_aws_lambda": executor.submit(is_aws_lambda),
385-
"has_aws_identity": executor.submit(has_aws_identity, timeout_seconds),
386-
"is_azure_vm": executor.submit(is_azure_vm, timeout_seconds),
387-
"is_azure_function": executor.submit(is_azure_function),
390+
"is_ec2_instance": executor.submit(
391+
is_ec2_instance, platform_detection_timeout_seconds
392+
),
393+
"has_aws_identity": executor.submit(
394+
has_aws_identity, platform_detection_timeout_seconds
395+
),
396+
"is_azure_vm": executor.submit(
397+
is_azure_vm, platform_detection_timeout_seconds
398+
),
388399
"azure_managed_identity": executor.submit(
389-
has_azure_managed_identity, timeout_seconds
400+
has_azure_managed_identity, platform_detection_timeout_seconds
401+
),
402+
"is_gce_vm": executor.submit(
403+
is_gce_vm, platform_detection_timeout_seconds
404+
),
405+
"has_gcp_identity": executor.submit(
406+
has_gcp_identity, platform_detection_timeout_seconds
390407
),
391-
"is_gce_vm": executor.submit(is_gce_vm, timeout_seconds),
392-
"is_gce_cloud_run_service": executor.submit(is_gce_cloud_run_service),
393-
"is_gce_cloud_run_job": executor.submit(is_gce_cloud_run_job),
394-
"has_gcp_identity": executor.submit(has_gcp_identity, timeout_seconds),
395-
"is_github_action": executor.submit(is_github_action),
396408
}
397409

398-
platforms = {key: future.result() for key, future in futures.items()}
410+
platforms.update({key: future.result() for key, future in futures.items()})
399411

400412
detected_platforms = []
401413
for platform_name, detection_state in platforms.items():

test/csp_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def handle_request(self, method, parsed_url, headers, timeout):
8080
return ConnectTimeout()
8181

8282
def handle_unexpected_hostname(self):
83-
return ConnectTimeout()
83+
raise ConnectTimeout()
8484

8585
def get_environment_variables(self) -> dict[str, str]:
8686
"""Returns a dictionary of environment variables to patch in to fake the metadata service."""
@@ -101,7 +101,7 @@ def __call__(self, method, url, headers, timeout):
101101
logger.debug(
102102
f"Received request to unexpected hostname {parsed_url.hostname}"
103103
)
104-
raise self.handle_unexpected_hostname()
104+
self.handle_unexpected_hostname()
105105

106106
return self.handle_request(method, parsed_url, headers, timeout)
107107

@@ -143,7 +143,7 @@ def __exit__(self, *args, **kwargs):
143143

144144

145145
class UnavailableMetadataService(FakeMetadataService):
146-
"""Emulates an environment where all metadata services unavailable."""
146+
"""Emulates an environment where all metadata services are unavailable."""
147147

148148
def reset_defaults(self):
149149
pass
@@ -164,7 +164,7 @@ def reset_defaults(self):
164164
pass
165165

166166
def handle_unexpected_hostname(self):
167-
return RequestException()
167+
raise RequestException()
168168

169169
@property
170170
def expected_hostnames(self):

test/unit/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def disable_oob_telemetry():
2929

3030
@pytest.fixture
3131
def unavailable_metadata_service():
32-
"""Emulates an environment where all metadata services unavailable."""
32+
"""Emulates an environment where all metadata services are unavailable."""
3333
with UnavailableMetadataService() as server:
3434
yield server
3535

0 commit comments

Comments
 (0)