Skip to content

Commit c300df6

Browse files
SNOW-2268606 zero timeout disables endpoint-based cloud platform detection (#2490)
1 parent 5e7fe20 commit c300df6

File tree

7 files changed

+148
-59
lines changed

7 files changed

+148
-59
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
99
# Release Notes
1010
- v3.17.2(TBD)
1111
- Fixed a bug where platform_detection was retrying failed requests with warnings to non-existent endpoints.
12+
- Added disabling endpoint-based platform detection by setting `platform_detection_timeout_seconds` to zero.
1213

1314
- v3.17.1(August 17,2025)
1415
- Added `infer_schema` parameter to `write_pandas` to perform schema inference on the passed data.

src/snowflake/connector/auth/_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def authenticate(
185185
self._rest._connection.login_timeout,
186186
self._rest._connection._network_timeout,
187187
self._rest._connection._socket_timeout,
188-
self._rest._connection._platform_detection_timeout_seconds,
188+
self._rest._connection.platform_detection_timeout_seconds,
189189
session_manager=self._rest.session_manager.clone(use_pooling=False),
190190
)
191191

src/snowflake/connector/auth/okta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def _step1(
167167
conn._internal_application_version,
168168
conn._ocsp_mode(),
169169
conn.login_timeout,
170-
conn._network_timeout,
170+
conn.network_timeout,
171+
conn.socket_timeout,
172+
conn.platform_detection_timeout_seconds,
171173
session_manager=conn._session_manager.clone(use_pooling=False),
172174
)
173175

src/snowflake/connector/auth/webbrowser.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,14 @@ def _get_sso_url(
456456
body = Auth.base_auth_data(
457457
user,
458458
account,
459-
conn._rest._connection.application,
460-
conn._rest._connection._internal_application_name,
461-
conn._rest._connection._internal_application_version,
462-
conn._rest._connection._ocsp_mode(),
463-
conn._rest._connection.login_timeout,
464-
conn._rest._connection._network_timeout,
459+
conn.application,
460+
conn._internal_application_name,
461+
conn._internal_application_version,
462+
conn._ocsp_mode(),
463+
conn.login_timeout,
464+
conn.network_timeout,
465+
conn.socket_timeout,
466+
conn.platform_detection_timeout_seconds,
465467
session_manager=conn.rest.session_manager.clone(use_pooling=False),
466468
)
467469

src/snowflake/connector/platform_detection.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -411,33 +411,36 @@ def detect_platforms(
411411
}
412412

413413
# Run network-calling functions in parallel
414-
with ThreadPoolExecutor(max_workers=6) as executor:
415-
futures = {
416-
"is_ec2_instance": executor.submit(
417-
is_ec2_instance, platform_detection_timeout_seconds
418-
),
419-
"has_aws_identity": executor.submit(
420-
has_aws_identity, platform_detection_timeout_seconds
421-
),
422-
"is_azure_vm": executor.submit(
423-
is_azure_vm, platform_detection_timeout_seconds, session_manager
424-
),
425-
"has_azure_managed_identity": executor.submit(
426-
has_azure_managed_identity,
427-
platform_detection_timeout_seconds,
428-
session_manager,
429-
),
430-
"is_gce_vm": executor.submit(
431-
is_gce_vm, platform_detection_timeout_seconds, session_manager
432-
),
433-
"has_gcp_identity": executor.submit(
434-
has_gcp_identity,
435-
platform_detection_timeout_seconds,
436-
session_manager,
437-
),
438-
}
439-
440-
platforms.update({key: future.result() for key, future in futures.items()})
414+
if platform_detection_timeout_seconds != 0.0:
415+
with ThreadPoolExecutor(max_workers=6) as executor:
416+
futures = {
417+
"is_ec2_instance": executor.submit(
418+
is_ec2_instance, platform_detection_timeout_seconds
419+
),
420+
"has_aws_identity": executor.submit(
421+
has_aws_identity, platform_detection_timeout_seconds
422+
),
423+
"is_azure_vm": executor.submit(
424+
is_azure_vm, platform_detection_timeout_seconds, session_manager
425+
),
426+
"has_azure_managed_identity": executor.submit(
427+
has_azure_managed_identity,
428+
platform_detection_timeout_seconds,
429+
session_manager,
430+
),
431+
"is_gce_vm": executor.submit(
432+
is_gce_vm, platform_detection_timeout_seconds, session_manager
433+
),
434+
"has_gcp_identity": executor.submit(
435+
has_gcp_identity,
436+
platform_detection_timeout_seconds,
437+
session_manager,
438+
),
439+
}
440+
441+
platforms.update(
442+
{key: future.result() for key, future in futures.items()}
443+
)
441444

442445
detected_platforms = []
443446
for platform_name, detection_state in platforms.items():

test/integ/test_connection.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,37 @@ def test_platform_detection_timeout(conn_cnx):
207207
assert cnx.platform_detection_timeout_seconds == 2.5
208208

209209

210+
@pytest.mark.skipolddriver
211+
def test_platform_detection_zero_timeout(conn_cnx):
212+
"""Tests platform detection with timeout set to zero.
213+
214+
The expectation is that it mustn't do diagnostic requests at all.
215+
"""
216+
with (
217+
mock.patch(
218+
"snowflake.connector.platform_detection.is_ec2_instance"
219+
) as is_ec2_instance,
220+
mock.patch(
221+
"snowflake.connector.platform_detection.has_aws_identity"
222+
) as has_aws_identity,
223+
mock.patch("snowflake.connector.platform_detection.is_azure_vm") as is_azure_vm,
224+
mock.patch(
225+
"snowflake.connector.platform_detection.has_azure_managed_identity"
226+
) as has_azure_managed_identity,
227+
mock.patch("snowflake.connector.platform_detection.is_gce_vm") as is_gce_vm,
228+
mock.patch(
229+
"snowflake.connector.platform_detection.has_gcp_identity"
230+
) as has_gcp_identity,
231+
):
232+
with conn_cnx(platform_detection_timeout_seconds=0):
233+
assert not is_ec2_instance.called
234+
assert not has_aws_identity.called
235+
assert not is_azure_vm.called
236+
assert not has_azure_managed_identity.called
237+
assert not is_gce_vm.called
238+
assert not has_gcp_identity.called
239+
240+
210241
def test_bad_db(conn_cnx):
211242
"""Attempts to use a bad DB."""
212243
with conn_cnx(database="baddb") as cnx:
@@ -1145,9 +1176,10 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11451176
"math",
11461177
]
11471178

1148-
with conn_cnx() as conn, capture_sf_telemetry.patch_connection(
1149-
conn, False
1150-
) as telemetry_test:
1179+
with (
1180+
conn_cnx() as conn,
1181+
capture_sf_telemetry.patch_connection(conn, False) as telemetry_test,
1182+
):
11511183
conn._log_telemetry_imported_packages()
11521184
assert len(telemetry_test.records) > 0
11531185
assert any(
@@ -1162,10 +1194,13 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11621194

11631195
# test different application
11641196
new_application_name = "PythonSnowpark"
1165-
with conn_cnx(
1166-
timezone="UTC",
1167-
application=new_application_name,
1168-
) as conn, capture_sf_telemetry.patch_connection(conn, False) as telemetry_test:
1197+
with (
1198+
conn_cnx(
1199+
timezone="UTC",
1200+
application=new_application_name,
1201+
) as conn,
1202+
capture_sf_telemetry.patch_connection(conn, False) as telemetry_test,
1203+
):
11691204
conn._log_telemetry_imported_packages()
11701205
assert len(telemetry_test.records) > 0
11711206
assert any(
@@ -1178,11 +1213,14 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11781213
)
11791214

11801215
# test opt out
1181-
with conn_cnx(
1182-
timezone="UTC",
1183-
application=new_application_name,
1184-
log_imported_packages_in_telemetry=False,
1185-
) as conn, capture_sf_telemetry.patch_connection(conn, False) as telemetry_test:
1216+
with (
1217+
conn_cnx(
1218+
timezone="UTC",
1219+
application=new_application_name,
1220+
log_imported_packages_in_telemetry=False,
1221+
) as conn,
1222+
capture_sf_telemetry.patch_connection(conn, False) as telemetry_test,
1223+
):
11861224
conn._log_telemetry_imported_packages()
11871225
assert len(telemetry_test.records) == 0
11881226

@@ -1318,9 +1356,10 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_match(
13181356
conn_cnx, is_public_test, is_local_dev_setup, caplog
13191357
):
13201358
caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake")
1321-
with conn_cnx(
1322-
insecure_mode=True, disable_ocsp_checks=True
1323-
) as conn, conn.cursor() as cur:
1359+
with (
1360+
conn_cnx(insecure_mode=True, disable_ocsp_checks=True) as conn,
1361+
conn.cursor() as cur,
1362+
):
13241363
assert cur.execute("select 1").fetchall() == [(1,)]
13251364
assert "snowflake.connector.ocsp_snowflake" not in caplog.text
13261365
if is_public_test or is_local_dev_setup:
@@ -1336,9 +1375,10 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_disabled(
13361375
conn_cnx, is_public_test, is_local_dev_setup, caplog
13371376
):
13381377
caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake")
1339-
with conn_cnx(
1340-
insecure_mode=False, disable_ocsp_checks=True
1341-
) as conn, conn.cursor() as cur:
1378+
with (
1379+
conn_cnx(insecure_mode=False, disable_ocsp_checks=True) as conn,
1380+
conn.cursor() as cur,
1381+
):
13421382
assert cur.execute("select 1").fetchall() == [(1,)]
13431383
assert "snowflake.connector.ocsp_snowflake" not in caplog.text
13441384
if is_public_test or is_local_dev_setup:
@@ -1524,9 +1564,10 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_enabled(
15241564
conn_cnx, is_public_test, is_local_dev_setup, caplog
15251565
):
15261566
caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake")
1527-
with conn_cnx(
1528-
insecure_mode=True, disable_ocsp_checks=False
1529-
) as conn, conn.cursor() as cur:
1567+
with (
1568+
conn_cnx(insecure_mode=True, disable_ocsp_checks=False) as conn,
1569+
conn.cursor() as cur,
1570+
):
15301571
assert cur.execute("select 1").fetchall() == [(1,)]
15311572
if is_public_test or is_local_dev_setup:
15321573
assert "snowflake.connector.ocsp_snowflake" in caplog.text
@@ -1625,9 +1666,10 @@ def test_disable_telemetry(conn_cnx, caplog):
16251666

16261667
# set session parameters to false
16271668
with caplog.at_level(logging.DEBUG):
1628-
with conn_cnx(
1629-
session_parameters={"CLIENT_TELEMETRY_ENABLED": False}
1630-
) as conn, conn.cursor() as cur:
1669+
with (
1670+
conn_cnx(session_parameters={"CLIENT_TELEMETRY_ENABLED": False}) as conn,
1671+
conn.cursor() as cur,
1672+
):
16311673
cur.execute("select 1").fetchall()
16321674
assert not conn.telemetry_enabled and not conn._telemetry._log_batch
16331675
# this enable won't work as the session parameter is set to false

test/unit/test_detect_platforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ def unavailable_metadata_service_with_request_exception(unavailable_metadata_ser
2626
return unavailable_metadata_service
2727

2828

29+
@pytest.fixture
30+
def labels_detected_by_endpoints():
31+
return {
32+
"is_ec2_instance",
33+
"is_ec2_instance_timeout",
34+
"has_aws_identity",
35+
"has_aws_identity_timeout",
36+
"is_azure_vm",
37+
"is_azure_vm_timeout",
38+
"has_azure_managed_identity",
39+
"has_azure_managed_identity_timeout",
40+
"is_gce_vm",
41+
"is_gce_vm_timeout",
42+
"has_gcp_identity",
43+
"has_gcp_identity_timeout",
44+
}
45+
46+
2947
@pytest.mark.xdist_group(name="serial_tests")
3048
class TestDetectPlatforms:
3149
@pytest.fixture(autouse=True)
@@ -288,3 +306,24 @@ def test_gce_cloud_run_job_missing_cloud_run_job(
288306
):
289307
result = detect_platforms(platform_detection_timeout_seconds=None)
290308
assert "is_gce_cloud_run_job" not in result
309+
310+
def test_zero_platform_detection_timeout_disables_endpoints_detection_on_cloud(
311+
self,
312+
fake_azure_vm_metadata_service,
313+
fake_azure_function_metadata_service,
314+
fake_gce_metadata_service,
315+
fake_gce_cloud_run_service_metadata_service,
316+
fake_gce_cloud_run_job_metadata_service,
317+
fake_github_actions_metadata_service,
318+
labels_detected_by_endpoints,
319+
):
320+
result = detect_platforms(platform_detection_timeout_seconds=0)
321+
assert not labels_detected_by_endpoints.intersection(result)
322+
323+
def test_zero_platform_detection_timeout_disables_endpoints_detection_out_of_cloud(
324+
self,
325+
unavailable_metadata_service_with_request_exception,
326+
labels_detected_by_endpoints,
327+
):
328+
result = detect_platforms(platform_detection_timeout_seconds=0)
329+
assert not labels_detected_by_endpoints.intersection(result)

0 commit comments

Comments
 (0)