22# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33#
44
5+ import os
56from contextlib import asynccontextmanager
6- from test .integ .conftest import get_db_parameters , is_public_testaccount
7- from typing import AsyncContextManager , Callable , Generator
7+ from test .integ .conftest import (
8+ _get_private_key_bytes_for_olddriver ,
9+ get_db_parameters ,
10+ is_public_testaccount ,
11+ )
12+ from typing import AsyncContextManager , AsyncGenerator , Callable
813
914import pytest
1015
@@ -44,7 +49,7 @@ async def patch_connection(
4449 self ,
4550 con : SnowflakeConnection ,
4651 propagate : bool = True ,
47- ) -> Generator [TelemetryCaptureHandlerAsync , None , None ]:
52+ ) -> AsyncGenerator [TelemetryCaptureHandlerAsync , None ]:
4853 original_telemetry = con ._telemetry
4954 new_telemetry = TelemetryCaptureHandlerAsync (
5055 original_telemetry ,
@@ -57,6 +62,9 @@ async def patch_connection(
5762 con ._telemetry = original_telemetry
5863
5964
65+ RUNNING_OLD_DRIVER = os .getenv ("TOX_ENV_NAME" ) == "olddriver"
66+
67+
6068@pytest .fixture (scope = "session" )
6169def capture_sf_telemetry_async () -> TelemetryCaptureFixtureAsync :
6270 return TelemetryCaptureFixtureAsync ()
@@ -71,6 +79,16 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
7179 """
7280 ret = get_db_parameters (connection_name )
7381 ret .update (kwargs )
82+
83+ # Handle private key authentication for old driver if applicable
84+ if RUNNING_OLD_DRIVER and "private_key_file" in ret and "private_key" not in ret :
85+ private_key_file = ret .get ("private_key_file" )
86+ if private_key_file :
87+ private_key_bytes = _get_private_key_bytes_for_olddriver (private_key_file )
88+ ret ["authenticator" ] = "SNOWFLAKE_JWT"
89+ ret ["private_key" ] = private_key_bytes
90+ ret .pop ("private_key_file" , None )
91+
7492 connection = SnowflakeConnection (** ret )
7593 await connection .connect ()
7694 return connection
@@ -80,7 +98,7 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
8098async def db (
8199 connection_name : str = "default" ,
82100 ** kwargs ,
83- ) -> Generator [SnowflakeConnection , None , None ]:
101+ ) -> AsyncGenerator [SnowflakeConnection , None ]:
84102 if not kwargs .get ("timezone" ):
85103 kwargs ["timezone" ] = "UTC"
86104 if not kwargs .get ("converter_class" ):
@@ -96,7 +114,7 @@ async def db(
96114async def negative_db (
97115 connection_name : str = "default" ,
98116 ** kwargs ,
99- ) -> Generator [SnowflakeConnection , None , None ]:
117+ ) -> AsyncGenerator [SnowflakeConnection , None ]:
100118 if not kwargs .get ("timezone" ):
101119 kwargs ["timezone" ] = "UTC"
102120 if not kwargs .get ("converter_class" ):
@@ -116,7 +134,7 @@ def conn_cnx():
116134
117135
118136@pytest .fixture ()
119- async def conn_testaccount () -> SnowflakeConnection :
137+ async def conn_testaccount () -> AsyncGenerator [ SnowflakeConnection , None ] :
120138 connection = await create_connection ("default" )
121139 yield connection
122140 await connection .close ()
@@ -129,18 +147,43 @@ def negative_conn_cnx() -> Callable[..., AsyncContextManager[SnowflakeConnection
129147
130148
131149@pytest .fixture ()
132- async def aio_connection (db_parameters ):
133- cnx = SnowflakeConnection (
134- user = db_parameters ["user" ],
135- password = db_parameters ["password" ],
136- host = db_parameters ["host" ],
137- port = db_parameters ["port" ],
138- account = db_parameters ["account" ],
139- database = db_parameters ["database" ],
140- schema = db_parameters ["schema" ],
141- warehouse = db_parameters ["warehouse" ],
142- protocol = db_parameters ["protocol" ],
143- timezone = "UTC" ,
144- )
145- yield cnx
146- await cnx .close ()
150+ async def aio_connection (db_parameters ) -> AsyncGenerator [SnowflakeConnection , None ]:
151+ # Build connection params supporting both password and key-pair auth depending on environment
152+ connection_params = {
153+ "user" : db_parameters ["user" ],
154+ "host" : db_parameters ["host" ],
155+ "port" : db_parameters ["port" ],
156+ "account" : db_parameters ["account" ],
157+ "database" : db_parameters ["database" ],
158+ "schema" : db_parameters ["schema" ],
159+ "protocol" : db_parameters ["protocol" ],
160+ "timezone" : "UTC" ,
161+ }
162+
163+ # Optional fields
164+ warehouse = db_parameters .get ("warehouse" )
165+ if warehouse is not None :
166+ connection_params ["warehouse" ] = warehouse
167+
168+ role = db_parameters .get ("role" )
169+ if role is not None :
170+ connection_params ["role" ] = role
171+
172+ if "password" in db_parameters and db_parameters ["password" ]:
173+ connection_params ["password" ] = db_parameters ["password" ]
174+ elif "private_key_file" in db_parameters :
175+ # Use key-pair authentication
176+ connection_params ["authenticator" ] = "SNOWFLAKE_JWT"
177+ if RUNNING_OLD_DRIVER :
178+ private_key_bytes = _get_private_key_bytes_for_olddriver (
179+ db_parameters ["private_key_file" ]
180+ )
181+ connection_params ["private_key" ] = private_key_bytes
182+ else :
183+ connection_params ["private_key_file" ] = db_parameters ["private_key_file" ]
184+
185+ cnx = SnowflakeConnection (** connection_params )
186+ try :
187+ yield cnx
188+ finally :
189+ await cnx .close ()
0 commit comments