11
11
12
12
import pytest
13
13
14
+ # Add cryptography imports for private key handling
15
+ from cryptography .hazmat .backends import default_backend
16
+ from cryptography .hazmat .primitives import serialization
17
+ from cryptography .hazmat .primitives .serialization import (
18
+ Encoding ,
19
+ NoEncryption ,
20
+ PrivateFormat ,
21
+ )
22
+
14
23
import snowflake .connector
15
24
from snowflake .connector .compat import IS_WINDOWS
16
25
from snowflake .connector .connection import DefaultConverterClass
28
37
from snowflake .connector import SnowflakeConnection
29
38
30
39
RUNNING_ON_GH = os .getenv ("GITHUB_ACTIONS" ) == "true"
40
+ RUNNING_ON_JENKINS = os .getenv ("JENKINS_HOME" ) not in (None , "false" )
41
+ RUNNING_OLD_DRIVER = os .getenv ("TOX_ENV_NAME" ) == "olddriver"
31
42
TEST_USING_VENDORED_ARROW = os .getenv ("TEST_USING_VENDORED_ARROW" ) == "true"
32
43
44
+
45
+ def _get_private_key_bytes_for_olddriver (private_key_file : str ) -> bytes :
46
+ """Load private key file and convert to DER format bytes for olddriver compatibility.
47
+
48
+ The olddriver expects private keys in DER format as bytes.
49
+ This function handles both PEM and DER input formats.
50
+ """
51
+ with open (private_key_file , "rb" ) as key_file :
52
+ key_data = key_file .read ()
53
+
54
+ # Try to load as PEM first, then DER
55
+ try :
56
+ # Try PEM format first
57
+ private_key = serialization .load_pem_private_key (
58
+ key_data ,
59
+ password = None ,
60
+ backend = default_backend (),
61
+ )
62
+ except ValueError :
63
+ try :
64
+ # Try DER format
65
+ private_key = serialization .load_der_private_key (
66
+ key_data ,
67
+ password = None ,
68
+ backend = default_backend (),
69
+ )
70
+ except ValueError as e :
71
+ raise ValueError (f"Could not load private key from { private_key_file } : { e } " )
72
+
73
+ # Convert to DER format bytes as expected by olddriver
74
+ return private_key .private_bytes (
75
+ encoding = Encoding .DER ,
76
+ format = PrivateFormat .PKCS8 ,
77
+ encryption_algorithm = NoEncryption (),
78
+ )
79
+
80
+
33
81
if not isinstance (CONNECTION_PARAMETERS ["host" ], str ):
34
82
raise Exception ("default host is not a string in parameters.py" )
35
83
RUNNING_AGAINST_LOCAL_SNOWFLAKE = CONNECTION_PARAMETERS ["host" ].endswith ("local" )
@@ -72,16 +120,42 @@ def _get_worker_specific_schema():
72
120
)
73
121
74
122
75
- DEFAULT_PARAMETERS : dict [str , Any ] = {
76
- "account" : "<account_name>" ,
77
- "user" : "<user_name>" ,
78
- "password" : "<password>" ,
79
- "database" : "<database_name>" ,
80
- "schema" : "<schema_name>" ,
81
- "protocol" : "https" ,
82
- "host" : "<host>" ,
83
- "port" : "443" ,
84
- }
123
+ if RUNNING_ON_JENKINS :
124
+ DEFAULT_PARAMETERS : dict [str , Any ] = {
125
+ "account" : "<account_name>" ,
126
+ "user" : "<user_name>" ,
127
+ "password" : "<password>" ,
128
+ "database" : "<database_name>" ,
129
+ "schema" : "<schema_name>" ,
130
+ "protocol" : "https" ,
131
+ "host" : "<host>" ,
132
+ "port" : "443" ,
133
+ }
134
+ else :
135
+ if RUNNING_OLD_DRIVER :
136
+ DEFAULT_PARAMETERS : dict [str , Any ] = {
137
+ "account" : "<account_name>" ,
138
+ "user" : "<user_name>" ,
139
+ "database" : "<database_name>" ,
140
+ "schema" : "<schema_name>" ,
141
+ "protocol" : "https" ,
142
+ "host" : "<host>" ,
143
+ "port" : "443" ,
144
+ "authenticator" : "SNOWFLAKE_JWT" ,
145
+ "private_key_file" : "<private_key_file>" ,
146
+ }
147
+ else :
148
+ DEFAULT_PARAMETERS : dict [str , Any ] = {
149
+ "account" : "<account_name>" ,
150
+ "user" : "<user_name>" ,
151
+ "database" : "<database_name>" ,
152
+ "schema" : "<schema_name>" ,
153
+ "protocol" : "https" ,
154
+ "host" : "<host>" ,
155
+ "port" : "443" ,
156
+ "authenticator" : "<authenticator>" ,
157
+ "private_key_file" : "<private_key_file>" ,
158
+ }
85
159
86
160
87
161
def print_help () -> None :
@@ -91,9 +165,10 @@ def print_help() -> None:
91
165
CONNECTION_PARAMETERS = {
92
166
'account': 'testaccount',
93
167
'user': 'user1',
94
- 'password': 'test',
95
168
'database': 'testdb',
96
169
'schema': 'public',
170
+ 'authenticator': 'KEY_PAIR_AUTHENTICATOR',
171
+ 'private_key_file': '/path/to/private_key.p8',
97
172
}
98
173
"""
99
174
)
@@ -196,15 +271,48 @@ def init_test_schema(db_parameters) -> Generator[None]:
196
271
197
272
This is automatically called per test session.
198
273
"""
199
- connection_params = {
200
- "user" : db_parameters ["user" ],
201
- "password" : db_parameters ["password" ],
202
- "host" : db_parameters ["host" ],
203
- "port" : db_parameters ["port" ],
204
- "database" : db_parameters ["database" ],
205
- "account" : db_parameters ["account" ],
206
- "protocol" : db_parameters ["protocol" ],
207
- }
274
+ if RUNNING_ON_JENKINS :
275
+ connection_params = {
276
+ "user" : db_parameters ["user" ],
277
+ "password" : db_parameters ["password" ],
278
+ "host" : db_parameters ["host" ],
279
+ "port" : db_parameters ["port" ],
280
+ "database" : db_parameters ["database" ],
281
+ "account" : db_parameters ["account" ],
282
+ "protocol" : db_parameters ["protocol" ],
283
+ }
284
+ else :
285
+ connection_params = {
286
+ "user" : db_parameters ["user" ],
287
+ "host" : db_parameters ["host" ],
288
+ "port" : db_parameters ["port" ],
289
+ "database" : db_parameters ["database" ],
290
+ "account" : db_parameters ["account" ],
291
+ "protocol" : db_parameters ["protocol" ],
292
+ }
293
+
294
+ # Handle private key authentication differently for old vs new driver
295
+ if RUNNING_OLD_DRIVER :
296
+ # Old driver expects private_key as bytes and SNOWFLAKE_JWT authenticator
297
+ private_key_file = db_parameters .get ("private_key_file" )
298
+ if private_key_file :
299
+ private_key_bytes = _get_private_key_bytes_for_olddriver (
300
+ private_key_file
301
+ )
302
+ connection_params .update (
303
+ {
304
+ "authenticator" : "SNOWFLAKE_JWT" ,
305
+ "private_key" : private_key_bytes ,
306
+ }
307
+ )
308
+ else :
309
+ # New driver expects private_key_file and KEY_PAIR_AUTHENTICATOR
310
+ connection_params .update (
311
+ {
312
+ "authenticator" : db_parameters ["authenticator" ],
313
+ "private_key_file" : db_parameters ["private_key_file" ],
314
+ }
315
+ )
208
316
209
317
# Role may be needed when running on preprod, but is not present on Jenkins jobs
210
318
optional_role = db_parameters .get ("role" )
@@ -226,6 +334,24 @@ def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
226
334
"""
227
335
ret = get_db_parameters (connection_name )
228
336
ret .update (kwargs )
337
+
338
+ # Handle private key authentication differently for old vs new driver (only if not on Jenkins)
339
+ if not RUNNING_ON_JENKINS and "private_key_file" in ret :
340
+ if RUNNING_OLD_DRIVER :
341
+ # Old driver (3.1.0) expects private_key as bytes and SNOWFLAKE_JWT authenticator
342
+ private_key_file = ret .get ("private_key_file" )
343
+ if (
344
+ private_key_file and "private_key" not in ret
345
+ ): # Don't override if private_key already set
346
+ private_key_bytes = _get_private_key_bytes_for_olddriver (
347
+ private_key_file
348
+ )
349
+ ret ["authenticator" ] = "SNOWFLAKE_JWT"
350
+ ret ["private_key" ] = private_key_bytes
351
+ ret .pop (
352
+ "private_key_file" , None
353
+ ) # Remove private_key_file for old driver
354
+
229
355
connection = snowflake .connector .connect (** ret )
230
356
return connection
231
357
0 commit comments