Skip to content

Commit 9fcbb2a

Browse files
committed
Robustify SQLA version determination logic
1 parent 3dd6556 commit 9fcbb2a

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ development = [
5050
"pytest",
5151
"setuptools",
5252
"pytest-cov",
53+
"pytest-mock",
5354
"pytest-timeout",
5455
"pytest-rerunfailures",
5556
"pytest-xdist",

src/snowflake/sqlalchemy/compat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import functools
6+
import re
67
from typing import Callable
78

89
from sqlalchemy import __version__ as SA_VERSION
@@ -11,7 +12,8 @@
1112
string_types = (str,)
1213
returns_unicode = util.symbol("RETURNS_UNICODE")
1314

14-
IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0)
15+
_match = re.match(r"(\d+)\.(\d+)\.(\d+)", SA_VERSION)
16+
IS_VERSION_20 = tuple(int(x) for x in _match.groups()) >= (2, 0, 0)
1517

1618

1719
def args_reducer(positions_to_drop: tuple):

tests/test_compat.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import importlib
2+
3+
import pytest
4+
5+
import snowflake.sqlalchemy.compat as compat_module
6+
7+
8+
@pytest.mark.parametrize(
9+
("version", "expected"),
10+
[
11+
("2.0.0", True),
12+
("3.1.0", True),
13+
("1.3.0", False),
14+
("2.0.5.post1", True),
15+
("2.0.0rc2", True),
16+
("2.0.0b1", True),
17+
("0.5.0beta3", False),
18+
("0.4.2a", False),
19+
],
20+
)
21+
def test_is_version_20(version, expected, mocker):
22+
mocker.patch("sqlalchemy.__version__", version)
23+
importlib.reload(compat_module)
24+
assert compat_module.IS_VERSION_20 == expected

0 commit comments

Comments
 (0)