Skip to content

Commit e209415

Browse files
sfc-gh-pcyreksfc-gh-pczajka
authored andcommitted
SNOW-2021009: test optimisation (#2388)
1 parent e0a46ea commit e209415

File tree

14 files changed

+376
-97
lines changed

14 files changed

+376
-97
lines changed

.github/workflows/build_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ jobs:
173173
run: python -m pip install tox>=4
174174
- name: Run tests
175175
# To run a single test on GHA use the below command:
176-
# run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-single-ci | sed 's/ /,/g'`
177-
run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-{extras,unit,integ,pandas,sso}-ci | sed 's/ /,/g'`
176+
# run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-single-ci | sed 's/ /,/g'`
177+
run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-{extras,unit-parallel,integ-parallel,pandas-parallel,sso}-ci | sed 's/ /,/g'`
178178

179179
env:
180180
PYTHON_VERSION: ${{ matrix.python-version }}

ci/test_fips.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ curl https://repo1.maven.org/maven2/org/wiremock/wiremock-standalone/3.11.0/wire
1414
python3 -m venv fips_env
1515
source fips_env/bin/activate
1616
pip install -U setuptools pip
17+
18+
# Install pytest-xdist for parallel execution
19+
pip install pytest-xdist
20+
1721
pip install "${CONNECTOR_WHL}[pandas,secure-local-storage,development]"
1822

1923
echo "!!! Environment description !!!"
@@ -24,6 +28,8 @@ python -c "from cryptography.hazmat.backends.openssl import backend;print('Cryp
2428
pip freeze
2529

2630
cd $CONNECTOR_DIR
27-
pytest -vvv --cov=snowflake.connector --cov-report=xml:coverage.xml test --ignore=test/integ/aio --ignore=test/unit/aio
31+
32+
# Run tests in parallel using pytest-xdist
33+
pytest -n auto -vvv --cov=snowflake.connector --cov-report=xml:coverage.xml test --ignore=test/integ/aio --ignore=test/unit/aio
2834

2935
deactivate

ci/test_linux.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ else
4040
echo "[Info] Testing with ${PYTHON_VERSION}"
4141
SHORT_VERSION=$(python3.10 -c "print('${PYTHON_VERSION}'.replace('.', ''))")
4242
CONNECTOR_WHL=$(ls $CONNECTOR_DIR/dist/snowflake_connector_python*cp${SHORT_VERSION}*manylinux2014*.whl | sort -r | head -n 1)
43-
TEST_LIST=`echo py${PYTHON_VERSION/\./}-{unit,integ,pandas,sso}-ci | sed 's/ /,/g'`
43+
TEST_LIST=`echo py${PYTHON_VERSION/\./}-{unit-parallel,integ,pandas-parallel,sso}-ci | sed 's/ /,/g'`
4444
TEST_ENVLIST=fix_lint,$TEST_LIST,py${PYTHON_VERSION/\./}-coverage
4545
echo "[Info] Running tox for ${TEST_ENVLIST}"
4646

src/snowflake/connector/ocsp_snowflake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def _download_ocsp_response_cache(ocsp, url, do_retry: bool = True) -> bool:
576576
response.status_code,
577577
sleep_time,
578578
)
579-
time.sleep(sleep_time)
579+
time.sleep(sleep_time)
580580
else:
581581
logger.error(
582582
"Failed to get OCSP response after %s attempt.", max_retry
@@ -1649,7 +1649,7 @@ def _fetch_ocsp_response(
16491649
response.status_code,
16501650
sleep_time,
16511651
)
1652-
time.sleep(sleep_time)
1652+
time.sleep(sleep_time)
16531653
except Exception as ex:
16541654
if max_retry > 1:
16551655
sleep_time = next(backoff)

test/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,18 @@ def pytest_runtest_setup(item) -> None:
146146
pytest.skip("cannot run this test on public Snowflake deployment")
147147
elif INTERNAL_SKIP_TAGS.intersection(test_tags) and not running_on_public_ci():
148148
pytest.skip("cannot run this test on private Snowflake deployment")
149+
150+
if "auth" in test_tags:
151+
if os.getenv("RUN_AUTH_TESTS") != "true":
152+
pytest.skip("Skipping auth test in current environment")
153+
154+
155+
def get_server_parameter_value(connection, parameter_name: str) -> str | None:
156+
"""Get server parameter value, returns None if parameter doesn't exist."""
157+
try:
158+
with connection.cursor() as cur:
159+
cur.execute(f"show parameters like '{parameter_name}'")
160+
ret = cur.fetchone()
161+
return ret[1] if ret else None
162+
except Exception:
163+
return None

test/helpers.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,34 @@ def _arrow_error_stream_chunk_remove_single_byte_test(use_table_iterator):
198198
decode_bytes = base64.b64decode(b64data)
199199
exception_result = []
200200
result_array = []
201-
for i in range(len(decode_bytes)):
201+
202+
# Test strategic positions instead of every byte for performance
203+
# Test header (first 50), middle section, end (last 50), and some random positions
204+
data_len = len(decode_bytes)
205+
test_positions = set()
206+
207+
# Critical positions: beginning (headers/metadata)
208+
test_positions.update(range(min(50, data_len)))
209+
210+
# Middle section positions
211+
mid_start = data_len // 2 - 25
212+
mid_end = data_len // 2 + 25
213+
test_positions.update(range(max(0, mid_start), min(data_len, mid_end)))
214+
215+
# End positions
216+
test_positions.update(range(max(0, data_len - 50), data_len))
217+
218+
# Some random positions throughout the data (for broader coverage)
219+
import random
220+
221+
random.seed(42) # Deterministic for reproducible tests
222+
random_positions = random.sample(range(data_len), min(50, data_len))
223+
test_positions.update(random_positions)
224+
225+
# Convert to sorted list for consistent execution
226+
test_positions = sorted(test_positions)
227+
228+
for i in test_positions:
202229
try:
203230
# removing the i-th char in the bytes
204231
iterator = create_nanoarrow_pyarrow_iterator(

test/integ/conftest.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,30 @@
4545

4646
logger = getLogger(__name__)
4747

48-
if RUNNING_ON_GH:
49-
TEST_SCHEMA = "GH_JOB_{}".format(str(uuid.uuid4()).replace("-", "_"))
50-
else:
51-
TEST_SCHEMA = "python_connector_tests_" + str(uuid.uuid4()).replace("-", "_")
48+
49+
def _get_worker_specific_schema():
50+
"""Generate worker-specific schema name for parallel test execution."""
51+
base_uuid = str(uuid.uuid4()).replace("-", "_")
52+
53+
# Check if running in pytest-xdist parallel mode
54+
worker_id = os.getenv("PYTEST_XDIST_WORKER")
55+
if worker_id:
56+
# Use worker ID to ensure unique schema per worker
57+
worker_suffix = worker_id.replace("-", "_")
58+
if RUNNING_ON_GH:
59+
return f"GH_JOB_{worker_suffix}_{base_uuid}"
60+
else:
61+
return f"python_connector_tests_{worker_suffix}_{base_uuid}"
62+
else:
63+
# Single worker mode (original behavior)
64+
if RUNNING_ON_GH:
65+
return f"GH_JOB_{base_uuid}"
66+
else:
67+
return f"python_connector_tests_{base_uuid}"
68+
69+
70+
TEST_SCHEMA = _get_worker_specific_schema()
71+
5272

5373
if TEST_USING_VENDORED_ARROW:
5474
snowflake.connector.cursor.NANOARR_USAGE = (
@@ -140,8 +160,15 @@ def get_db_parameters(connection_name: str = "default") -> dict[str, Any]:
140160
print_help()
141161
sys.exit(2)
142162

143-
# a unique table name
144-
ret["name"] = "python_tests_" + str(uuid.uuid4()).replace("-", "_")
163+
# a unique table name (worker-specific for parallel execution)
164+
base_uuid = str(uuid.uuid4()).replace("-", "_")
165+
worker_id = os.getenv("PYTEST_XDIST_WORKER")
166+
if worker_id:
167+
# Include worker ID to prevent conflicts between parallel workers
168+
worker_suffix = worker_id.replace("-", "_")
169+
ret["name"] = f"python_tests_{worker_suffix}_{base_uuid}"
170+
else:
171+
ret["name"] = f"python_tests_{base_uuid}"
145172
ret["name_wh"] = ret["name"] + "wh"
146173

147174
ret["schema"] = TEST_SCHEMA

test/integ/test_arrow_result.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def pandas_verify(cur, data, deserialize):
303303
), f"Result value {value} should match input example {datum}."
304304

305305

306-
@pytest.mark.parametrize("datatype", ICEBERG_UNSUPPORTED_TYPES)
306+
@pytest.mark.parametrize("datatype", sorted(ICEBERG_UNSUPPORTED_TYPES))
307307
def test_iceberg_negative(datatype, conn_cnx, iceberg_support, structured_type_support):
308308
if not iceberg_support:
309309
pytest.skip("Test requires iceberg support.")
@@ -1002,35 +1002,46 @@ def test_select_vector(conn_cnx, is_public_test):
10021002

10031003

10041004
def test_select_time(conn_cnx):
1005-
for scale in range(10):
1006-
select_time_with_scale(conn_cnx, scale)
1007-
1008-
1009-
def select_time_with_scale(conn_cnx, scale):
1005+
# Test key scales and meaningful cases in a single table operation
1006+
# Cover: no fractional seconds, milliseconds, microseconds, nanoseconds
1007+
scales = [0, 3, 6, 9] # Key precision levels
10101008
cases = [
1011-
"00:01:23",
1012-
"00:01:23.1",
1013-
"00:01:23.12",
1014-
"00:01:23.123",
1015-
"00:01:23.1234",
1016-
"00:01:23.12345",
1017-
"00:01:23.123456",
1018-
"00:01:23.1234567",
1019-
"00:01:23.12345678",
1020-
"00:01:23.123456789",
1009+
"00:01:23", # Basic time
1010+
"00:01:23.123456789", # Max precision
1011+
"23:59:59.999999999", # Edge case - max time with max precision
1012+
"00:00:00.000000001", # Edge case - min time with min precision
10211013
]
1022-
table = "test_arrow_time"
1023-
column = f"(a time({scale}))"
1024-
values = (
1025-
"(-1, NULL), ("
1026-
+ "),(".join([f"{i}, '{c}'" for i, c in enumerate(cases)])
1027-
+ f"), ({len(cases)}, NULL)"
1028-
)
1029-
init(conn_cnx, table, column, values)
1030-
sql_text = f"select a from {table} order by s"
1031-
row_count = len(cases) + 2
1032-
col_count = 1
1033-
iterate_over_test_chunk("time", conn_cnx, sql_text, row_count, col_count)
1014+
1015+
table = "test_arrow_time_scales"
1016+
1017+
# Create columns for selected scales only (init function will add 's number' automatically)
1018+
columns = ", ".join([f"a{i} time({i})" for i in scales])
1019+
column_def = f"({columns})"
1020+
1021+
# Create values for selected scales - each case tests all scales simultaneously
1022+
value_rows = []
1023+
for i, case in enumerate(cases):
1024+
# Each row has the same time value for all scale columns
1025+
time_values = ", ".join([f"'{case}'" for _ in scales])
1026+
value_rows.append(f"({i}, {time_values})")
1027+
1028+
# Add NULL rows
1029+
null_values = ", ".join(["NULL" for _ in scales])
1030+
value_rows.append(f"(-1, {null_values})")
1031+
value_rows.append(f"({len(cases)}, {null_values})")
1032+
1033+
values = ", ".join(value_rows)
1034+
1035+
# Single table creation and test
1036+
init(conn_cnx, table, column_def, values)
1037+
1038+
# Test each scale column
1039+
for scale in scales:
1040+
sql_text = f"select a{scale} from {table} order by s"
1041+
row_count = len(cases) + 2
1042+
col_count = 1
1043+
iterate_over_test_chunk("time", conn_cnx, sql_text, row_count, col_count)
1044+
10341045
finish(conn_cnx, table)
10351046

10361047

test/integ/test_connection.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def test_connection_without_database2(db_parameters):
115115

116116
def test_with_config(db_parameters):
117117
"""Creates a connection with the config parameter."""
118+
from ..conftest import get_server_parameter_value
119+
118120
config = {
119121
"user": db_parameters["user"],
120122
"password": db_parameters["password"],
@@ -129,7 +131,22 @@ def test_with_config(db_parameters):
129131
cnx = snowflake.connector.connect(**config)
130132
try:
131133
assert cnx, "invalid cnx"
132-
assert not cnx.client_session_keep_alive # default is False
134+
135+
# Check what the server default is to make test environment-aware
136+
server_default_str = get_server_parameter_value(
137+
cnx, "CLIENT_SESSION_KEEP_ALIVE"
138+
)
139+
if server_default_str:
140+
server_default = server_default_str.lower() == "true"
141+
# Test that connection respects server default when not explicitly set
142+
assert (
143+
cnx.client_session_keep_alive == server_default
144+
), f"Expected client_session_keep_alive={server_default} (server default), got {cnx.client_session_keep_alive}"
145+
else:
146+
# Fallback: if we can't determine server default, expect False
147+
assert (
148+
not cnx.client_session_keep_alive
149+
), "Expected client_session_keep_alive=False when server default unknown"
133150
finally:
134151
cnx.close()
135152

test/integ/test_dbapi.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -728,15 +728,65 @@ def test_escape(conn_local):
728728
with conn_local() as con:
729729
cur = con.cursor()
730730
executeDDL1(cur)
731-
for i in teststrings:
732-
args = {"dbapi_ddl2": i}
733-
cur.execute("insert into %s values (%%(dbapi_ddl2)s)" % TABLE1, args)
734-
cur.execute("select * from %s" % TABLE1)
735-
row = cur.fetchone()
736-
cur.execute("delete from %s where name=%%s" % TABLE1, i)
737-
assert (
738-
i == row[0]
739-
), f"newline not properly converted, got {row[0]}, should be {i}"
731+
732+
# Test 1: Batch INSERT with dictionary parameters (executemany)
733+
# This tests the same dictionary parameter binding as the original
734+
batch_args = [{"dbapi_ddl2": test_string} for test_string in teststrings]
735+
cur.executemany("insert into %s values (%%(dbapi_ddl2)s)" % TABLE1, batch_args)
736+
737+
# Test 2: Batch SELECT with no parameters
738+
# This tests the same SELECT functionality as the original
739+
cur.execute("select name from %s" % TABLE1)
740+
rows = cur.fetchall()
741+
742+
# Verify each test string was properly escaped/handled
743+
assert len(rows) == len(
744+
teststrings
745+
), f"Expected {len(teststrings)} rows, got {len(rows)}"
746+
747+
# Extract actual strings from result set
748+
actual_strings = {row[0] for row in rows} # Use set to ignore order
749+
expected_strings = set(teststrings)
750+
751+
# Verify all expected strings are present
752+
missing_strings = expected_strings - actual_strings
753+
extra_strings = actual_strings - expected_strings
754+
755+
assert len(missing_strings) == 0, f"Missing strings: {missing_strings}"
756+
assert len(extra_strings) == 0, f"Extra strings: {extra_strings}"
757+
assert actual_strings == expected_strings, "String sets don't match"
758+
759+
# Test 3: DELETE with positional parameters (batched for efficiency)
760+
# This maintains the same DELETE parameter binding test as the original
761+
# We test a representative subset to maintain coverage while being efficient
762+
critical_test_strings = [
763+
teststrings[0], # Basic newline: "abc\ndef"
764+
teststrings[5], # Double quote: 'abc"def'
765+
teststrings[7], # Single quote: "abc'def"
766+
teststrings[13], # Tab: "abc\tdef"
767+
teststrings[16], # Backslash-x: "\\x"
768+
]
769+
770+
# Batch DELETE with positional parameters using executemany
771+
# This tests the same positional parameter binding as the original individual DELETEs
772+
cur.executemany(
773+
"delete from %s where name=%%s" % TABLE1,
774+
[(test_string,) for test_string in critical_test_strings],
775+
)
776+
777+
# Batch verification: check that all critical strings were deleted
778+
cur.execute(
779+
"select name from %s where name in (%s)"
780+
% (TABLE1, ",".join(["%s"] * len(critical_test_strings))),
781+
critical_test_strings,
782+
)
783+
remaining_critical = cur.fetchall()
784+
assert (
785+
len(remaining_critical) == 0
786+
), f"Failed to delete strings: {[row[0] for row in remaining_critical]}"
787+
788+
# Clean up remaining rows
789+
cur.execute("delete from %s" % TABLE1)
740790

741791

742792
@pytest.mark.skipolddriver

0 commit comments

Comments
 (0)