Skip to content

Commit 10966b7

Browse files
fix async connection tests
1 parent c77cd31 commit 10966b7

File tree

2 files changed

+44
-90
lines changed

2 files changed

+44
-90
lines changed

test/integ/aio/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
8989
ret["private_key"] = private_key_bytes
9090
ret.pop("private_key_file", None)
9191

92+
# If authenticator is explicitly provided and it's not key-pair based, drop key-pair fields
93+
authenticator_value = ret.get("authenticator")
94+
if authenticator_value.lower() not in {"key_pair_authenticator", "snowflake_jwt"}:
95+
ret.pop("private_key", None)
96+
ret.pop("private_key_file", None)
97+
9298
connection = SnowflakeConnection(**ret)
9399
await connection.connect()
94100
return connection

test/integ/aio/test_connection_async.py

Lines changed: 38 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def test_with_config(conn_cnx):
8383
async with conn_cnx(timezone="UTC") as cnx:
8484
assert cnx, "invalid cnx"
8585
# Default depends on server; if unreachable, fall back to False
86-
from ..conftest import get_server_parameter_value
86+
from ...conftest import get_server_parameter_value
8787

8888
server_default_str = get_server_parameter_value(
8989
cnx, "CLIENT_SESSION_KEEP_ALIVE"
@@ -620,6 +620,7 @@ async def mock_auth(self, auth_instance):
620620
async with conn_cnx(
621621
timezone="UTC",
622622
authenticator=orig_authenticator,
623+
password="test-password",
623624
) as cnx:
624625
assert cnx
625626

@@ -713,82 +714,42 @@ async def test_dashed_url_account_name(db_parameters):
713714
),
714715
],
715716
)
716-
async def test_invalid_connection_parameter(db_parameters, name, value, exc_warn):
717-
with warnings.catch_warnings(record=True) as w:
718-
conn_params = {
719-
"account": db_parameters["account"],
720-
"user": db_parameters["user"],
721-
"password": db_parameters["password"],
722-
"schema": db_parameters["schema"],
723-
"database": db_parameters["database"],
724-
"protocol": db_parameters["protocol"],
725-
"host": db_parameters["host"],
726-
"port": db_parameters["port"],
727-
"validate_default_parameters": True,
728-
name: value,
729-
}
730-
try:
731-
conn = snowflake.connector.aio.SnowflakeConnection(**conn_params)
732-
await conn.connect()
717+
async def test_invalid_connection_parameter(conn_cnx, name, value, exc_warn):
718+
with warnings.catch_warnings(record=True) as warns:
719+
async with conn_cnx(validate_default_parameters=True, **{name: value}) as conn:
733720
assert getattr(conn, "_" + name) == value
734-
assert len(w) == 1
735-
assert str(w[0].message) == str(exc_warn)
736-
finally:
737-
await conn.close()
721+
assert any(str(exc_warn) == str(w.message) for w in warns)
738722

739723

740-
async def test_invalid_connection_parameters_turned_off(db_parameters):
724+
async def test_invalid_connection_parameters_turned_off(conn_cnx):
741725
"""Makes sure parameter checking can be turned off."""
742-
with warnings.catch_warnings(record=True) as w:
743-
conn_params = {
744-
"account": db_parameters["account"],
745-
"user": db_parameters["user"],
746-
"password": db_parameters["password"],
747-
"schema": db_parameters["schema"],
748-
"database": db_parameters["database"],
749-
"protocol": db_parameters["protocol"],
750-
"host": db_parameters["host"],
751-
"port": db_parameters["port"],
752-
"validate_default_parameters": False,
753-
"autocommit": "True", # Wrong type
754-
"applucation": "this is a typo or my own variable", # Wrong name
755-
}
756-
try:
757-
conn = snowflake.connector.aio.SnowflakeConnection(**conn_params)
758-
await conn.connect()
759-
assert conn._autocommit == conn_params["autocommit"]
760-
assert conn._applucation == conn_params["applucation"]
761-
assert len(w) == 0
762-
finally:
763-
await conn.close()
726+
with warnings.catch_warnings(record=True) as warns:
727+
async with conn_cnx(
728+
validate_default_parameters=False,
729+
autocommit="True",
730+
applucation="this is a typo or my own variable",
731+
) as conn:
732+
assert conn._autocommit == "True"
733+
assert conn._applucation == "this is a typo or my own variable"
734+
assert not any(
735+
"_autocommit" in w.message or "_applucation" in w.message for w in warns
736+
)
764737

