Skip to content

Commit 6fa523c

Browse files
authored
SNOW-629084 follow sqlalchemy 2.0 migration guide to verify the requirements are met (#317)
1 parent 193e596 commit 6fa523c

18 files changed

+608
-452
lines changed

DESCRIPTION.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ Source code is also available at:
99

1010
# Release Notes
1111

12+
- v1.4.1(Unreleased)
13+
- snowflake-sqlalchemy is now SQLAlchemy 2.0 compatible.
14+
1215
- v1.4.0(July 20, 2022)
1316

1417
- Added support for `regexp_match`, `regexp_replace` in `sqlalchemy.sql.expression.ColumnOperators`.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ universal = 1
33

44
[metadata]
55
name = snowflake-sqlalchemy
6-
version = 1.4.0
6+
version = 1.4.1
77
description = Snowflake SQLAlchemy Dialect
88
long_description = file: DESCRIPTION.md
99
long_description_content_type = text/markdown

src/snowflake/sqlalchemy/snowdialect.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -329,15 +329,15 @@ def _get_schema_primary_keys(self, connection, schema, **kw):
329329
ans = {}
330330
key_sequence_order_map = defaultdict(list)
331331
for row in result:
332-
table_name = self.normalize_name(row["table_name"])
333-
key_sequence_order_map[table_name].append(row["key_sequence"])
332+
table_name = self.normalize_name(row._mapping["table_name"])
333+
key_sequence_order_map[table_name].append(row._mapping["key_sequence"])
334334
if table_name not in ans:
335335
ans[table_name] = {
336336
"constrained_columns": [],
337-
"name": self.normalize_name(row["constraint_name"]),
337+
"name": self.normalize_name(row._mapping["constraint_name"]),
338338
}
339339
ans[table_name]["constrained_columns"].append(
340-
self.normalize_name(row["column_name"])
340+
self.normalize_name(row._mapping["column_name"])
341341
)
342342

343343
for k, v in ans.items():
@@ -368,16 +368,16 @@ def _get_schema_unique_constraints(self, connection, schema, **kw):
368368
)
369369
unique_constraints = {}
370370
for row in result:
371-
name = self.normalize_name(row["constraint_name"])
371+
name = self.normalize_name(row._mapping["constraint_name"])
372372
if name not in unique_constraints:
373373
unique_constraints[name] = {
374-
"column_names": [self.normalize_name(row["column_name"])],
374+
"column_names": [self.normalize_name(row._mapping["column_name"])],
375375
"name": name,
376-
"table_name": self.normalize_name(row["table_name"]),
376+
"table_name": self.normalize_name(row._mapping["table_name"]),
377377
}
378378
else:
379379
unique_constraints[name]["column_names"].append(
380-
self.normalize_name(row["column_name"])
380+
self.normalize_name(row._mapping["column_name"])
381381
)
382382

383383
ans = defaultdict(list)
@@ -409,12 +409,14 @@ def _get_schema_foreign_keys(self, connection, schema, **kw):
409409
foreign_key_map = {}
410410
key_sequence_order_map = defaultdict(list)
411411
for row in result:
412-
name = self.normalize_name(row["fk_name"])
413-
key_sequence_order_map[name].append(row["key_sequence"])
412+
name = self.normalize_name(row._mapping["fk_name"])
413+
key_sequence_order_map[name].append(row._mapping["key_sequence"])
414414
if name not in foreign_key_map:
415-
referred_schema = self.normalize_name(row["pk_schema_name"])
415+
referred_schema = self.normalize_name(row._mapping["pk_schema_name"])
416416
foreign_key_map[name] = {
417-
"constrained_columns": [self.normalize_name(row["fk_column_name"])],
417+
"constrained_columns": [
418+
self.normalize_name(row._mapping["fk_column_name"])
419+
],
418420
# referred schema should be None in context where it doesn't need to be specified
419421
# https://docs.sqlalchemy.org/en/14/core/reflection.html#reflection-schema-qualified-interaction
420422
"referred_schema": (
@@ -423,23 +425,31 @@ def _get_schema_foreign_keys(self, connection, schema, **kw):
423425
not in (self.default_schema_name, current_schema)
424426
else None
425427
),
426-
"referred_table": self.normalize_name(row["pk_table_name"]),
427-
"referred_columns": [self.normalize_name(row["pk_column_name"])],
428+
"referred_table": self.normalize_name(
429+
row._mapping["pk_table_name"]
430+
),
431+
"referred_columns": [
432+
self.normalize_name(row._mapping["pk_column_name"])
433+
],
428434
"name": name,
429-
"table_name": self.normalize_name(row["fk_table_name"]),
435+
"table_name": self.normalize_name(row._mapping["fk_table_name"]),
430436
}
431437
options = {}
432-
if self.normalize_name(row["delete_rule"]) != "NO ACTION":
433-
options["ondelete"] = self.normalize_name(row["delete_rule"])
434-
if self.normalize_name(row["update_rule"]) != "NO ACTION":
435-
options["onupdate"] = self.normalize_name(row["update_rule"])
438+
if self.normalize_name(row._mapping["delete_rule"]) != "NO ACTION":
439+
options["ondelete"] = self.normalize_name(
440+
row._mapping["delete_rule"]
441+
)
442+
if self.normalize_name(row._mapping["update_rule"]) != "NO ACTION":
443+
options["onupdate"] = self.normalize_name(
444+
row._mapping["update_rule"]
445+
)
436446
foreign_key_map[name]["options"] = options
437447
else:
438448
foreign_key_map[name]["constrained_columns"].append(
439-
self.normalize_name(row["fk_column_name"])
449+
self.normalize_name(row._mapping["fk_column_name"])
440450
)
441451
foreign_key_map[name]["referred_columns"].append(
442-
self.normalize_name(row["pk_column_name"])
452+
self.normalize_name(row._mapping["pk_column_name"])
443453
)
444454

