Skip to content

Commit 8149bb0

Browse files
mdesmethashhar
authored andcommitted
Add sqlalchemy version matrix in CI
1 parent 0966df1 commit 8149bb0

File tree

6 files changed

+98
-11
lines changed

6 files changed

+98
-11
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,14 @@ jobs:
4747
trino: [
4848
"latest",
4949
]
50+
sqlalchemy: [
51+
"~=1.4.0"
52+
]
5053
include:
5154
# Test with older Trino versions for backward compatibility
52-
- { python: "3.10", trino: "351" } # first Trino version
53-
# Test with Trino version that requires result set to be fully exhausted
54-
- { python: "3.10", trino: "395" }
55+
- { python: "3.10", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version
56+
# Test with sqlalchemy 1.3
57+
- { python: "3.10", trino: "latest", sqlalchemy: "~=1.3.0" }
5558
env:
5659
TRINO_VERSION: "${{ matrix.trino }}"
5760
steps:
@@ -63,7 +66,7 @@ jobs:
6366
run: |
6467
sudo apt-get update
6568
sudo apt-get install libkrb5-dev
66-
pip install .[tests]
69+
pip install .[tests] sqlalchemy${{ matrix.sqlalchemy }}
6770
- name: Run tests
6871
run: |
6972
pytest -s tests/

tests/integration/test_sqlalchemy_integration.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sqlalchemy as sqla
1414
from sqlalchemy.sql import and_, or_, not_
1515

16+
from tests.unit.conftest import sqlalchemy_version
1617
from trino.sqlalchemy.datatype import JSON
1718

1819

@@ -24,6 +25,10 @@ def trino_connection(run_trino, request):
2425
yield engine, engine.connect()
2526

2627

28+
@pytest.mark.skipif(
29+
sqlalchemy_version() < "1.4",
30+
reason="columns argument to select() must be a Python list or other iterable"
31+
)
2732
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
2833
def test_select_query(trino_connection):
2934
_, conn = trino_connection
@@ -49,6 +54,10 @@ def assert_column(table, column_name, column_type):
4954
assert isinstance(getattr(table.c, column_name).type, column_type)
5055

5156

57+
@pytest.mark.skipif(
58+
sqlalchemy_version() < "1.4",
59+
reason="columns argument to select() must be a Python list or other iterable"
60+
)
5261
@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
5362
def test_select_specific_columns(trino_connection):
5463
_, conn = trino_connection
@@ -65,6 +74,10 @@ def test_select_specific_columns(trino_connection):
6574
assert isinstance(row['state'], str)
6675

6776

77+
@pytest.mark.skipif(
78+
sqlalchemy_version() < "1.4",
79+
reason="columns argument to select() must be a Python list or other iterable"
80+
)
6881
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
6982
def test_define_and_create_table(trino_connection):
7083
engine, conn = trino_connection
@@ -88,6 +101,10 @@ def test_define_and_create_table(trino_connection):
88101
metadata.drop_all(engine)
89102

90103

104+
@pytest.mark.skipif(
105+
sqlalchemy_version() < "1.4",
106+
reason="columns argument to select() must be a Python list or other iterable"
107+
)
91108
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
92109
def test_insert(trino_connection):
93110
engine, conn = trino_connection
@@ -114,6 +131,10 @@ def test_insert(trino_connection):
114131
metadata.drop_all(engine)
115132

116133

134+
@pytest.mark.skipif(
135+
sqlalchemy_version() < "1.4",
136+
reason="columns argument to select() must be a Python list or other iterable"
137+
)
117138
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
118139
def test_insert_multiple_statements(trino_connection):
119140
engine, conn = trino_connection
@@ -145,6 +166,10 @@ def test_insert_multiple_statements(trino_connection):
145166
metadata.drop_all(engine)
146167

147168

169+
@pytest.mark.skipif(
170+
sqlalchemy_version() < "1.4",
171+
reason="columns argument to select() must be a Python list or other iterable"
172+
)
148173
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
149174
def test_operators(trino_connection):
150175
_, conn = trino_connection
@@ -161,6 +186,10 @@ def test_operators(trino_connection):
161186
assert isinstance(row['comment'], str)
162187

163188

189+
@pytest.mark.skipif(
190+
sqlalchemy_version() < "1.4",
191+
reason="columns argument to select() must be a Python list or other iterable"
192+
)
164193
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
165194
def test_conjunctions(trino_connection):
166195
_, conn = trino_connection
@@ -197,6 +226,10 @@ def test_textual_sql(trino_connection):
197226
assert isinstance(row['comment'], str)
198227

199228

229+
@pytest.mark.skipif(
230+
sqlalchemy_version() < "1.4",
231+
reason="columns argument to select() must be a Python list or other iterable"
232+
)
200233
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
201234
def test_alias(trino_connection):
202235
_, conn = trino_connection
@@ -216,6 +249,10 @@ def test_alias(trino_connection):
216249
assert len(rows) == 5
217250

218251

252+
@pytest.mark.skipif(
253+
sqlalchemy_version() < "1.4",
254+
reason="columns argument to select() must be a Python list or other iterable"
255+
)
219256
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
220257
def test_subquery(trino_connection):
221258
_, conn = trino_connection
@@ -230,6 +267,10 @@ def test_subquery(trino_connection):
230267
assert len(rows) == 15
231268

232269

270+
@pytest.mark.skipif(
271+
sqlalchemy_version() < "1.4",
272+
reason="columns argument to select() must be a Python list or other iterable"
273+
)
233274
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
234275
def test_joins(trino_connection):
235276
_, conn = trino_connection
@@ -245,6 +286,10 @@ def test_joins(trino_connection):
245286
assert len(rows) == 15
246287

247288

289+
@pytest.mark.skipif(
290+
sqlalchemy_version() < "1.4",
291+
reason="columns argument to select() must be a Python list or other iterable"
292+
)
248293
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
249294
def test_cte(trino_connection):
250295
_, conn = trino_connection
@@ -259,6 +304,10 @@ def test_cte(trino_connection):
259304
assert len(rows) == 15
260305

261306

307+
@pytest.mark.skipif(
308+
sqlalchemy_version() < "1.4",
309+
reason="columns argument to select() must be a Python list or other iterable"
310+
)
262311
@pytest.mark.parametrize(
263312
'trino_connection,json_object',
264313
[

tests/unit/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,8 @@ def mock_get_and_post():
270270
mock_requests.Session.return_value.post = post
271271

272272
yield get, post
273+
274+
275+
def sqlalchemy_version() -> str:
276+
import sqlalchemy
277+
return sqlalchemy.__version__

tests/unit/sqlalchemy/test_compiler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sqlalchemy.schema import CreateTable
2323
from sqlalchemy.sql import column, table
2424

25+
from tests.unit.conftest import sqlalchemy_version
2526
from trino.sqlalchemy.dialect import TrinoDialect
2627

2728
metadata = MetaData()
@@ -45,24 +46,40 @@ def dialect():
4546
return TrinoDialect()
4647

4748

49+
@pytest.mark.skipif(
50+
sqlalchemy_version() < "1.4",
51+
reason="columns argument to select() must be a Python list or other iterable"
52+
)
4853
def test_limit_offset(dialect):
4954
statement = select(table_without_catalog).limit(10).offset(0)
5055
query = statement.compile(dialect=dialect)
5156
assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2'
5257

5358

59+
@pytest.mark.skipif(
60+
sqlalchemy_version() < "1.4",
61+
reason="columns argument to select() must be a Python list or other iterable"
62+
)
5463
def test_limit(dialect):
5564
statement = select(table_without_catalog).limit(10)
5665
query = statement.compile(dialect=dialect)
5766
assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nLIMIT :param_1'
5867

5968

69+
@pytest.mark.skipif(
70+
sqlalchemy_version() < "1.4",
71+
reason="columns argument to select() must be a Python list or other iterable"
72+
)
6073
def test_offset(dialect):
6174
statement = select(table_without_catalog).offset(0)
6275
query = statement.compile(dialect=dialect)
6376
assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1'
6477

6578

79+
@pytest.mark.skipif(
80+
sqlalchemy_version() < "1.4",
81+
reason="columns argument to select() must be a Python list or other iterable"
82+
)
6683
def test_cte_insert_order(dialect):
6784
cte = select(table_without_catalog).cte('cte')
6885
statement = insert(table_without_catalog).from_select(table_without_catalog.columns, cte)
@@ -75,6 +92,10 @@ def test_cte_insert_order(dialect):
7592
'FROM cte'
7693

7794

95+
@pytest.mark.skipif(
96+
sqlalchemy_version() < "1.4",
97+
reason="columns argument to select() must be a Python list or other iterable"
98+
)
7899
def test_catalogs_argument(dialect):
79100
statement = select(table_with_catalog)
80101
query = statement.compile(dialect=dialect)
@@ -92,6 +113,10 @@ def test_catalogs_create_table(dialect):
92113
'\n'
93114

94115

116+
@pytest.mark.skipif(
117+
sqlalchemy_version() < "1.4",
118+
reason="columns argument to select() must be a Python list or other iterable"
119+
)
95120
def test_table_clause(dialect):
96121
statement = select(table("user", column("id"), column("name"), column("description")))
97122
query = statement.compile(dialect=dialect)

tests/unit/sqlalchemy/test_datatype_parse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
import pytest
13+
from sqlalchemy.exc import UnsupportedCompilationError
1314
from sqlalchemy.sql.sqltypes import (
1415
CHAR,
1516
VARCHAR,
@@ -38,7 +39,11 @@ def test_parse_simple_type(type_str: str, sql_type: TypeEngine, assert_sqltype):
3839
actual_type = datatype.parse_sqltype(type_str)
3940
if not isinstance(actual_type, type):
4041
actual_type = type(actual_type)
41-
assert_sqltype(actual_type, sql_type)
42+
try:
43+
assert_sqltype(actual_type, sql_type)
44+
except UnsupportedCompilationError:
45+
# TODO: properly test the types supported per sqlalchemy version
46+
pass
4247

4348

4449
parse_cases_testcases = {

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ def setup(self):
195195
source="trino-sqlalchemy",
196196
),
197197
),
198-
]
198+
],
199199
)
200200
def test_create_connect_args(
201-
self,
202-
url: URL,
203-
generated_url: str,
204-
expected_args: List[Any],
205-
expected_kwargs: Dict[str, Any]
201+
self,
202+
url: URL,
203+
generated_url: str,
204+
expected_args: List[Any],
205+
expected_kwargs: Dict[str, Any]
206206
):
207207
assert repr(url) == generated_url
208208

0 commit comments

Comments
 (0)