Skip to content

Commit fd0f3be

Browse files
authored
Merge pull request #153 from techouse/feat/support-shorthand-fks
✨ add support for SQLite shorthand `REFERENCE`s
2 parents 5ae7f88 + fa6c27e commit fd0f3be

File tree

2 files changed

+161
-21
lines changed

2 files changed

+161
-21
lines changed

src/sqlite3_to_mysql/transporter.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,25 @@ def _sqlite_quote_ident(name: str) -> str:
297297
"""Return a SQLite identifier with internal quotes escaped."""
298298
return str(name).replace('"', '""')
299299

300+
def _get_table_info(self, table_name: str) -> t.List[t.Dict[str, t.Any]]:
301+
"""Fetch SQLite PRAGMA table information for a table."""
302+
quoted_table_name: str = self._sqlite_quote_ident(table_name)
303+
pragma: str = "table_xinfo" if self._sqlite_table_xinfo_support else "table_info"
304+
self._sqlite_cur.execute(f'PRAGMA {pragma}("{quoted_table_name}")')
305+
return [dict(row) for row in self._sqlite_cur.fetchall()]
306+
307+
def _get_table_primary_key_columns(self, table_name: str) -> t.List[str]:
308+
"""Return visible primary key columns ordered by their PK sequence."""
309+
primary_key_rows: t.List[t.Dict[str, t.Any]] = sorted(
310+
(
311+
column
312+
for column in self._get_table_info(table_name)
313+
if column.get("pk", 0) > 0 and column.get("hidden", 0) not in (1, 2, 3)
314+
),
315+
key=lambda column: column.get("pk", 0),
316+
)
317+
return [safe_identifier_length(column["name"]) for column in primary_key_rows]
318+
300319
def _sqlite_table_has_rowid(self, table: str) -> bool:
301320
try:
302321
quoted_table: str = self._sqlite_quote_ident(table)
@@ -1159,51 +1178,79 @@ def _add_foreign_keys(self, table_name: str) -> None:
11591178
quoted_table_name: str = self._sqlite_quote_ident(table_name)
11601179
self._sqlite_cur.execute(f'PRAGMA foreign_key_list("{quoted_table_name}")')
11611180

1181+
foreign_keys: t.Dict[int, t.List[t.Dict[str, t.Any]]] = {}
11621182
for row in self._sqlite_cur.fetchall():
11631183
foreign_key: t.Dict[str, t.Any] = dict(row)
1184+
foreign_keys.setdefault(int(foreign_key["id"]), []).append(foreign_key)
1185+
1186+
for fk_id, fk_rows in foreign_keys.items():
1187+
fk_rows.sort(key=lambda fk_row: fk_row["seq"])
1188+
ref_table: str = fk_rows[0]["table"]
1189+
from_columns: t.List[str] = [safe_identifier_length(fk_row["from"]) for fk_row in fk_rows]
1190+
1191+
referenced_columns: t.List[str]
1192+
missing_references: t.List[t.Dict[str, t.Any]] = [fk_row for fk_row in fk_rows if not fk_row["to"]]
1193+
if missing_references:
1194+
if len(missing_references) != len(fk_rows):
1195+
self._logger.warning(
1196+
'Skipping foreign key "%s" in table %s: partially defined reference columns.',
1197+
safe_identifier_length(fk_rows[0]["from"]),
1198+
safe_identifier_length(table_name),
1199+
)
1200+
continue
1201+
1202+
primary_keys: t.List[str] = self._get_table_primary_key_columns(ref_table)
1203+
if not primary_keys or len(primary_keys) != len(from_columns):
1204+
self._logger.warning(
1205+
'Skipping foreign key "%s" in table %s: unable to resolve referenced primary key columns from table %s.',
1206+
safe_identifier_length(fk_rows[0]["from"]),
1207+
safe_identifier_length(table_name),
1208+
safe_identifier_length(ref_table),
1209+
)
1210+
continue
1211+
referenced_columns = primary_keys
1212+
else:
1213+
referenced_columns = [safe_identifier_length(fk_row["to"]) for fk_row in fk_rows]
1214+
11641215
sql = """
11651216
ALTER TABLE `{table}`
11661217
ADD CONSTRAINT `{table}_FK_{id}_{seq}`
1167-
FOREIGN KEY (`{column}`)
1168-
REFERENCES `{ref_table}`(`{ref_column}`)
1218+
FOREIGN KEY ({columns})
1219+
REFERENCES `{ref_table}`({ref_columns})
11691220
ON DELETE {on_delete}
11701221
ON UPDATE {on_update}
11711222
""".format(
1172-
id=foreign_key["id"],
1173-
seq=foreign_key["seq"],
1223+
id=fk_id,
1224+
seq=fk_rows[0]["seq"],
11741225
table=safe_identifier_length(table_name),
1175-
column=safe_identifier_length(foreign_key["from"]),
1176-
ref_table=safe_identifier_length(foreign_key["table"]),
1177-
ref_column=safe_identifier_length(foreign_key["to"]),
1226+
columns=", ".join(f"`{column}`" for column in from_columns),
1227+
ref_table=safe_identifier_length(ref_table),
1228+
ref_columns=", ".join(f"`{column}`" for column in referenced_columns),
11781229
on_delete=(
1179-
foreign_key["on_delete"].upper()
1180-
if foreign_key["on_delete"].upper() != "SET DEFAULT"
1181-
else "NO ACTION"
1230+
fk_rows[0]["on_delete"].upper() if fk_rows[0]["on_delete"].upper() != "SET DEFAULT" else "NO ACTION"
11821231
),
11831232
on_update=(
1184-
foreign_key["on_update"].upper()
1185-
if foreign_key["on_update"].upper() != "SET DEFAULT"
1186-
else "NO ACTION"
1233+
fk_rows[0]["on_update"].upper() if fk_rows[0]["on_update"].upper() != "SET DEFAULT" else "NO ACTION"
11871234
),
11881235
)
11891236

