Skip to content

Commit 996a2fd

Browse files
[async] apply #2490 - platform_detection_timeout
1 parent 1e9d52f commit 996a2fd

File tree

4 files changed

+95
-40
lines changed

4 files changed

+95
-40
lines changed

src/snowflake/connector/aio/auth/_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ async def authenticate(
102102
self._rest._connection._login_timeout,
103103
self._rest._connection._network_timeout,
104104
self._rest._connection._socket_timeout,
105-
self._rest._connection._platform_detection_timeout_seconds,
105+
self._rest._connection.platform_detection_timeout_seconds,
106106
http_config=self._rest.session_manager.config, # AioHttpConfig extends BaseHttpConfig
107107
)
108108

src/snowflake/connector/aio/auth/_okta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ async def _step1(
122122
conn._internal_application_version,
123123
conn._ocsp_mode(),
124124
conn.login_timeout,
125-
conn._network_timeout,
125+
conn.network_timeout,
126+
conn.socket_timeout,
127+
conn.platform_detection_timeout_seconds,
126128
http_config=conn._session_manager.config, # AioHttpConfig extends BaseHttpConfig
127129
)
128130

src/snowflake/connector/aio/auth/_webbrowser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,9 @@ async def _get_sso_url(
370370
conn._internal_application_version,
371371
conn._ocsp_mode(),
372372
conn.login_timeout,
373-
conn._network_timeout,
373+
conn.network_timeout,
374+
conn.socket_timeout,
375+
conn.platform_detection_timeout_seconds,
374376
http_config=conn._session_manager.config, # AioHttpConfig extends BaseHttpConfig
375377
)
376378

test/integ/aio_it/test_connection_async.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,20 @@ async def test_keep_alive_heartbeat_send(conn_cnx, db_parameters):
195195
"client_session_keep_alive_heartbeat_frequency": "1",
196196
}
197197
)
198-
with mock.patch(
199-
"snowflake.connector.aio._connection.SnowflakeConnection._validate_client_session_keep_alive_heartbeat_frequency",
200-
return_value=900,
201-
), mock.patch(
202-
"snowflake.connector.aio._connection.SnowflakeConnection.client_session_keep_alive_heartbeat_frequency",
203-
new_callable=mock.PropertyMock,
204-
return_value=1,
205-
), mock.patch(
206-
"snowflake.connector.aio._connection.SnowflakeConnection._heartbeat_tick"
207-
) as mocked_heartbeat:
198+
with (
199+
mock.patch(
200+
"snowflake.connector.aio._connection.SnowflakeConnection._validate_client_session_keep_alive_heartbeat_frequency",
201+
return_value=900,
202+
),
203+
mock.patch(
204+
"snowflake.connector.aio._connection.SnowflakeConnection.client_session_keep_alive_heartbeat_frequency",
205+
new_callable=mock.PropertyMock,
206+
return_value=1,
207+
),
208+
mock.patch(
209+
"snowflake.connector.aio._connection.SnowflakeConnection._heartbeat_tick"
210+
) as mocked_heartbeat,
211+
):
208212
cnx = snowflake.connector.aio.SnowflakeConnection(**config)
209213
try:
210214
await cnx.connect()
@@ -1056,9 +1060,10 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
10561060
"math",
10571061
]
10581062