445455
ans = {}
@@ -847,7 +857,11 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
847857
# the "table" being reflected is actually a view
848858
result = self._get_view_comment(connection, table_name, schema)
849859

850-
return {"text": result["comment"] if result and result["comment"] else None}
860+
return {
861+
"text": result._mapping["comment"]
862+
if result and result._mapping["comment"]
863+
else None
864+
}
851865

852866

853867
@sa_vnt.listens_for(Table, "before_create")

tests/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import time
88
import uuid
9+
from functools import partial
910
from logging import getLogger
1011

1112
import pytest
@@ -27,6 +28,18 @@
2728
else:
2829
TEST_SCHEMA = "sqlalchemy_tests_" + str(uuid.uuid4()).replace("-", "_")
2930

31+
create_engine_with_future_flag = create_engine
32+
33+
34+
def pytest_addoption(parser):
35+
parser.addoption(
36+
"--run_v20_sqlalchemy",
37+
help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported."
38+
"Turning on this option will set future flag to True on Engine and Session objects according to"
39+
"the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html",
40+
action="store_true",
41+
)
42+
3043

3144
@pytest.fixture(scope="session")
3245
def on_travis():
@@ -139,7 +152,7 @@ def get_engine(user=None, password=None, account=None, schema=None):
139152

140153
from sqlalchemy.pool import NullPool
141154

142-
engine = create_engine(
155+
engine = create_engine_with_future_flag(
143156
URL(
144157
user=ret["user"],
145158
password=ret["password"],
@@ -203,6 +216,19 @@ def sql_compiler():
203216
).replace("\n", "")
204217

205218

219+
@pytest.fixture(scope="session")
220+
def run_v20_sqlalchemy(pytestconfig):
221+
return pytestconfig.option.run_v20_sqlalchemy
222+
223+
224+
def pytest_sessionstart(session):
225+
# patch the create_engine with future flag
226+
global create_engine_with_future_flag
227+
create_engine_with_future_flag = partial(
228+
create_engine, future=session.config.option.run_v20_sqlalchemy
229+
)
230+
231+
206232
def running_on_public_ci() -> bool:
207233
"""Whether or not tests are currently running on one of our public CIs."""
208234
return os.getenv("GITHUB_ACTIONS") == "true"

tests/test_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ def test_quoted_name_label(engine_testaccount):
9696

9797
for t in test_cases:
9898
col = column("colname").label(t["label"])
99-
sel_from_tbl = select([col]).group_by(col).select_from(table("abc"))
99+
sel_from_tbl = select(col).group_by(col).select_from(table("abc"))
100100
compiled_result = sel_from_tbl.compile()
101101
assert str(compiled_result) == t["output"]

tests/test_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
6464
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
6565
)
6666
copy_stmt_2 = CopyIntoStorage(
67-
from_=select([food_items]).where(food_items.c.id == 1), # Test sub-query
67+
from_=select(food_items).where(food_items.c.id == 1), # Test sub-query
6868
into=AWSBucket.from_uri("s3://backup")
6969
.credentials(aws_role="some_iam_role")
7070
.encryption_aws_sse_s3(),

0 commit comments

Comments
 (0)