11901237
try:
11911238
self._logger.info(
1192-
"Adding foreign key to %s.%s referencing %s.%s",
1239+
"Adding foreign key to %s.(%s) referencing %s.(%s)",
11931240
safe_identifier_length(table_name),
1194-
safe_identifier_length(foreign_key["from"]),
1195-
safe_identifier_length(foreign_key["table"]),
1196-
safe_identifier_length(foreign_key["to"]),
1241+
", ".join(from_columns),
1242+
safe_identifier_length(ref_table),
1243+
", ".join(referenced_columns),
11971244
)
11981245
self._mysql_cur.execute(sql)
11991246
self._mysql.commit()
12001247
except mysql.connector.Error as err:
12011248
self._logger.error(
1202-
"MySQL failed adding foreign key to %s.%s referencing %s.%s: %s",
1249+
"MySQL failed adding foreign key to %s.(%s) referencing %s.(%s): %s",
12031250
safe_identifier_length(table_name),
1204-
safe_identifier_length(foreign_key["from"]),
1205-
safe_identifier_length(foreign_key["table"]),
1206-
safe_identifier_length(foreign_key["to"]),
1251+
", ".join(from_columns),
1252+
safe_identifier_length(ref_table),
1253+
", ".join(referenced_columns),
12071254
err,
12081255
)
12091256
raise

tests/unit/sqlite3_to_mysql_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,99 @@ def execute(self, statement):
15191519
sqlite_cnx.close()
15201520
sqlite_engine.dispose()
15211521

1522+
def test_add_foreign_keys_shorthand_references_primary_key(
1523+
self,
1524+
sqlite_database: str,
1525+
mysql_database: Engine,
1526+
mysql_credentials: MySQLCredentials,
1527+
mocker: MockFixture,
1528+
) -> None:
1529+
proc = SQLite3toMySQL( # type: ignore[call-arg]
1530+
sqlite_file=sqlite_database,
1531+
mysql_user=mysql_credentials.user,
1532+
mysql_password=mysql_credentials.password,
1533+
mysql_host=mysql_credentials.host,
1534+
mysql_port=mysql_credentials.port,
1535+
mysql_database=mysql_credentials.database,
1536+
)
1537+
sqlite_cursor = mocker.MagicMock()
1538+
sqlite_cursor.fetchall.side_effect = [
1539+
[
1540+
{
1541+
"id": 0,
1542+
"seq": 0,
1543+
"table": "parent",
1544+
"from": "parent_id",
1545+
"to": "",
1546+
"on_delete": "NO ACTION",
1547+
"on_update": "NO ACTION",
1548+
}
1549+
],
1550+
[
1551+
{"name": "id", "pk": 1, "hidden": 0},
1552+
],
1553+
]
1554+
proc._sqlite_cur = sqlite_cursor
1555+
proc._sqlite_table_xinfo_support = False
1556+
proc._mysql_cur = mocker.MagicMock()
1557+
proc._mysql = mocker.MagicMock()
1558+
proc._logger = mocker.MagicMock()
1559+
1560+
proc._add_foreign_keys("child")
1561+
1562+
assert proc._mysql_cur.execute.call_count == 1
1563+
executed_sql: str = proc._mysql_cur.execute.call_args[0][0]
1564+
assert "FOREIGN KEY (`parent_id`)" in executed_sql
1565+
assert "REFERENCES `parent`(`id`)" in executed_sql
1566+
proc._mysql.commit.assert_called_once()
1567+
1568+
def test_add_foreign_keys_shorthand_pk_mismatch_is_skipped(
1569+
self,
1570+
sqlite_database: str,
1571+
mysql_database: Engine,
1572+
mysql_credentials: MySQLCredentials,
1573+
mocker: MockFixture,
1574+
) -> None:
1575+
proc = SQLite3toMySQL( # type: ignore[call-arg]
1576+
sqlite_file=sqlite_database,
1577+
mysql_user=mysql_credentials.user,
1578+
mysql_password=mysql_credentials.password,
1579+
mysql_host=mysql_credentials.host,
1580+
mysql_port=mysql_credentials.port,
1581+
mysql_database=mysql_credentials.database,
1582+
)
1583+
sqlite_cursor = mocker.MagicMock()
1584+
sqlite_cursor.fetchall.side_effect = [
1585+
[
1586+
{
1587+
"id": 1,
1588+
"seq": 0,
1589+
"table": "parent",
1590+
"from": "parent_id",
1591+
"to": "",
1592+
"on_delete": "NO ACTION",
1593+
"on_update": "NO ACTION",
1594+
}
1595+
],
1596+
[
1597+
{"name": "id", "pk": 1, "hidden": 0},
1598+
{"name": "second", "pk": 2, "hidden": 0},
1599+
],
1600+
]
1601+
proc._sqlite_cur = sqlite_cursor
1602+
proc._sqlite_table_xinfo_support = False
1603+
proc._mysql_cur = mocker.MagicMock()
1604+
proc._mysql = mocker.MagicMock()
1605+
proc._logger = mocker.MagicMock()
1606+
1607+
proc._add_foreign_keys("child")
1608+
1609+
proc._mysql_cur.execute.assert_not_called()
1610+
assert any(
1611+
"unable to resolve referenced primary key columns" in call.args[0]
1612+
for call in proc._logger.warning.call_args_list
1613+
)
1614+
15221615
@pytest.mark.parametrize("quiet", [False, True])
15231616
def test_add_index_duplicate_key_error(
15241617
self,

0 commit comments

Comments
 (0)