Skip to content

Commit 41497c3

Browse files
committed
Fix for use_old_upsert branch
Refs #653 (comment)
1 parent efeba1f commit 41497c3

File tree

3 files changed

+52
-44
lines changed

3 files changed

+52
-44
lines changed

sqlite_utils/db.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,45 +3086,48 @@ def build_insert_queries_and_params(
30863086
# At this point we need compatibility UPSERT for SQLite < 3.24.0
30873087
# (INSERT OR IGNORE + second UPDATE stage)
30883088
queries_and_params = []
3089-
3090-
insert_sql = (
3091-
f"INSERT OR IGNORE INTO [{self.name}] "
3092-
f"({columns_sql}) VALUES {row_placeholders_sql}"
3093-
)
3094-
queries_and_params.append((insert_sql, flat_params))
3095-
3096-
# If there is nothing to update we are done.
3097-
if not non_pk_cols:
3098-
return queries_and_params
3099-
3100-
# We can use UPDATE … FROM (VALUES …) on SQLite ≥ 3.33.0
3101-
# Older SQLite versions will run this as one UPDATE per row
3102-
# – which is what sqlite-utils did prior to this refactor.
3103-
alias_cols_sql = ", ".join(pk_cols + non_pk_cols)
3104-
3105-
assignments = []
3106-
for c in non_pk_cols:
3107-
if c in conversions:
3108-
assignments.append(f"[{c}] = {conversions[c].replace('?', f'v.[{c}]')}")
3109-
else:
3110-
assignments.append(f"[{c}] = v.[{c}]")
3111-
assignments_sql = ", ".join(assignments)
3112-
3113-
update_sql = (
3114-
f"UPDATE [{self.name}] AS m SET {assignments_sql} "
3115-
f"FROM (VALUES {row_placeholders_sql}) "
3116-
f"AS v({alias_cols_sql}) "
3117-
f"WHERE " + " AND ".join(f"m.[{c}] = v.[{c}]" for c in pk_cols)
3118-
)
3119-
3120-
# Parameters for the UPDATE – pk cols first then non-pk cols
3121-
update_params = []
3122-
for row in values:
3123-
row_dict = dict(zip(all_columns, row))
3124-
ordered = [row_dict[c] for c in pk_cols + non_pk_cols]
3125-
update_params.extend(ordered)
3126-
3127-
queries_and_params.append((update_sql, update_params))
3089+
if isinstance(pk, str):
3090+
pks = [pk]
3091+
else:
3092+
pks = pk
3093+
self.last_pk = None
3094+
for record_values in values:
3095+
record = dict(zip(all_columns, record_values))
3096+
placeholders = list(pks)
3097+
# Need to populate not-null columns too, or INSERT OR IGNORE ignores
3098+
# them since it ignores the resulting integrity errors
3099+
if not_null:
3100+
placeholders.extend(not_null)
3101+
sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format(
3102+
table=self.name,
3103+
cols=", ".join(["[{}]".format(p) for p in placeholders]),
3104+
placeholders=", ".join(["?" for p in placeholders]),
3105+
)
3106+
queries_and_params.append(
3107+
(sql, [record[col] for col in pks] + ["" for _ in (not_null or [])])
3108+
)
3109+
# UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001;
3110+
set_cols = [col for col in all_columns if col not in pks]
3111+
if set_cols:
3112+
sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format(
3113+
table=self.name,
3114+
pairs=", ".join(
3115+
"[{}] = {}".format(col, conversions.get(col, "?"))
3116+
for col in set_cols
3117+
),
3118+
wheres=" AND ".join("[{}] = ?".format(pk) for pk in pks),
3119+
)
3120+
queries_and_params.append(
3121+
(
3122+
sql2,
3123+
[record[col] for col in set_cols] + [record[pk] for pk in pks],
3124+
)
3125+
)
3126+
# We can populate .last_pk right here
3127+
if num_records_processed == 1:
3128+
self.last_pk = tuple(record[pk] for pk in pks)
3129+
if len(self.last_pk) == 1:
3130+
self.last_pk = self.last_pk[0]
31283131
return queries_and_params
31293132

31303133
def insert_chunk(

tests/test_create.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,16 @@ def test_create_table_from_example_with_compound_primary_keys(fresh_db):
173173
@pytest.mark.parametrize(
174174
"method_name", ("insert", "upsert", "insert_all", "upsert_all")
175175
)
176-
def test_create_table_with_custom_columns(fresh_db, method_name):
177-
table = fresh_db["dogs"]
176+
@pytest.mark.parametrize("use_old_upsert", (False, True))
177+
def test_create_table_with_custom_columns(method_name, use_old_upsert):
178+
db = Database(memory=True, use_old_upsert=use_old_upsert)
179+
table = db["dogs"]
178180
method = getattr(table, method_name)
179181
record = {"id": 1, "name": "Cleo", "age": "5"}
180182
if method_name.endswith("_all"):
181183
record = [record]
182184
method(record, pk="id", columns={"age": int, "weight": float})
183-
assert ["dogs"] == fresh_db.table_names()
185+
assert ["dogs"] == db.table_names()
184186
expected_columns = [
185187
{"name": "id", "type": "INTEGER"},
186188
{"name": "name", "type": "TEXT"},

tests/test_upsert.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from sqlite_utils.db import PrimaryKeyRequired
2+
from sqlite_utils import Database
23
import pytest
34

45

5-
def test_upsert(fresh_db):
6-
table = fresh_db["table"]
6+
@pytest.mark.parametrize("use_old_upsert", (False, True))
7+
def test_upsert(use_old_upsert):
8+
db = Database(memory=True, use_old_upsert=use_old_upsert)
9+
table = db["table"]
710
table.insert({"id": 1, "name": "Cleo"}, pk="id")
811
table.upsert({"id": 1, "age": 5}, pk="id", alter=True)
912
assert list(table.rows) == [{"id": 1, "name": "Cleo", "age": 5}]

0 commit comments

Comments
 (0)