Skip to content

Commit b32e8b3

Browse files
adjust async conftest.py
1 parent f316c82 commit b32e8b3

File tree

1 file changed

+64
-21
lines changed

1 file changed

+64
-21
lines changed

test/integ/aio/conftest.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33
#
44

5+
import os
56
from contextlib import asynccontextmanager
6-
from test.integ.conftest import get_db_parameters, is_public_testaccount
7-
from typing import AsyncContextManager, Callable, Generator
7+
from test.integ.conftest import (
8+
_get_private_key_bytes_for_olddriver,
9+
get_db_parameters,
10+
is_public_testaccount,
11+
)
12+
from typing import AsyncContextManager, AsyncGenerator, Callable
813

914
import pytest
1015

@@ -44,7 +49,7 @@ async def patch_connection(
4449
self,
4550
con: SnowflakeConnection,
4651
propagate: bool = True,
47-
) -> Generator[TelemetryCaptureHandlerAsync, None, None]:
52+
) -> AsyncGenerator[TelemetryCaptureHandlerAsync, None]:
4853
original_telemetry = con._telemetry
4954
new_telemetry = TelemetryCaptureHandlerAsync(
5055
original_telemetry,
@@ -57,6 +62,9 @@ async def patch_connection(
5762
con._telemetry = original_telemetry
5863

5964

65+
RUNNING_OLD_DRIVER = os.getenv("TOX_ENV_NAME") == "olddriver"
66+
67+
6068
@pytest.fixture(scope="session")
6169
def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync:
6270
return TelemetryCaptureFixtureAsync()
@@ -71,6 +79,16 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
7179
"""
7280
ret = get_db_parameters(connection_name)
7381
ret.update(kwargs)
82+
83+
# Handle private key authentication for old driver if applicable
84+
if RUNNING_OLD_DRIVER and "private_key_file" in ret and "private_key" not in ret:
85+
private_key_file = ret.get("private_key_file")
86+
if private_key_file:
87+
private_key_bytes = _get_private_key_bytes_for_olddriver(private_key_file)
88+
ret["authenticator"] = "SNOWFLAKE_JWT"
89+
ret["private_key"] = private_key_bytes
90+
ret.pop("private_key_file", None)
91+
7492
connection = SnowflakeConnection(**ret)
7593
await connection.connect()
7694
return connection
@@ -80,7 +98,7 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
8098
async def db(
8199
connection_name: str = "default",
82100
**kwargs,
83-
) -> Generator[SnowflakeConnection, None, None]:
101+
) -> AsyncGenerator[SnowflakeConnection, None]:
84102
if not kwargs.get("timezone"):
85103
kwargs["timezone"] = "UTC"
86104
if not kwargs.get("converter_class"):
@@ -96,7 +114,7 @@ async def db(
96114
async def negative_db(
97115
connection_name: str = "default",
98116
**kwargs,
99-
) -> Generator[SnowflakeConnection, None, None]:
117+
) -> AsyncGenerator[SnowflakeConnection, None]:
100118
if not kwargs.get("timezone"):
101119
kwargs["timezone"] = "UTC"
102120
if not kwargs.get("converter_class"):
@@ -116,7 +134,7 @@ def conn_cnx():
116134

117135

118136
@pytest.fixture()
119-
async def conn_testaccount() -> SnowflakeConnection:
137+
async def conn_testaccount() -> AsyncGenerator[SnowflakeConnection, None]:
120138
connection = await create_connection("default")
121139
yield connection
122140
await connection.close()
@@ -129,18 +147,43 @@ def negative_conn_cnx() -> Callable[..., AsyncContextManager[SnowflakeConnection
129147

130148

131149
@pytest.fixture()
132-
async def aio_connection(db_parameters):
133-
cnx = SnowflakeConnection(
134-
user=db_parameters["user"],
135-
password=db_parameters["password"],
136-
host=db_parameters["host"],
137-
port=db_parameters["port"],
138-
account=db_parameters["account"],
139-
database=db_parameters["database"],
140-
schema=db_parameters["schema"],
141-
warehouse=db_parameters["warehouse"],
142-
protocol=db_parameters["protocol"],
143-
timezone="UTC",
144-
)
145-
yield cnx
146-
await cnx.close()
150+
async def aio_connection(db_parameters) -> AsyncGenerator[SnowflakeConnection, None]:
151+
# Build connection params supporting both password and key-pair auth depending on environment
152+
connection_params = {
153+
"user": db_parameters["user"],
154+
"host": db_parameters["host"],
155+
"port": db_parameters["port"],
156+
"account": db_parameters["account"],
157+
"database": db_parameters["database"],
158+
"schema": db_parameters["schema"],
159+
"protocol": db_parameters["protocol"],
160+
"timezone": "UTC",
161+
}
162+
163+
# Optional fields
164+
warehouse = db_parameters.get("warehouse")
165+
if warehouse is not None:
166+
connection_params["warehouse"] = warehouse
167+
168+
role = db_parameters.get("role")
169+
if role is not None:
170+
connection_params["role"] = role
171+
172+
if "password" in db_parameters and db_parameters["password"]:
173+
connection_params["password"] = db_parameters["password"]
174+
elif "private_key_file" in db_parameters:
175+
# Use key-pair authentication
176+
connection_params["authenticator"] = "SNOWFLAKE_JWT"
177+
if RUNNING_OLD_DRIVER:
178+
private_key_bytes = _get_private_key_bytes_for_olddriver(
179+
db_parameters["private_key_file"]
180+
)
181+
connection_params["private_key"] = private_key_bytes
182+
else:
183+
connection_params["private_key_file"] = db_parameters["private_key_file"]
184+
185+
cnx = SnowflakeConnection(**connection_params)
186+
try:
187+
yield cnx
188+
finally:
189+
await cnx.close()

0 commit comments

Comments
 (0)