Skip to content

Commit 322852d

Browse files
Apply changes to async tests and workflows
1 parent 407d131 commit 322852d

17 files changed

+426
-644
lines changed

.github/workflows/build_test.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,13 @@ jobs:
424424
run: |
425425
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
426426
.github/workflows/parameters/public/parameters_${{ matrix.cloud-provider }}.py.gpg > test/parameters.py
427+
- name: Setup private key file
428+
shell: bash
429+
env:
430+
PYTHON_PRIVATE_KEY_SECRET: ${{ secrets.PYTHON_PRIVATE_KEY_SECRET }}
431+
run: |
432+
gpg --quiet --batch --yes --decrypt --passphrase="$PYTHON_PRIVATE_KEY_SECRET" \
433+
.github/workflows/parameters/public/rsa_keys/rsa_key_python_${{ matrix.cloud-provider }}.p8.gpg > test/rsa_key_python_${{ matrix.cloud-provider }}.p8
427434
- name: Download wheel(s)
428435
uses: actions/download-artifact@v4
429436
with:

test/integ/aio/conftest.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33
#
44

5+
import os
56
from contextlib import asynccontextmanager
6-
from test.integ.conftest import get_db_parameters, is_public_testaccount
7-
from typing import AsyncContextManager, Callable, Generator
7+
from test.integ.conftest import (
8+
_get_private_key_bytes_for_olddriver,
9+
get_db_parameters,
10+
is_public_testaccount,
11+
)
12+
from typing import AsyncContextManager, AsyncGenerator, Callable
813

914
import pytest
1015