765738

766-
async def test_invalid_connection_parameters_only_warns(db_parameters):
739+
async def test_invalid_connection_parameters_only_warns(conn_cnx):
767740
"""This test supresses warnings to only have warehouse, database and schema checking."""
768-
with warnings.catch_warnings(record=True) as w:
769-
conn_params = {
770-
"account": db_parameters["account"],
771-
"user": db_parameters["user"],
772-
"password": db_parameters["password"],
773-
"schema": db_parameters["schema"],
774-
"database": db_parameters["database"],
775-
"protocol": db_parameters["protocol"],
776-
"host": db_parameters["host"],
777-
"port": db_parameters["port"],
778-
"validate_default_parameters": True,
779-
"autocommit": "True", # Wrong type
780-
"applucation": "this is a typo or my own variable", # Wrong name
781-
}
782-
try:
783-
with warnings.catch_warnings():
784-
warnings.simplefilter("ignore")
785-
conn = snowflake.connector.aio.SnowflakeConnection(**conn_params)
786-
await conn.connect()
787-
assert conn._autocommit == conn_params["autocommit"]
788-
assert conn._applucation == conn_params["applucation"]
789-
assert len(w) == 0
790-
finally:
791-
await conn.close()
741+
with warnings.catch_warnings(record=True) as warns:
742+
async with conn_cnx(
743+
validate_default_parameters=True,
744+
autocommit="True",
745+
applucation="this is a typo or my own variable",
746+
) as conn:
747+
assert conn._autocommit == "True"
748+
assert conn._applucation == "this is a typo or my own variable"
749+
assert not any(
750+
"_autocommit" in str(w.message) or "_applucation" in str(w.message)
751+
for w in warns
752+
)
792753

793754

794755
@pytest.mark.skipolddriver
@@ -1059,9 +1020,7 @@ async def test_ocsp_cache_working(conn_cnx):
10591020

10601021

10611022
@pytest.mark.skipolddriver
1062-
async def test_imported_packages_telemetry(
1063-
conn_cnx, capture_sf_telemetry_async, db_parameters
1064-
):
1023+
async def test_imported_packages_telemetry(conn_cnx, capture_sf_telemetry_async):
10651024
# these imports are not used but for testing
10661025
import html.parser # noqa: F401
10671026
import json # noqa: F401
@@ -1102,20 +1061,8 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11021061

11031062
# test different application
11041063
new_application_name = "PythonSnowpark"
1105-
config = {
1106-
"user": db_parameters["user"],
1107-
"password": db_parameters["password"],
1108-
"host": db_parameters["host"],
1109-
"port": db_parameters["port"],
1110-
"account": db_parameters["account"],
1111-
"schema": db_parameters["schema"],
1112-
"database": db_parameters["database"],
1113-
"protocol": db_parameters["protocol"],
1114-
"timezone": "UTC",
1115-
"application": new_application_name,
1116-
}
1117-
async with snowflake.connector.aio.SnowflakeConnection(
1118-
**config
1064+
async with conn_cnx(
1065+
timezone="UTC", application=new_application_name
11191066
) as conn, capture_sf_telemetry_async.patch_connection(
11201067
conn, False
11211068
) as telemetry_test:
@@ -1131,9 +1078,10 @@ def check_packages(message: str, expected_packages: list[str]) -> bool:
11311078
)
11321079

11331080
# test opt out
1134-
config["log_imported_packages_in_telemetry"] = False
1135-
async with snowflake.connector.aio.SnowflakeConnection(
1136-
**config
1081+
async with conn_cnx(
1082+
timezone="UTC",
1083+
application=new_application_name,
1084+
log_imported_packages_in_telemetry=False,
11371085
) as conn, capture_sf_telemetry_async.patch_connection(
11381086
conn, False
11391087
) as telemetry_test:

0 commit comments

Comments
 (0)