1059-
async with conn_cnx() as conn, capture_sf_telemetry_async.patch_connection(
1060-
conn, False
1061-
) as telemetry_test:
1063+
async with (
1064+
conn_cnx() as conn,
1065+
capture_sf_telemetry_async.patch_connection(conn, False) as telemetry_test,
1066+
):
10621067
await conn._log_telemetry_imported_packages()
10631068
assert len(telemetry_test.records) > 0
10641069
assert any(
@@ -1073,11 +1078,10 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
10731078

10741079
# test different application
10751080
new_application_name = "PythonSnowpark"
1076-
async with conn_cnx(
1077-
timezone="UTC", application=new_application_name
1078-
) as conn, capture_sf_telemetry_async.patch_connection(
1079-
conn, False
1080-
) as telemetry_test:
1081+
async with (
1082+
conn_cnx(timezone="UTC", application=new_application_name) as conn,
1083+
capture_sf_telemetry_async.patch_connection(conn, False) as telemetry_test,
1084+
):
10811085
await conn._log_telemetry_imported_packages()
10821086
assert len(telemetry_test.records) > 0
10831087
assert any(
@@ -1090,13 +1094,14 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
10901094
)
10911095

10921096
# test opt out
1093-
async with conn_cnx(
1094-
timezone="UTC",
1095-
application=new_application_name,
1096-
log_imported_packages_in_telemetry=False,
1097-
) as conn, capture_sf_telemetry_async.patch_connection(
1098-
conn, False
1099-
) as telemetry_test:
1097+
async with (
1098+
conn_cnx(
1099+
timezone="UTC",
1100+
application=new_application_name,
1101+
log_imported_packages_in_telemetry=False,
1102+
) as conn,
1103+
capture_sf_telemetry_async.patch_connection(conn, False) as telemetry_test,
1104+
):
11001105
await conn._log_telemetry_imported_packages()
11011106
assert len(telemetry_test.records) == 0
11021107

@@ -1245,9 +1250,10 @@ async def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_match(
12451250
conn_cnx, is_public_test, is_local_dev_setup, caplog
12461251
):
12471252
caplog.set_level(logging.DEBUG, "snowflake.connector.aio._ocsp_snowflake")
1248-
async with conn_cnx(
1249-
insecure_mode=True, disable_ocsp_checks=True
1250-
) as conn, conn.cursor() as cur:
1253+
async with (
1254+
conn_cnx(insecure_mode=True, disable_ocsp_checks=True) as conn,
1255+
conn.cursor() as cur,
1256+
):
12511257
assert await (await cur.execute("select 1")).fetchall() == [(1,)]
12521258
assert "snowflake.connector.aio._ocsp_snowflake" not in caplog.text
12531259
if is_public_test or is_local_dev_setup:
@@ -1263,9 +1269,10 @@ async def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_dis
12631269
conn_cnx, is_public_test, is_local_dev_setup, caplog
12641270
):
12651271
caplog.set_level(logging.DEBUG, "snowflake.connector.aio._ocsp_snowflake")
1266-
async with conn_cnx(
1267-
insecure_mode=False, disable_ocsp_checks=True
1268-
) as conn, conn.cursor() as cur:
1272+
async with (
1273+
conn_cnx(insecure_mode=False, disable_ocsp_checks=True) as conn,
1274+
conn.cursor() as cur,
1275+
):
12691276
assert await (await cur.execute("select 1")).fetchall() == [(1,)]
12701277
assert "snowflake.connector.aio._ocsp_snowflake" not in caplog.text
12711278
if is_public_test or is_local_dev_setup:
@@ -1281,9 +1288,10 @@ async def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_ena
12811288
conn_cnx, is_public_test, is_local_dev_setup, caplog
12821289
):
12831290
caplog.set_level(logging.DEBUG, "snowflake.connector.aio._ocsp_snowflake")
1284-
async with conn_cnx(
1285-
insecure_mode=True, disable_ocsp_checks=False
1286-
) as conn, conn.cursor() as cur:
1291+
async with (
1292+
conn_cnx(insecure_mode=True, disable_ocsp_checks=False) as conn,
1293+
conn.cursor() as cur,
1294+
):
12871295
assert await (await cur.execute("select 1")).fetchall() == [(1,)]
12881296
if is_public_test or is_local_dev_setup:
12891297
assert "snowflake.connector.aio._ocsp_snowflake" in caplog.text
@@ -1394,9 +1402,10 @@ async def test_disable_telemetry(conn_cnx, caplog):
13941402

13951403
# set session parameters to false
13961404
with caplog.at_level(logging.DEBUG):
1397-
async with conn_cnx(
1398-
session_parameters={"CLIENT_TELEMETRY_ENABLED": False}
1399-
) as conn, conn.cursor() as cur:
1405+
async with (
1406+
conn_cnx(session_parameters={"CLIENT_TELEMETRY_ENABLED": False}) as conn,
1407+
conn.cursor() as cur,
1408+
):
14001409
await (await cur.execute("select 1")).fetchall()
14011410
assert not conn.telemetry_enabled and not conn._telemetry._log_batch
14021411
# this enable won't work as the session parameter is set to false
@@ -1418,6 +1427,48 @@ async def test_disable_telemetry(conn_cnx, caplog):
14181427
assert "POST /telemetry/send" not in caplog.text
14191428

14201429

1430+
@pytest.mark.skipolddriver
1431+
async def test_platform_detection_timeout(conn_cnx):
1432+
"""Tests platform detection timeout.
1433+
1434+
Creates a connection with platform_detection_timeout parameter.
1435+
"""
1436+
async with conn_cnx(timezone="UTC", platform_detection_timeout_seconds=2.5) as cnx:
1437+
assert cnx.platform_detection_timeout_seconds == 2.5
1438+
1439+
1440+
@pytest.mark.skipolddriver
1441+
async def test_platform_detection_zero_timeout(conn_cnx):
1442+
with (
1443+
mock.patch(
1444+
"snowflake.connector.platform_detection.is_ec2_instance"
1445+
) as is_ec2_instance,
1446+
mock.patch(
1447+
"snowflake.connector.platform_detection.has_aws_identity"
1448+
) as has_aws_identity,
1449+
mock.patch("snowflake.connector.platform_detection.is_azure_vm") as is_azure_vm,
1450+
mock.patch(
1451+
"snowflake.connector.platform_detection.has_azure_managed_identity"
1452+
) as has_azure_managed_identity,
1453+
mock.patch("snowflake.connector.platform_detection.is_gce_vm") as is_gce_vm,
1454+
mock.patch(
1455+
"snowflake.connector.platform_detection.has_gcp_identity"
1456+
) as has_gcp_identity,
1457+
):
1458+
for kwargs in [
1459+
{}, # should be default
1460+
{"platform_detection_timeout_seconds": 0},
1461+
]:
1462+
async with conn_cnx(**kwargs) as conn:
1463+
assert conn.platform_detection_timeout_seconds == 0.0
1464+
assert not is_ec2_instance.called
1465+
assert not has_aws_identity.called
1466+
assert not is_azure_vm.called
1467+
assert not has_azure_managed_identity.called
1468+
assert not is_gce_vm.called
1469+
assert not has_gcp_identity.called
1470+
1471+
14211472
@pytest.mark.skipolddriver
14221473
async def test_is_valid(conn_cnx):
14231474
"""Tests whether connection and session validation happens."""

0 commit comments

Comments
 (0)