@@ -44,7 +49,7 @@ async def patch_connection(
4449
self,
4550
con: SnowflakeConnection,
4651
propagate: bool = True,
47-
) -> Generator[TelemetryCaptureHandlerAsync, None, None]:
52+
) -> AsyncGenerator[TelemetryCaptureHandlerAsync, None]:
4853
original_telemetry = con._telemetry
4954
new_telemetry = TelemetryCaptureHandlerAsync(
5055
original_telemetry,
@@ -57,6 +62,9 @@ async def patch_connection(
5762
con._telemetry = original_telemetry
5863

5964

65+
RUNNING_OLD_DRIVER = os.getenv("TOX_ENV_NAME") == "olddriver"
66+
67+
6068
@pytest.fixture(scope="session")
6169
def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync:
6270
return TelemetryCaptureFixtureAsync()
@@ -71,6 +79,22 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
7179
"""
7280
ret = get_db_parameters(connection_name)
7381
ret.update(kwargs)
82+
83+
# Handle private key authentication for old driver if applicable
84+
if RUNNING_OLD_DRIVER and "private_key_file" in ret and "private_key" not in ret:
85+
private_key_file = ret.get("private_key_file")
86+
if private_key_file:
87+
private_key_bytes = _get_private_key_bytes_for_olddriver(private_key_file)
88+
ret["authenticator"] = "SNOWFLAKE_JWT"
89+
ret["private_key"] = private_key_bytes
90+
ret.pop("private_key_file", None)
91+
92+
# If authenticator is explicitly provided and it's not key-pair based, drop key-pair fields
93+
authenticator_value = ret.get("authenticator")
94+
if authenticator_value.lower() not in {"key_pair_authenticator", "snowflake_jwt"}:
95+
ret.pop("private_key", None)
96+
ret.pop("private_key_file", None)
97+
7498
connection = SnowflakeConnection(**ret)
7599
await connection.connect()
76100
return connection
@@ -80,7 +104,7 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
80104
async def db(
81105
connection_name: str = "default",
82106
**kwargs,
83-
) -> Generator[SnowflakeConnection, None, None]:
107+
) -> AsyncGenerator[SnowflakeConnection, None]:
84108
if not kwargs.get("timezone"):
85109
kwargs["timezone"] = "UTC"
86110
if not kwargs.get("converter_class"):
@@ -96,7 +120,7 @@ async def db(
96120
async def negative_db(
97121
connection_name: str = "default",
98122
**kwargs,
99-
) -> Generator[SnowflakeConnection, None, None]:
123+
) -> AsyncGenerator[SnowflakeConnection, None]:
100124
if not kwargs.get("timezone"):
101125
kwargs["timezone"] = "UTC"
102126
if not kwargs.get("converter_class"):
@@ -116,7 +140,7 @@ def conn_cnx():
116140

117141

118142
@pytest.fixture()
119-
async def conn_testaccount() -> SnowflakeConnection:
143+
async def conn_testaccount() -> AsyncGenerator[SnowflakeConnection, None]:
120144
connection = await create_connection("default")
121145
yield connection
122146
await connection.close()
@@ -129,18 +153,43 @@ def negative_conn_cnx() -> Callable[..., AsyncContextManager[SnowflakeConnection
129153

130154

131155
@pytest.fixture()
132-
async def aio_connection(db_parameters):
133-
cnx = SnowflakeConnection(
134-
user=db_parameters["user"],
135-
password=db_parameters["password"],
136-
host=db_parameters["host"],
137-
port=db_parameters["port"],
138-
account=db_parameters["account"],
139-
database=db_parameters["database"],
140-
schema=db_parameters["schema"],
141-
warehouse=db_parameters["warehouse"],
142-
protocol=db_parameters["protocol"],
143-
timezone="UTC",
144-
)
145-
yield cnx
146-
await cnx.close()
156+
async def aio_connection(db_parameters) -> AsyncGenerator[SnowflakeConnection, None]:
157+
# Build connection params supporting both password and key-pair auth depending on environment
158+
connection_params = {
159+
"user": db_parameters["user"],
160+
"host": db_parameters["host"],
161+
"port": db_parameters["port"],
162+
"account": db_parameters["account"],
163+
"database": db_parameters["database"],
164+
"schema": db_parameters["schema"],
165+
"protocol": db_parameters["protocol"],
166+
"timezone": "UTC",
167+
}
168+
169+
# Optional fields
170+
warehouse = db_parameters.get("warehouse")
171+
if warehouse is not None:
172+
connection_params["warehouse"] = warehouse
173+
174+
role = db_parameters.get("role")
175+
if role is not None:
176+
connection_params["role"] = role
177+
178+
if "password" in db_parameters and db_parameters["password"]:
179+
connection_params["password"] = db_parameters["password"]
180+
elif "private_key_file" in db_parameters:
181+
# Use key-pair authentication
182+
connection_params["authenticator"] = "SNOWFLAKE_JWT"
183+
if RUNNING_OLD_DRIVER:
184+
private_key_bytes = _get_private_key_bytes_for_olddriver(
185+
db_parameters["private_key_file"]
186+
)
187+
connection_params["private_key"] = private_key_bytes
188+
else:
189+
connection_params["private_key_file"] = db_parameters["private_key_file"]
190+
191+
cnx = SnowflakeConnection(**connection_params)
192+
try:
193+
yield cnx
194+
finally:
195+
await cnx.close()

test/integ/aio/test_arrow_result_async.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def structured_type_wrapped_conn(conn_cnx, structured_type_support):
136136

137137

138138
@pytest.mark.asyncio
139-
@pytest.mark.parametrize("datatype", ICEBERG_UNSUPPORTED_TYPES)
139+
@pytest.mark.parametrize("datatype", sorted(ICEBERG_UNSUPPORTED_TYPES))
140140
async def test_iceberg_negative(
141141
datatype, conn_cnx, iceberg_support, structured_type_support
142142
):
@@ -834,35 +834,46 @@ async def test_select_vector(conn_cnx, is_public_test):
834834

835835
@pytest.mark.asyncio
836836
async def test_select_time(conn_cnx):
837-
for scale in range(10):
838-
await select_time_with_scale(conn_cnx, scale)
839-
840-
841-
async def select_time_with_scale(conn_cnx, scale):
837+
# Test key scales and meaningful cases in a single table operation
838+
# Cover: no fractional seconds, milliseconds, microseconds, nanoseconds
839+
scales = [0, 3, 6, 9] # Key precision levels
842840
cases = [
843-
"00:01:23",
844-
"00:01:23.1",
845-
"00:01:23.12",
846-
"00:01:23.123",
847-
"00:01:23.1234",
848-
"00:01:23.12345",
849-
"00:01:23.123456",
850-
"00:01:23.1234567",
851-
"00:01:23.12345678",
852-
"00:01:23.123456789",
841+
"00:01:23", # Basic time
842+
"00:01:23.123456789", # Max precision
843+
"23:59:59.999999999", # Edge case - max time with max precision
844+
"00:00:00.000000001", # Edge case - min time with min precision
853845
]
854-
table = "test_arrow_time"
855-
column = f"(a time({scale}))"
856-
values = (
857-
"(-1, NULL), ("
858-
+ "),(".join([f"{i}, '{c}'" for i, c in enumerate(cases)])
859-
+ f"), ({len(cases)}, NULL)"
860-
)
861-
await init(conn_cnx, table, column, values)
862-
sql_text = f"select a from {table} order by s"
863-
row_count = len(cases) + 2
864-
col_count = 1
865-
await iterate_over_test_chunk("time", conn_cnx, sql_text, row_count, col_count)
846+
847+
table = "test_arrow_time_scales"
848+
849+
# Create columns for selected scales only (init function will add 's number' automatically)
850+
columns = ", ".join([f"a{i} time({i})" for i in scales])
851+
column_def = f"({columns})"
852+
853+
# Create values for selected scales - each case tests all scales simultaneously
854+
value_rows = []
855+
for i, case in enumerate(cases):
856+
# Each row has the same time value for all scale columns
857+
time_values = ", ".join([f"'{case}'" for _ in scales])
858+
value_rows.append(f"({i}, {time_values})")
859+
860+
# Add NULL rows
861+
null_values = ", ".join(["NULL" for _ in scales])
862+
value_rows.append(f"(-1, {null_values})")
863+
value_rows.append(f"({len(cases)}, {null_values})")
864+
865+
values = ", ".join(value_rows)
866+
867+
# Single table creation and test
868+
await init(conn_cnx, table, column_def, values)
869+
870+
# Test each scale column
871+
for scale in scales:
872+
sql_text = f"select a{scale} from {table} order by s"
873+
row_count = len(cases) + 2
874+
col_count = 1
875+
await iterate_over_test_chunk("time", conn_cnx, sql_text, row_count, col_count)
876+
866877
await finish(conn_cnx, table)
867878

868879

test/integ/aio/test_autocommit_async.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
from __future__ import annotations
77

8-
import snowflake.connector.aio
9-
108

119
async def exe0(cnx, sql):
1210
return await cnx.cursor().execute(sql)
@@ -164,7 +162,7 @@ async def exe(cnx, sql):
164162
)
165163

166164

167-
async def test_autocommit_parameters(db_parameters):
165+
async def test_autocommit_parameters(db_parameters, conn_cnx):
168166
"""Tests autocommit parameter.
169167
170168
Args:
@@ -174,17 +172,7 @@ async def test_autocommit_parameters(db_parameters):
174172
async def exe(cnx, sql):
175173
return await cnx.cursor().execute(sql.format(name=db_parameters["name"]))
176174

177-
async with snowflake.connector.aio.SnowflakeConnection(
178-
user=db_parameters["user"],
179-
password=db_parameters["password"],
180-
host=db_parameters["host"],
181-
port=db_parameters["port"],
182-
account=db_parameters["account"],
183-
protocol=db_parameters["protocol"],
184-
schema=db_parameters["schema"],
185-
database=db_parameters["database"],
186-
autocommit=False,
187-
) as cnx:
175+
async with conn_cnx(autocommit=False) as cnx:
188176
await exe(
189177
cnx,
190178
"""
@@ -193,17 +181,7 @@ async def exe(cnx, sql):
193181
)
194182
await _run_autocommit_off(cnx, db_parameters)
195183

196-
async with snowflake.connector.aio.SnowflakeConnection(
197-
user=db_parameters["user"],
198-
password=db_parameters["password"],
199-
host=db_parameters["host"],
200-
port=db_parameters["port"],
201-
account=db_parameters["account"],
202-
protocol=db_parameters["protocol"],
203-
schema=db_parameters["schema"],
204-
database=db_parameters["database"],
205-
autocommit=True,
206-
) as cnx:
184+
async with conn_cnx(autocommit=True) as cnx:
207185
await _run_autocommit_on(cnx, db_parameters)
208186
await exe(
209187
cnx,

0 commit comments

Comments
 (0)