2
2
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
3
#
4
4
5
+ import os
5
6
from contextlib import asynccontextmanager
6
- from test .integ .conftest import get_db_parameters , is_public_testaccount
7
- from typing import AsyncContextManager , Callable , Generator
7
+ from test .integ .conftest import (
8
+ _get_private_key_bytes_for_olddriver ,
9
+ get_db_parameters ,
10
+ is_public_testaccount ,
11
+ )
12
+ from typing import AsyncContextManager , AsyncGenerator , Callable
8
13
9
14
import pytest
10
15
@@ -44,7 +49,7 @@ async def patch_connection(
44
49
self ,
45
50
con : SnowflakeConnection ,
46
51
propagate : bool = True ,
47
- ) -> Generator [TelemetryCaptureHandlerAsync , None , None ]:
52
+ ) -> AsyncGenerator [TelemetryCaptureHandlerAsync , None ]:
48
53
original_telemetry = con ._telemetry
49
54
new_telemetry = TelemetryCaptureHandlerAsync (
50
55
original_telemetry ,
@@ -57,6 +62,9 @@ async def patch_connection(
57
62
con ._telemetry = original_telemetry
58
63
59
64
65
+ RUNNING_OLD_DRIVER = os .getenv ("TOX_ENV_NAME" ) == "olddriver"
66
+
67
+
60
68
@pytest .fixture (scope = "session" )
61
69
def capture_sf_telemetry_async () -> TelemetryCaptureFixtureAsync :
62
70
return TelemetryCaptureFixtureAsync ()
@@ -71,6 +79,22 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
71
79
"""
72
80
ret = get_db_parameters (connection_name )
73
81
ret .update (kwargs )
82
+
83
+ # Handle private key authentication for old driver if applicable
84
+ if RUNNING_OLD_DRIVER and "private_key_file" in ret and "private_key" not in ret :
85
+ private_key_file = ret .get ("private_key_file" )
86
+ if private_key_file :
87
+ private_key_bytes = _get_private_key_bytes_for_olddriver (private_key_file )
88
+ ret ["authenticator" ] = "SNOWFLAKE_JWT"
89
+ ret ["private_key" ] = private_key_bytes
90
+ ret .pop ("private_key_file" , None )
91
+
92
+ # If authenticator is explicitly provided and it's not key-pair based, drop key-pair fields
93
+ authenticator_value = ret .get ("authenticator" )
94
+ if authenticator_value .lower () not in {"key_pair_authenticator" , "snowflake_jwt" }:
95
+ ret .pop ("private_key" , None )
96
+ ret .pop ("private_key_file" , None )
97
+
74
98
connection = SnowflakeConnection (** ret )
75
99
await connection .connect ()
76
100
return connection
@@ -80,7 +104,7 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
80
104
async def db (
81
105
connection_name : str = "default" ,
82
106
** kwargs ,
83
- ) -> Generator [SnowflakeConnection , None , None ]:
107
+ ) -> AsyncGenerator [SnowflakeConnection , None ]:
84
108
if not kwargs .get ("timezone" ):
85
109
kwargs ["timezone" ] = "UTC"
86
110
if not kwargs .get ("converter_class" ):
@@ -96,7 +120,7 @@ async def db(
96
120
async def negative_db (
97
121
connection_name : str = "default" ,
98
122
** kwargs ,
99
- ) -> Generator [SnowflakeConnection , None , None ]:
123
+ ) -> AsyncGenerator [SnowflakeConnection , None ]:
100
124
if not kwargs .get ("timezone" ):
101
125
kwargs ["timezone" ] = "UTC"
102
126
if not kwargs .get ("converter_class" ):
@@ -116,7 +140,7 @@ def conn_cnx():
116
140
117
141
118
142
@pytest .fixture ()
119
- async def conn_testaccount () -> SnowflakeConnection :
143
+ async def conn_testaccount () -> AsyncGenerator [ SnowflakeConnection , None ] :
120
144
connection = await create_connection ("default" )
121
145
yield connection
122
146
await connection .close ()
@@ -129,18 +153,43 @@ def negative_conn_cnx() -> Callable[..., AsyncContextManager[SnowflakeConnection
129
153
130
154
131
155
@pytest .fixture ()
132
- async def aio_connection (db_parameters ):
133
- cnx = SnowflakeConnection (
134
- user = db_parameters ["user" ],
135
- password = db_parameters ["password" ],
136
- host = db_parameters ["host" ],
137
- port = db_parameters ["port" ],
138
- account = db_parameters ["account" ],
139
- database = db_parameters ["database" ],
140
- schema = db_parameters ["schema" ],
141
- warehouse = db_parameters ["warehouse" ],
142
- protocol = db_parameters ["protocol" ],
143
- timezone = "UTC" ,
144
- )
145
- yield cnx
146
- await cnx .close ()
156
+ async def aio_connection (db_parameters ) -> AsyncGenerator [SnowflakeConnection , None ]:
157
+ # Build connection params supporting both password and key-pair auth depending on environment
158
+ connection_params = {
159
+ "user" : db_parameters ["user" ],
160
+ "host" : db_parameters ["host" ],
161
+ "port" : db_parameters ["port" ],
162
+ "account" : db_parameters ["account" ],
163
+ "database" : db_parameters ["database" ],
164
+ "schema" : db_parameters ["schema" ],
165
+ "protocol" : db_parameters ["protocol" ],
166
+ "timezone" : "UTC" ,
167
+ }
168
+
169
+ # Optional fields
170
+ warehouse = db_parameters .get ("warehouse" )
171
+ if warehouse is not None :
172
+ connection_params ["warehouse" ] = warehouse
173
+
174
+ role = db_parameters .get ("role" )
175
+ if role is not None :
176
+ connection_params ["role" ] = role
177
+
178
+ if "password" in db_parameters and db_parameters ["password" ]:
179
+ connection_params ["password" ] = db_parameters ["password" ]
180
+ elif "private_key_file" in db_parameters :
181
+ # Use key-pair authentication
182
+ connection_params ["authenticator" ] = "SNOWFLAKE_JWT"
183
+ if RUNNING_OLD_DRIVER :
184
+ private_key_bytes = _get_private_key_bytes_for_olddriver (
185
+ db_parameters ["private_key_file" ]
186
+ )
187
+ connection_params ["private_key" ] = private_key_bytes
188
+ else :
189
+ connection_params ["private_key_file" ] = db_parameters ["private_key_file" ]
190
+
191
+ cnx = SnowflakeConnection (** connection_params )
192
+ try :
193
+ yield cnx
194
+ finally :
195
+ await cnx .close ()
0 commit comments