Skip to content

Commit 8285948

Browse files
committed
old driver fix2
1 parent 049f076 commit 8285948

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

test/integ/conftest.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true"
3131
RUNNING_ON_JENKINS = os.getenv("JENKINS_URL") is not None
32+
RUNNING_OLD_DRIVER = os.getenv("TOX_ENV_NAME") == "olddriver"
3233
TEST_USING_VENDORED_ARROW = os.getenv("TEST_USING_VENDORED_ARROW") == "true"
3334

3435
if not isinstance(CONNECTION_PARAMETERS["host"], str):
@@ -85,17 +86,30 @@ def _get_worker_specific_schema():
8586
"port": "443",
8687
}
8788
else:
88-
DEFAULT_PARAMETERS: dict[str, Any] = {
89-
"account": "<account_name>",
90-
"user": "<user_name>",
91-
"database": "<database_name>",
92-
"schema": "<schema_name>",
93-
"protocol": "https",
94-
"host": "<host>",
95-
"port": "443",
96-
"authenticator": "<authenticator>",
97-
"private_key_file": "<private_key_file>",
98-
}
89+
if RUNNING_OLD_DRIVER:
90+
DEFAULT_PARAMETERS: dict[str, Any] = {
91+
"account": "<account_name>",
92+
"user": "<user_name>",
93+
"database": "<database_name>",
94+
"schema": "<schema_name>",
95+
"protocol": "https",
96+
"host": "<host>",
97+
"port": "443",
98+
"authenticator": "SNOWFLAKE_JWT",
99+
"private_key_file": "<private_key_file>",
100+
}
101+
else:
102+
DEFAULT_PARAMETERS: dict[str, Any] = {
103+
"account": "<account_name>",
104+
"user": "<user_name>",
105+
"database": "<database_name>",
106+
"schema": "<schema_name>",
107+
"protocol": "https",
108+
"host": "<host>",
109+
"port": "443",
110+
"authenticator": "<authenticator>",
111+
"private_key_file": "<private_key_file>",
112+
}
99113

100114

101115
def print_help() -> None:
@@ -229,10 +243,26 @@ def init_test_schema(db_parameters) -> Generator[None]:
229243
"database": db_parameters["database"],
230244
"account": db_parameters["account"],
231245
"protocol": db_parameters["protocol"],
232-
"authenticator": db_parameters["authenticator"],
233-
"private_key_file": db_parameters["private_key_file"],
234246
}
235247

248+
# Handle private key authentication differently for old vs new driver
249+
if RUNNING_OLD_DRIVER:
250+
# Old driver expects private_key as bytes and SNOWFLAKE_JWT authenticator
251+
private_key_file = db_parameters.get("private_key_file")
252+
if private_key_file:
253+
with open(private_key_file, "rb") as f:
254+
private_key_content = f.read()
255+
connection_params.update({
256+
"authenticator": "SNOWFLAKE_JWT",
257+
"private_key": private_key_content,
258+
})
259+
else:
260+
# New driver expects private_key_file and KEY_PAIR_AUTHENTICATOR
261+
connection_params.update({
262+
"authenticator": db_parameters["authenticator"],
263+
"private_key_file": db_parameters["private_key_file"],
264+
})
265+
236266
# Role may be needed when running on preprod, but is not present on Jenkins jobs
237267
optional_role = db_parameters.get("role")
238268
if optional_role is not None:
@@ -253,6 +283,19 @@ def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
253283
"""
254284
ret = get_db_parameters(connection_name)
255285
ret.update(kwargs)
286+
287+
# Handle private key authentication differently for old vs new driver (only if not on Jenkins)
288+
if not RUNNING_ON_JENKINS and "private_key_file" in ret:
289+
if RUNNING_OLD_DRIVER:
290+
# Old driver (3.1.0) expects private_key as bytes and SNOWFLAKE_JWT authenticator
291+
private_key_file = ret.get("private_key_file")
292+
if private_key_file and "private_key" not in ret: # Don't override if private_key already set
293+
with open(private_key_file, "rb") as f:
294+
private_key_content = f.read()
295+
ret["authenticator"] = "SNOWFLAKE_JWT"
296+
ret["private_key"] = private_key_content
297+
ret.pop("private_key_file", None) # Remove private_key_file for old driver
298+
256299
connection = snowflake.connector.connect(**ret)
257300
return connection
258301

0 commit comments

Comments
 (0)