15
15
16
16
import pytest
17
17
18
+ # Add cryptography imports for private key handling
19
+ from cryptography .hazmat .backends import default_backend
20
+ from cryptography .hazmat .primitives import serialization
21
+ from cryptography .hazmat .primitives .serialization import (
22
+ Encoding ,
23
+ NoEncryption ,
24
+ PrivateFormat ,
25
+ )
26
+
18
27
import snowflake .connector
19
28
from snowflake .connector .compat import IS_WINDOWS
20
29
from snowflake .connector .connection import DefaultConverterClass
32
41
from snowflake .connector import SnowflakeConnection
33
42
34
43
RUNNING_ON_GH = os .getenv ("GITHUB_ACTIONS" ) == "true"
44
+ RUNNING_ON_JENKINS = os .getenv ("JENKINS_HOME" ) not in (None , "false" )
45
+ RUNNING_OLD_DRIVER = os .getenv ("TOX_ENV_NAME" ) == "olddriver"
35
46
TEST_USING_VENDORED_ARROW = os .getenv ("TEST_USING_VENDORED_ARROW" ) == "true"
36
47
48
+
49
+ def _get_private_key_bytes_for_olddriver (private_key_file : str ) -> bytes :
50
+ """Load private key file and convert to DER format bytes for olddriver compatibility.
51
+
52
+ The olddriver expects private keys in DER format as bytes.
53
+ This function handles both PEM and DER input formats.
54
+ """
55
+ with open (private_key_file , "rb" ) as key_file :
56
+ key_data = key_file .read ()
57
+
58
+ # Try to load as PEM first, then DER
59
+ try :
60
+ # Try PEM format first
61
+ private_key = serialization .load_pem_private_key (
62
+ key_data ,
63
+ password = None ,
64
+ backend = default_backend (),
65
+ )
66
+ except ValueError :
67
+ try :
68
+ # Try DER format
69
+ private_key = serialization .load_der_private_key (
70
+ key_data ,
71
+ password = None ,
72
+ backend = default_backend (),
73
+ )
74
+ except ValueError as e :
75
+ raise ValueError (f"Could not load private key from { private_key_file } : { e } " )
76
+
77
+ # Convert to DER format bytes as expected by olddriver
78
+ return private_key .private_bytes (
79
+ encoding = Encoding .DER ,
80
+ format = PrivateFormat .PKCS8 ,
81
+ encryption_algorithm = NoEncryption (),
82
+ )
83
+
84
+
37
85
if not isinstance (CONNECTION_PARAMETERS ["host" ], str ):
38
86
raise Exception ("default host is not a string in parameters.py" )
39
87
RUNNING_AGAINST_LOCAL_SNOWFLAKE = CONNECTION_PARAMETERS ["host" ].endswith ("local" )
@@ -76,16 +124,42 @@ def _get_worker_specific_schema():
76
124
)
77
125
78
126
79
- DEFAULT_PARAMETERS : dict [str , Any ] = {
80
- "account" : "<account_name>" ,
81
- "user" : "<user_name>" ,
82
- "password" : "<password>" ,
83
- "database" : "<database_name>" ,
84
- "schema" : "<schema_name>" ,
85
- "protocol" : "https" ,
86
- "host" : "<host>" ,
87
- "port" : "443" ,
88
- }
127
+ if RUNNING_ON_JENKINS :
128
+ DEFAULT_PARAMETERS : dict [str , Any ] = {
129
+ "account" : "<account_name>" ,
130
+ "user" : "<user_name>" ,
131
+ "password" : "<password>" ,
132
+ "database" : "<database_name>" ,
133
+ "schema" : "<schema_name>" ,
134
+ "protocol" : "https" ,
135
+ "host" : "<host>" ,
136
+ "port" : "443" ,
137
+ }
138
+ else :
139
+ if RUNNING_OLD_DRIVER :
140
+ DEFAULT_PARAMETERS : dict [str , Any ] = {
141
+ "account" : "<account_name>" ,
142
+ "user" : "<user_name>" ,
143
+ "database" : "<database_name>" ,
144
+ "schema" : "<schema_name>" ,
145
+ "protocol" : "https" ,
146
+ "host" : "<host>" ,
147
+ "port" : "443" ,
148
+ "authenticator" : "SNOWFLAKE_JWT" ,
149
+ "private_key_file" : "<private_key_file>" ,
150
+ }
151
+ else :
152
+ DEFAULT_PARAMETERS : dict [str , Any ] = {
153
+ "account" : "<account_name>" ,
154
+ "user" : "<user_name>" ,
155
+ "database" : "<database_name>" ,
156
+ "schema" : "<schema_name>" ,
157
+ "protocol" : "https" ,
158
+ "host" : "<host>" ,
159
+ "port" : "443" ,
160
+ "authenticator" : "<authenticator>" ,
161
+ "private_key_file" : "<private_key_file>" ,
162
+ }
89
163
90
164
91
165
def print_help () -> None :
@@ -95,9 +169,10 @@ def print_help() -> None:
95
169
CONNECTION_PARAMETERS = {
96
170
'account': 'testaccount',
97
171
'user': 'user1',
98
- 'password': 'test',
99
172
'database': 'testdb',
100
173
'schema': 'public',
174
+ 'authenticator': 'KEY_PAIR_AUTHENTICATOR',
175
+ 'private_key_file': '/path/to/private_key.p8',
101
176
}
102
177
"""
103
178
)
@@ -200,16 +275,55 @@ def init_test_schema(db_parameters) -> Generator[None]:
200
275
201
276
This is automatically called per test session.
202
277
"""
203
- ret = db_parameters
204
- with snowflake .connector .connect (
205
- user = ret ["user" ],
206
- password = ret ["password" ],
207
- host = ret ["host" ],
208
- port = ret ["port" ],
209
- database = ret ["database" ],
210
- account = ret ["account" ],
211
- protocol = ret ["protocol" ],
212
- ) as con :
278
+ if RUNNING_ON_JENKINS :
279
+ connection_params = {
280
+ "user" : db_parameters ["user" ],
281
+ "password" : db_parameters ["password" ],
282
+ "host" : db_parameters ["host" ],
283
+ "port" : db_parameters ["port" ],
284
+ "database" : db_parameters ["database" ],
285
+ "account" : db_parameters ["account" ],
286
+ "protocol" : db_parameters ["protocol" ],
287
+ }
288
+ else :
289
+ connection_params = {
290
+ "user" : db_parameters ["user" ],
291
+ "host" : db_parameters ["host" ],
292
+ "port" : db_parameters ["port" ],
293
+ "database" : db_parameters ["database" ],
294
+ "account" : db_parameters ["account" ],
295
+ "protocol" : db_parameters ["protocol" ],
296
+ }
297
+
298
+ # Handle private key authentication differently for old vs new driver
299
+ if RUNNING_OLD_DRIVER :
300
+ # Old driver expects private_key as bytes and SNOWFLAKE_JWT authenticator
301
+ private_key_file = db_parameters .get ("private_key_file" )
302
+ if private_key_file :
303
+ private_key_bytes = _get_private_key_bytes_for_olddriver (
304
+ private_key_file
305
+ )
306
+ connection_params .update (
307
+ {
308
+ "authenticator" : "SNOWFLAKE_JWT" ,
309
+ "private_key" : private_key_bytes ,
310
+ }
311
+ )
312
+ else :
313
+ # New driver expects private_key_file and KEY_PAIR_AUTHENTICATOR
314
+ connection_params .update (
315
+ {
316
+ "authenticator" : db_parameters ["authenticator" ],
317
+ "private_key_file" : db_parameters ["private_key_file" ],
318
+ }
319
+ )
320
+
321
+ # Role may be needed when running on preprod, but is not present on Jenkins jobs
322
+ optional_role = db_parameters .get ("role" )
323
+ if optional_role is not None :
324
+ connection_params .update (role = optional_role )
325
+
326
+ with snowflake .connector .connect (** connection_params ) as con :
213
327
con .cursor ().execute (f"CREATE SCHEMA IF NOT EXISTS { TEST_SCHEMA } " )
214
328
yield
215
329
con .cursor ().execute (f"DROP SCHEMA IF EXISTS { TEST_SCHEMA } " )
@@ -224,6 +338,24 @@ def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
224
338
"""
225
339
ret = get_db_parameters (connection_name )
226
340
ret .update (kwargs )
341
+
342
+ # Handle private key authentication differently for old vs new driver (only if not on Jenkins)
343
+ if not RUNNING_ON_JENKINS and "private_key_file" in ret :
344
+ if RUNNING_OLD_DRIVER :
345
+ # Old driver (3.1.0) expects private_key as bytes and SNOWFLAKE_JWT authenticator
346
+ private_key_file = ret .get ("private_key_file" )
347
+ if (
348
+ private_key_file and "private_key" not in ret
349
+ ): # Don't override if private_key already set
350
+ private_key_bytes = _get_private_key_bytes_for_olddriver (
351
+ private_key_file
352
+ )
353
+ ret ["authenticator" ] = "SNOWFLAKE_JWT"
354
+ ret ["private_key" ] = private_key_bytes
355
+ ret .pop (
356
+ "private_key_file" , None
357
+ ) # Remove private_key_file for old driver
358
+
227
359
connection = snowflake .connector .connect (** ret )
228
360
return connection
229
361
0 commit comments