Skip to content

Commit b9bc689

Browse files
SNOW-2204396: Applied platform_detection.py changes
1 parent a7fe63c commit b9bc689

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

src/snowflake/connector/platform_detection.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class _DetectionState(Enum):
3030

3131
DETECTED = "detected"
3232
NOT_DETECTED = "not_detected"
33-
TIMEOUT = "timeout"
33+
HTTP_TIMEOUT = "timeout"
34+
WORKER_TIMEOUT = "worker_timeout"
3435

3536

3637
# Result returned when platform detection is disabled via environment variable
@@ -157,7 +158,7 @@ def is_azure_vm(
157158
session_manager: SessionManager instance for making HTTP requests.
158159
159160
Returns:
160-
_DetectionState: DETECTED if on Azure VM, TIMEOUT if request times out,
161+
_DetectionState: DETECTED if on Azure VM, HTTP_TIMEOUT if request times out,
161162
NOT_DETECTED otherwise.
162163
"""
163164
try:
@@ -172,7 +173,7 @@ def is_azure_vm(
172173
else _DetectionState.NOT_DETECTED
173174
)
174175
except Timeout:
175-
return _DetectionState.TIMEOUT
176+
return _DetectionState.HTTP_TIMEOUT
176177
except RequestException:
177178
return _DetectionState.NOT_DETECTED
178179

@@ -219,7 +220,7 @@ def is_managed_identity_available_on_azure_vm(
219220
resource: The Azure resource URI to request a token for.
220221
221222
Returns:
222-
_DetectionState: DETECTED if managed identity is available, TIMEOUT if request
223+
_DetectionState: DETECTED if managed identity is available, HTTP_TIMEOUT if request
223224
times out, NOT_DETECTED otherwise.
224225
"""
225226
endpoint = f"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource={resource}"
@@ -234,7 +235,7 @@ def is_managed_identity_available_on_azure_vm(
234235
else _DetectionState.NOT_DETECTED
235236
)
236237
except Timeout:
237-
return _DetectionState.TIMEOUT
238+
return _DetectionState.HTTP_TIMEOUT
238239
except RequestException:
239240
return _DetectionState.NOT_DETECTED
240241

@@ -261,7 +262,7 @@ def has_azure_managed_identity(
261262
session_manager: SessionManager instance for making HTTP requests.
262263
263264
Returns:
264-
_DetectionState: DETECTED if managed identity is available, TIMEOUT if
265+
_DetectionState: DETECTED if managed identity is available, HTTP_TIMEOUT if
265266
detection timed out, NOT_DETECTED otherwise.
266267
"""
267268
# short circuit early to save on latency and avoid minting an unnecessary token
@@ -290,7 +291,7 @@ def is_gce_vm(
290291
session_manager: SessionManager instance for making HTTP requests.
291292
292293
Returns:
293-
_DetectionState: DETECTED if on GCE, TIMEOUT if request times out,
294+
_DetectionState: DETECTED if on GCE, HTTP_TIMEOUT if request times out,
294295
NOT_DETECTED otherwise.
295296
"""
296297
try:
@@ -304,7 +305,7 @@ def is_gce_vm(
304305
else _DetectionState.NOT_DETECTED
305306
)
306307
except Timeout:
307-
return _DetectionState.TIMEOUT
308+
return _DetectionState.HTTP_TIMEOUT
308309
except RequestException:
309310
return _DetectionState.NOT_DETECTED
310311

@@ -360,7 +361,7 @@ def has_gcp_identity(
360361
platform_detection_timeout_seconds: Timeout value for the metadata service request.
361362
session_manager: SessionManager instance for making HTTP requests.
362363
Returns:
363-
_DetectionState: DETECTED if valid GCP identity exists, TIMEOUT if request
364+
_DetectionState: DETECTED if valid GCP identity exists, HTTP_TIMEOUT if request
364365
times out, NOT_DETECTED otherwise.
365366
"""
366367
try:
@@ -375,7 +376,7 @@ def has_gcp_identity(
375376
else _DetectionState.NOT_DETECTED
376377
)
377378
except Timeout:
378-
return _DetectionState.TIMEOUT
379+
return _DetectionState.HTTP_TIMEOUT
379380
except RequestException:
380381
return _DetectionState.NOT_DETECTED
381382

@@ -412,11 +413,11 @@ def detect_platforms(
412413
session_manager: SessionManager instance for making HTTP requests. If None, a new instance will be created.
413414
414415
Returns:
415-
list[str]: List of detected platform names. Platforms that timed out will have
416-
"_timeout" suffix appended to their name. Returns _PLATFORM_DETECTION_DISABLED_RESULT
417-
if the ENV_VAR_DISABLE_PLATFORM_DETECTION environment variable is set to a value
418-
in ENV_VAR_BOOL_POSITIVE_VALUES_LOWERCASED (case-insensitive).
419-
Returns empty list if any exception occurs during detection.
416+
list[str]: List of detected platform names. Platforms that timed out (either HTTP timeout
417+
or thread timeout) will have "_timeout" suffix appended to their name.
418+
Returns _PLATFORM_DETECTION_DISABLED_RESULT if the ENV_VAR_DISABLE_PLATFORM_DETECTION
419+
environment variable is set to a value in ENV_VAR_BOOL_POSITIVE_VALUES_LOWERCASED
420+
(case-insensitive). Returns empty list if any exception occurs during detection.
420421
"""
421422
try:
422423
# Check if platform detection is disabled via environment variable
@@ -491,15 +492,20 @@ def detect_platforms(
491492
timeout=platform_detection_timeout_seconds
492493
)
493494
except (FutureTimeoutError, FutureCancelledError):
494-
platforms[key] = _DetectionState.TIMEOUT
495+
# Thread/future timed out at executor level
496+
platforms[key] = _DetectionState.WORKER_TIMEOUT
495497
except Exception:
498+
# Any other error from the thread
496499
platforms[key] = _DetectionState.NOT_DETECTED
497500

498501
detected_platforms = []
499502
for platform_name, detection_state in platforms.items():
500503
if detection_state == _DetectionState.DETECTED:
501504
detected_platforms.append(platform_name)
502-
elif detection_state == _DetectionState.TIMEOUT:
505+
elif detection_state in (
506+
_DetectionState.HTTP_TIMEOUT,
507+
_DetectionState.WORKER_TIMEOUT,
508+
):
503509
detected_platforms.append(f"{platform_name}_timeout")
504510

505511
logger.debug(

test/unit/test_detect_platforms.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def teardown(self):
6565
def test_no_platforms_detected(
6666
self, unavailable_metadata_service_with_request_exception
6767
):
68-
result = detect_platforms(platform_detection_timeout_seconds=None)
68+
result = detect_platforms(
69+
platform_detection_timeout_seconds=0.3
70+
) # increase timeout to make sure no Thread-based timeout messes the results
6971
assert result == []
7072

7173
def test_ec2_instance_detection(
@@ -407,7 +409,7 @@ def capture_timeout_azure(timeout, session_manager):
407409
"snowflake.connector.platform_detection.is_azure_vm",
408410
side_effect=capture_timeout_azure,
409411
):
410-
detect_platforms()
412+
detect_platforms(platform_detection_timeout_seconds=None)
411413

412414
# Verify that functions were called with timeout <= 200ms
413415
assert len(timeout_captured) > 0, "No timeout was captured"
@@ -434,9 +436,9 @@ def test_platform_detection_completes_within_timeout(
434436
# Allow ~30% overhead for thread management, environment variable checks, etc.
435437
# The timeout is 200ms per network call, but they run in parallel
436438
# So total time should be ~200ms + overhead, not 200ms * number_of_calls
437-
epsilon_for_overhead = 1.3
438-
max_allowed_time = (
439-
EXPECTED_MAX_TIMEOUT_FOR_PLATFORM_DETECTION * epsilon_for_overhead
439+
epsilon_for_overhead = 0.3
440+
max_allowed_time = EXPECTED_MAX_TIMEOUT_FOR_PLATFORM_DETECTION * (
441+
1 + epsilon_for_overhead
440442
)
441443
assert execution_time < max_allowed_time, (
442444
f"Platform detection took {execution_time:.3f}s, "
@@ -445,5 +447,5 @@ def test_platform_detection_completes_within_timeout(
445447
# Ensure it's not suspiciously fast (< 10ms would indicate something's wrong)
446448
assert execution_time > 0.01, (
447449
f"Platform detection completed too quickly ({execution_time:.3f}s), "
448-
"which may indicate detection was skipped"
450+
"which may indicate detection was skipped or some other issues happened"
449451
)

0 commit comments

Comments
 (0)