Skip to content

Commit 3fd9aa9

Browse files
sfc-gh-mmishchenkosfc-gh-pczajka
authored andcommitted
SNOW-2268606 zero timeout disables endpoint-based cloud platform detection (#2490)
1 parent 36f1171 commit 3fd9aa9

File tree

6 files changed

+147
-59
lines changed

6 files changed

+147
-59
lines changed

src/snowflake/connector/auth/_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def authenticate(
195195
self._rest._connection.login_timeout,
196196
self._rest._connection._network_timeout,
197197
self._rest._connection._socket_timeout,
198-
self._rest._connection._platform_detection_timeout_seconds,
198+
self._rest._connection.platform_detection_timeout_seconds,
199199
session_manager=self._rest.session_manager.clone(use_pooling=False),
200200
)
201201

src/snowflake/connector/auth/okta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def _step1(
167167
conn._internal_application_version,
168168
conn._ocsp_mode(),
169169
conn.login_timeout,
170-
conn._network_timeout,
170+
conn.network_timeout,
171+
conn.socket_timeout,
172+
conn.platform_detection_timeout_seconds,
171173
session_manager=conn._session_manager.clone(use_pooling=False),
172174
)
173175

src/snowflake/connector/auth/webbrowser.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,14 @@ def _get_sso_url(
456456
body = Auth.base_auth_data(
457457
user,
458458
account,
459-
conn._rest._connection.application,
460-
conn._rest._connection._internal_application_name,
461-
conn._rest._connection._internal_application_version,
462-
conn._rest._connection._ocsp_mode(),
463-
conn._rest._connection.login_timeout,
464-
conn._rest._connection._network_timeout,
459+
conn.application,
460+
conn._internal_application_name,
461+
conn._internal_application_version,
462+
conn._ocsp_mode(),
463+
conn.login_timeout,
464+
conn.network_timeout,
465+
conn.socket_timeout,
466+
conn.platform_detection_timeout_seconds,
465467
session_manager=conn.rest.session_manager.clone(use_pooling=False),
466468
)
467469

src/snowflake/connector/platform_detection.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -417,33 +417,36 @@ def detect_platforms(
417417
}
418418

419419
# Run network-calling functions in parallel
420-
with ThreadPoolExecutor(max_workers=6) as executor:
421-
futures = {
422-
"is_ec2_instance": executor.submit(
423-
is_ec2_instance, platform_detection_timeout_seconds
424-
),
425-
"has_aws_identity": executor.submit(
426-
has_aws_identity, platform_detection_timeout_seconds
427-
),
428-
"is_azure_vm": executor.submit(
429-
is_azure_vm, platform_detection_timeout_seconds, session_manager
430-
),
431-
"has_azure_managed_identity": executor.submit(
432-
has_azure_managed_identity,
433-
platform_detection_timeout_seconds,
434-
session_manager,
435-
),
436-
"is_gce_vm": executor.submit(
437-
is_gce_vm, platform_detection_timeout_seconds, session_manager
438-
),
439-
"has_gcp_identity": executor.submit(
440-
has_gcp_identity,
441-
platform_detection_timeout_seconds,
442-
session_manager,
443-
),
444-
}
445-
446-
platforms.update({key: future.result() for key, future in futures.items()})
420+
if platform_detection_timeout_seconds != 0.0:
421+
with ThreadPoolExecutor(max_workers=6) as executor:
422+
futures = {
423+
"is_ec2_instance": executor.submit(
424+
is_ec2_instance, platform_detection_timeout_seconds
425+
),
426+
"has_aws_identity": executor.submit(
427+
has_aws_identity, platform_detection_timeout_seconds
428+
),
429+
"is_azure_vm": executor.submit(
430+
is_azure_vm, platform_detection_timeout_seconds, session_manager
431+
),
432+
"has_azure_managed_identity": executor.submit(
433+
has_azure_managed_identity,
434+
platform_detection_timeout_seconds,
435+
session_manager,
436+
),
437+
"is_gce_vm": executor.submit(
438+
is_gce_vm, platform_detection_timeout_seconds, session_manager
439+
),
440+
"has_gcp_identity": executor.submit(
441+
has_gcp_identity,
442+
platform_detection_timeout_seconds,
443+
session_manager,
444+
),
445+
}
446+
447+
platforms.update(
448+
{key: future.result() for key, future in futures.items()}
449+
)
447450

448451
detected_platforms = []
449452
for platform_name, detection_state in platforms.items():

test/integ/test_connection.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,37 @@ def test_platform_detection_timeout(conn_cnx):
204204
assert cnx.platform_detection_timeout_seconds == 2.5
205205

206206

207+
@pytest.mark.skipolddriver
208+
def test_platform_detection_zero_timeout(conn_cnx):
209+
"""Tests platform detection with timeout set to zero.
210+
211+
The expectation is that it mustn't do diagnostic requests at all.
212+
"""
213+
with (
214+
mock.patch(
215+
"snowflake.connector.platform_detection.is_ec2_instance"
216+
) as is_ec2_instance,
217+
mock.patch(
218+
"snowflake.connector.platform_detection.has_aws_identity"
219+
) as has_aws_identity,
220+
mock.patch("snowflake.connector.platform_detection.is_azure_vm") as is_azure_vm,
221+
mock.patch(
222+
"snowflake.connector.platform_detection.has_azure_managed_identity"
223+
) as has_azure_managed_identity,
224+
mock.patch("snowflake.connector.platform_detection.is_gce_vm") as is_gce_vm,
225+
mock.patch(
226+
"snowflake.connector.platform_detection.has_gcp_identity"
227+
) as has_gcp_identity,
228+
):
229+
with conn_cnx(platform_detection_timeout_seconds=0):
230+
assert not is_ec2_instance.called
231+
assert not has_aws_identity.called
232+
assert not is_azure_vm.called
233+
assert not has_azure_managed_identity.called
234+
assert not is_gce_vm.called
235+
assert not has_gcp_identity.called
236+
237+
207238
def test_bad_db(conn_cnx):
208239
"""Attempts to use a bad DB."""
209240
with conn_cnx(database="baddb") as cnx:
@@ -1119,9 +1150,10 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11191150
"math",
11201151
]
11211152

1122-
with conn_cnx() as conn, capture_sf_telemetry.patch_connection(
1123-
conn, False
1124-
) as telemetry_test:
1153+
with (
1154+
conn_cnx() as conn,
1155+
capture_sf_telemetry.patch_connection(conn, False) as telemetry_test,
1156+
):
11251157
conn._log_telemetry_imported_packages()
11261158
assert len(telemetry_test.records) > 0
11271159
assert any(
@@ -1136,10 +1168,13 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11361168

11371169
# test different application
11381170
new_application_name = "PythonSnowpark"
1139-
with conn_cnx(
1140-
timezone="UTC",
1141-
application=new_application_name,
1142-
) as conn, capture_sf_telemetry.patch_connection(conn, False) as telemetry_test:
1171+
with (
1172+
conn_cnx(
1173+
timezone="UTC",
1174+
application=new_application_name,
1175+
) as conn,
1176+
capture_sf_telemetry.patch_connection(conn, False) as telemetry_test,
1177+
):
11431178
conn._log_telemetry_imported_packages()
11441179
assert len(telemetry_test.records) > 0
11451180
assert any(
@@ -1152,11 +1187,14 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11521187
)
11531188

11541189
# test opt out
1155-
with conn_cnx(
1156-
timezone="UTC",
1157-
application=new_application_name,
1158-
log_imported_packages_in_telemetry=False,
1159-
) as conn, capture_sf_telemetry.patch_connection(conn, False) as telemetry_test:
1190+
with (
1191+
conn_cnx(
1192+
timezone="UTC",
1193+
application=new_application_name,
1194+
log_imported_packages_in_telemetry=False,
1195+
) as conn,
1196+
capture_sf_telemetry.patch_connection(conn, False) as telemetry_test,
1197+
):
11601198
conn._log_telemetry_imported_packages()
11611199
assert len(telemetry_test.records) == 0
11621200

@@ -1293,9 +1331,10 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_match(
12931331
conn_cnx, is_public_test, is_local_dev_setup, caplog
12941332
):
12951333
caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake")
1296-
with conn_cnx(
1297-
insecure_mode=True, disable_ocsp_checks=True
1298-
) as conn, conn.cursor() as cur:
1334+
with (
1335+
conn_cnx(insecure_mode=True, disable_ocsp_checks=True) as conn,
1336+
conn.cursor() as cur,
1337+
):
12991338
assert cur.execute("select 1").fetchall() == [(1,)]
13001339
assert "snowflake.connector.ocsp_snowflake" not in caplog.text
13011340
if is_public_test or is_local_dev_setup:
@@ -1311,9 +1350,10 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_disabled(
13111350
conn_cnx, is_public_test, is_local_dev_setup, caplog
13121351
):
13131352
caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake")
1314-
with conn_cnx(
1315-
insecure_mode=False, disable_ocsp_checks=True
1316-
) as conn, conn.cursor() as cur:
1353+
with (
1354+
conn_cnx(insecure_mode=False, disable_ocsp_checks=True) as conn,
1355+
conn.cursor() as cur,
1356+
):
13171357
assert cur.execute("select 1").fetchall() == [(1,)]
13181358
assert "snowflake.connector.ocsp_snowflake" not in caplog.text
13191359
if is_public_test or is_local_dev_setup:
@@ -1329,9 +1369,10 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_enabled(
13291369
conn_cnx, is_public_test, is_local_dev_setup, caplog
13301370
):
13311371
caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake")
1332-
with conn_cnx(
1333-
insecure_mode=True, disable_ocsp_checks=False
1334-
) as conn, conn.cursor() as cur:
1372+
with (
1373+
conn_cnx(insecure_mode=True, disable_ocsp_checks=False) as conn,
1374+
conn.cursor() as cur,
1375+
):
13351376
assert cur.execute("select 1").fetchall() == [(1,)]
13361377
if is_public_test or is_local_dev_setup:
13371378
assert "snowflake.connector.ocsp_snowflake" in caplog.text
@@ -1430,9 +1471,10 @@ def test_disable_telemetry(conn_cnx, caplog):
14301471

14311472
# set session parameters to false
14321473
with caplog.at_level(logging.DEBUG):
1433-
with conn_cnx(
1434-
session_parameters={"CLIENT_TELEMETRY_ENABLED": False}
1435-
) as conn, conn.cursor() as cur:
1474+
with (
1475+
conn_cnx(session_parameters={"CLIENT_TELEMETRY_ENABLED": False}) as conn,
1476+
conn.cursor() as cur,
1477+
):
14361478
cur.execute("select 1").fetchall()
14371479
assert not conn.telemetry_enabled and not conn._telemetry._log_batch
14381480
# this enable won't work as the session parameter is set to false

test/unit/test_detect_platforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ def unavailable_metadata_service_with_request_exception(unavailable_metadata_ser
2626
return unavailable_metadata_service
2727

2828

29+
@pytest.fixture
30+
def labels_detected_by_endpoints():
31+
return {
32+
"is_ec2_instance",
33+
"is_ec2_instance_timeout",
34+
"has_aws_identity",
35+
"has_aws_identity_timeout",
36+
"is_azure_vm",
37+
"is_azure_vm_timeout",
38+
"has_azure_managed_identity",
39+
"has_azure_managed_identity_timeout",
40+
"is_gce_vm",
41+
"is_gce_vm_timeout",
42+
"has_gcp_identity",
43+
"has_gcp_identity_timeout",
44+
}
45+
46+
2947
@pytest.mark.xdist_group(name="serial_tests")
3048
class TestDetectPlatforms:
3149
@pytest.fixture(autouse=True)
@@ -288,3 +306,24 @@ def test_gce_cloud_run_job_missing_cloud_run_job(
288306
):
289307
result = detect_platforms(platform_detection_timeout_seconds=None)
290308
assert "is_gce_cloud_run_job" not in result
309+
310+
def test_zero_platform_detection_timeout_disables_endpoints_detection_on_cloud(
311+
self,
312+
fake_azure_vm_metadata_service,
313+
fake_azure_function_metadata_service,
314+
fake_gce_metadata_service,
315+
fake_gce_cloud_run_service_metadata_service,
316+
fake_gce_cloud_run_job_metadata_service,
317+
fake_github_actions_metadata_service,
318+
labels_detected_by_endpoints,
319+
):
320+
result = detect_platforms(platform_detection_timeout_seconds=0)
321+
assert not labels_detected_by_endpoints.intersection(result)
322+
323+
def test_zero_platform_detection_timeout_disables_endpoints_detection_out_of_cloud(
324+
self,
325+
unavailable_metadata_service_with_request_exception,
326+
labels_detected_by_endpoints,
327+
):
328+
result = detect_platforms(platform_detection_timeout_seconds=0)
329+
assert not labels_detected_by_endpoints.intersection(result)

0 commit comments

Comments
 (0)