1515
1616import pytest
1717
18+ # Add cryptography imports for private key handling
19+ from cryptography .hazmat .backends import default_backend
20+ from cryptography .hazmat .primitives import serialization
21+ from cryptography .hazmat .primitives .serialization import (
22+ Encoding ,
23+ NoEncryption ,
24+ PrivateFormat ,
25+ )
26+
1827import snowflake .connector
1928from snowflake .connector .compat import IS_WINDOWS
2029from snowflake .connector .connection import DefaultConverterClass
3241 from snowflake .connector import SnowflakeConnection
3342
3443RUNNING_ON_GH = os .getenv ("GITHUB_ACTIONS" ) == "true"
44+ RUNNING_ON_JENKINS = os .getenv ("JENKINS_HOME" ) not in (None , "false" )
45+ RUNNING_OLD_DRIVER = os .getenv ("TOX_ENV_NAME" ) == "olddriver"
3546TEST_USING_VENDORED_ARROW = os .getenv ("TEST_USING_VENDORED_ARROW" ) == "true"
3647
48+
49+ def _get_private_key_bytes_for_olddriver (private_key_file : str ) -> bytes :
50+ """Load private key file and convert to DER format bytes for olddriver compatibility.
51+
52+ The olddriver expects private keys in DER format as bytes.
53+ This function handles both PEM and DER input formats.
54+ """
55+ with open (private_key_file , "rb" ) as key_file :
56+ key_data = key_file .read ()
57+
58+ # Try to load as PEM first, then DER
59+ try :
60+ # Try PEM format first
61+ private_key = serialization .load_pem_private_key (
62+ key_data ,
63+ password = None ,
64+ backend = default_backend (),
65+ )
66+ except ValueError :
67+ try :
68+ # Try DER format
69+ private_key = serialization .load_der_private_key (
70+ key_data ,
71+ password = None ,
72+ backend = default_backend (),
73+ )
74+ except ValueError as e :
75+ raise ValueError (f"Could not load private key from { private_key_file } : { e } " )
76+
77+ # Convert to DER format bytes as expected by olddriver
78+ return private_key .private_bytes (
79+ encoding = Encoding .DER ,
80+ format = PrivateFormat .PKCS8 ,
81+ encryption_algorithm = NoEncryption (),
82+ )
83+
84+
3785if not isinstance (CONNECTION_PARAMETERS ["host" ], str ):
3886 raise Exception ("default host is not a string in parameters.py" )
3987RUNNING_AGAINST_LOCAL_SNOWFLAKE = CONNECTION_PARAMETERS ["host" ].endswith ("local" )
@@ -76,16 +124,42 @@ def _get_worker_specific_schema():
76124 )
77125
78126
79- DEFAULT_PARAMETERS : dict [str , Any ] = {
80- "account" : "<account_name>" ,
81- "user" : "<user_name>" ,
82- "password" : "<password>" ,
83- "database" : "<database_name>" ,
84- "schema" : "<schema_name>" ,
85- "protocol" : "https" ,
86- "host" : "<host>" ,
87- "port" : "443" ,
88- }
127+ if RUNNING_ON_JENKINS :
128+ DEFAULT_PARAMETERS : dict [str , Any ] = {
129+ "account" : "<account_name>" ,
130+ "user" : "<user_name>" ,
131+ "password" : "<password>" ,
132+ "database" : "<database_name>" ,
133+ "schema" : "<schema_name>" ,
134+ "protocol" : "https" ,
135+ "host" : "<host>" ,
136+ "port" : "443" ,
137+ }
138+ else :
139+ if RUNNING_OLD_DRIVER :
140+ DEFAULT_PARAMETERS : dict [str , Any ] = {
141+ "account" : "<account_name>" ,
142+ "user" : "<user_name>" ,
143+ "database" : "<database_name>" ,
144+ "schema" : "<schema_name>" ,
145+ "protocol" : "https" ,
146+ "host" : "<host>" ,
147+ "port" : "443" ,
148+ "authenticator" : "SNOWFLAKE_JWT" ,
149+ "private_key_file" : "<private_key_file>" ,
150+ }
151+ else :
152+ DEFAULT_PARAMETERS : dict [str , Any ] = {
153+ "account" : "<account_name>" ,
154+ "user" : "<user_name>" ,
155+ "database" : "<database_name>" ,
156+ "schema" : "<schema_name>" ,
157+ "protocol" : "https" ,
158+ "host" : "<host>" ,
159+ "port" : "443" ,
160+ "authenticator" : "<authenticator>" ,
161+ "private_key_file" : "<private_key_file>" ,
162+ }
89163
90164
91165def print_help () -> None :
@@ -95,9 +169,10 @@ def print_help() -> None:
95169CONNECTION_PARAMETERS = {
96170 'account': 'testaccount',
97171 'user': 'user1',
98- 'password': 'test',
99172 'database': 'testdb',
100173 'schema': 'public',
174+ 'authenticator': 'KEY_PAIR_AUTHENTICATOR',
175+ 'private_key_file': '/path/to/private_key.p8',
101176}
102177"""
103178 )
@@ -200,16 +275,55 @@ def init_test_schema(db_parameters) -> Generator[None]:
200275
201276 This is automatically called per test session.
202277 """
203- ret = db_parameters
204- with snowflake .connector .connect (
205- user = ret ["user" ],
206- password = ret ["password" ],
207- host = ret ["host" ],
208- port = ret ["port" ],
209- database = ret ["database" ],
210- account = ret ["account" ],
211- protocol = ret ["protocol" ],
212- ) as con :
278+ if RUNNING_ON_JENKINS :
279+ connection_params = {
280+ "user" : db_parameters ["user" ],
281+ "password" : db_parameters ["password" ],
282+ "host" : db_parameters ["host" ],
283+ "port" : db_parameters ["port" ],
284+ "database" : db_parameters ["database" ],
285+ "account" : db_parameters ["account" ],
286+ "protocol" : db_parameters ["protocol" ],
287+ }
288+ else :
289+ connection_params = {
290+ "user" : db_parameters ["user" ],
291+ "host" : db_parameters ["host" ],
292+ "port" : db_parameters ["port" ],
293+ "database" : db_parameters ["database" ],
294+ "account" : db_parameters ["account" ],
295+ "protocol" : db_parameters ["protocol" ],
296+ }
297+
298+ # Handle private key authentication differently for old vs new driver
299+ if RUNNING_OLD_DRIVER :
300+ # Old driver expects private_key as bytes and SNOWFLAKE_JWT authenticator
301+ private_key_file = db_parameters .get ("private_key_file" )
302+ if private_key_file :
303+ private_key_bytes = _get_private_key_bytes_for_olddriver (
304+ private_key_file
305+ )
306+ connection_params .update (
307+ {
308+ "authenticator" : "SNOWFLAKE_JWT" ,
309+ "private_key" : private_key_bytes ,
310+ }
311+ )
312+ else :
313+ # New driver expects private_key_file and KEY_PAIR_AUTHENTICATOR
314+ connection_params .update (
315+ {
316+ "authenticator" : db_parameters ["authenticator" ],
317+ "private_key_file" : db_parameters ["private_key_file" ],
318+ }
319+ )
320+
321+ # Role may be needed when running on preprod, but is not present on Jenkins jobs
322+ optional_role = db_parameters .get ("role" )
323+ if optional_role is not None :
324+ connection_params .update (role = optional_role )
325+
326+ with snowflake .connector .connect (** connection_params ) as con :
213327 con .cursor ().execute (f"CREATE SCHEMA IF NOT EXISTS { TEST_SCHEMA } " )
214328 yield
215329 con .cursor ().execute (f"DROP SCHEMA IF EXISTS { TEST_SCHEMA } " )
@@ -224,6 +338,24 @@ def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
224338 """
225339 ret = get_db_parameters (connection_name )
226340 ret .update (kwargs )
341+
342+ # Handle private key authentication differently for old vs new driver (only if not on Jenkins)
343+ if not RUNNING_ON_JENKINS and "private_key_file" in ret :
344+ if RUNNING_OLD_DRIVER :
345+ # Old driver (3.1.0) expects private_key as bytes and SNOWFLAKE_JWT authenticator
346+ private_key_file = ret .get ("private_key_file" )
347+ if (
348+ private_key_file and "private_key" not in ret
349+ ): # Don't override if private_key already set
350+ private_key_bytes = _get_private_key_bytes_for_olddriver (
351+ private_key_file
352+ )
353+ ret ["authenticator" ] = "SNOWFLAKE_JWT"
354+ ret ["private_key" ] = private_key_bytes
355+ ret .pop (
356+ "private_key_file" , None
357+ ) # Remove private_key_file for old driver
358+
227359 connection = snowflake .connector .connect (** ret )
228360 return connection
229361
0 commit comments