Skip to content

Commit 0747dab

Browse files
authored
table.update() method
* Also now set .last_pk to lastrowid for rowid tables * table.pks introspection now returns ["rowid"] for rowid tables Closes #35
2 parents a6749cd + 16d7008 commit 0747dab

File tree

6 files changed

+162
-20
lines changed

6 files changed

+162
-20
lines changed

docs/python-api.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,30 @@ The function can accept an iterator or generator of rows and will commit them ac
331331
332332
You can skip inserting any records that have a primary key that already exists using ``ignore=True``. This works with both ``.insert({...}, ignore=True)`` and ``.insert_all([...], ignore=True)``.
333333

334+
.. _python_api_update:
335+
336+
Updating a specific record
337+
==========================
338+
339+
You can update a record by its primary key using ``table.update()``::
340+
341+
>>> db = sqlite_utils.Database("dogs.db")
342+
>>> print(db["dogs"].get(1))
343+
{'id': 1, 'age': 4, 'name': 'Cleo'}
344+
>>> db["dogs"].update(1, {"age": 5})
345+
>>> print(db["dogs"].get(1))
346+
{'id': 1, 'age': 5, 'name': 'Cleo'}
347+
348+
The first argument to ``update()`` is the primary key. This can be a single value, or a tuple if that table has a compound primary key::
349+
350+
>>> db["compound_dogs"].update((5, 3), {"name": "Updated"})
351+
352+
The second argument is a dictonary of columns that should be updated, along with their new values.
353+
354+
You can cause any missing columns to be added automatically using ``alter=True``::
355+
356+
>>> db["dogs"].update(1, {"breed": "Mutt"}, alter=True)
357+
334358
Upserting data
335359
==============
336360

sqlite_utils/db.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .utils import sqlite3
1+
from .utils import sqlite3, OperationalError
22
from collections import namedtuple
33
import datetime
44
import hashlib
@@ -456,31 +456,24 @@ def rows_where(self, where=None, where_args=None):
456456

457457
@property
458458
def pks(self):
459-
return [column.name for column in self.columns if column.is_pk]
459+
names = [column.name for column in self.columns if column.is_pk]
460+
if not names:
461+
names = ["rowid"]
462+
return names
460463

461464
def get(self, pk_values):
462465
if not isinstance(pk_values, (list, tuple)):
463466
pk_values = [pk_values]
464467
pks = self.pks
465-
pk_names = []
466-
if len(pks) == 0:
467-
# rowid table
468-
pk_names = ["rowid"]
469-
last_pk = pk_values[0]
470-
elif len(pks) == 1:
471-
pk_names = [pks[0]]
472-
last_pk = pk_values[0]
473-
elif len(pks) > 1:
474-
pk_names = pks
475-
last_pk = pk_values
476-
if len(pk_names) != len(pk_values):
468+
last_pk = pk_values[0] if len(pks) == 1 else pk_values
469+
if len(pks) != len(pk_values):
477470
raise NotFoundError(
478471
"Need {} primary key value{}".format(
479-
len(pk_names), "" if len(pk_names) == 1 else "s"
472+
len(pks), "" if len(pks) == 1 else "s"
480473
)
481474
)
482475

483-
wheres = ["[{}] = ?".format(pk_name) for pk_name in pk_names]
476+
wheres = ["[{}] = ?".format(pk_name) for pk_name in pks]
484477
rows = self.rows_where(" and ".join(wheres), pk_values)
485478
try:
486479
row = list(rows)[0]
@@ -784,6 +777,41 @@ def search(self, q):
784777
def value_or_default(self, key, value):
785778
return self._defaults[key] if value is DEFAULT else value
786779

780+
def update(self, pk_values, updates=None, alter=False):
781+
updates = updates or {}
782+
if not isinstance(pk_values, (list, tuple)):
783+
pk_values = [pk_values]
784+
# Sanity check that the record exists (raises error if not):
785+
self.get(pk_values)
786+
if not updates:
787+
return self
788+
args = []
789+
sets = []
790+
wheres = []
791+
for key, value in updates.items():
792+
sets.append("[{}] = ?".format(key))
793+
args.append(value)
794+
wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks]
795+
args.extend(pk_values)
796+
sql = "update [{table}] set {sets} where {wheres}".format(
797+
table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres)
798+
)
799+
with self.db.conn:
800+
try:
801+
rowcount = self.db.conn.execute(sql, args).rowcount
802+
except OperationalError as e:
803+
if alter and (" column" in e.args[0]):
804+
# Attempt to add any missing columns, then try again
805+
self.add_missing_columns([updates])
806+
rowcount = self.db.conn.execute(sql, args).rowcount
807+
else:
808+
raise
809+
810+
# TODO: Test this works (rolls back) - use better exception:
811+
assert rowcount == 1
812+
self.last_pk = pk_values[0] if len(self.pks) == 1 else pk_values
813+
return self
814+
787815
def insert(
788816
self,
789817
record,
@@ -918,15 +946,15 @@ def insert_all(
918946
with self.db.conn:
919947
try:
920948
result = self.db.conn.execute(sql, values)
921-
except sqlite3.OperationalError as e:
922-
if alter and (" has no column " in e.args[0]):
949+
except OperationalError as e:
950+
if alter and (" column" in e.args[0]):
923951
# Attempt to add any missing columns, then try again
924952
self.add_missing_columns(chunk)
925953
result = self.db.conn.execute(sql, values)
926954
else:
927955
raise
928956
self.last_rowid = result.lastrowid
929-
self.last_pk = None
957+
self.last_pk = self.last_rowid
930958
# self.last_rowid will be 0 if a "INSERT OR IGNORE" happened
931959
if (hash_id or pk) and self.last_rowid:
932960
row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0]

sqlite_utils/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
try:
22
import pysqlite3 as sqlite3
3+
import pysqlite3.dbapi2
4+
5+
OperationalError = pysqlite3.dbapi2.OperationalError
36
except ImportError:
47
import sqlite3
8+
9+
OperationalError = sqlite3.OperationalError

tests/test_create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
NoObviousTable,
77
ForeignKey,
88
)
9+
from sqlite_utils.utils import sqlite3
910
import collections
1011
import datetime
1112
import json
1213
import pathlib
1314
import pytest
14-
import sqlite3
1515

1616
from .utils import collapse_whitespace
1717

tests/test_introspect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,11 @@ def test_guess_foreign_table(fresh_db, column, expected_table_guess):
9898
fresh_db.create_table("authors", {"name": str})
9999
fresh_db.create_table("genre", {"name": str})
100100
assert expected_table_guess == fresh_db["books"].guess_foreign_table(column)
101+
102+
103+
@pytest.mark.parametrize(
104+
"pk,expected", ((None, ["rowid"]), ("id", ["id"]), (["id", "id2"], ["id", "id2"]))
105+
)
106+
def test_pks(fresh_db, pk, expected):
107+
fresh_db["foo"].insert_all([{"id": 1, "id2": 2}], pk=pk)
108+
assert expected == fresh_db["foo"].pks

tests/test_update.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from sqlite_utils.db import NotFoundError
2+
import pytest
3+
4+
5+
def test_update_rowid_table(fresh_db):
6+
table = fresh_db["table"]
7+
rowid = table.insert({"foo": "bar"}).last_pk
8+
table.update(rowid, {"foo": "baz"})
9+
assert [{"foo": "baz"}] == list(table.rows)
10+
11+
12+
def test_update_pk_table(fresh_db):
13+
table = fresh_db["table"]
14+
pk = table.insert({"foo": "bar", "id": 5}, pk="id").last_pk
15+
assert 5 == pk
16+
table.update(pk, {"foo": "baz"})
17+
assert [{"id": 5, "foo": "baz"}] == list(table.rows)
18+
19+
20+
def test_update_compound_pk_table(fresh_db):
21+
table = fresh_db["table"]
22+
pk = table.insert({"id1": 5, "id2": 3, "v": 1}, pk=("id1", "id2")).last_pk
23+
assert (5, 3) == pk
24+
table.update(pk, {"v": 2})
25+
assert [{"id1": 5, "id2": 3, "v": 2}] == list(table.rows)
26+
27+
28+
@pytest.mark.parametrize(
29+
"pk,update_pk",
30+
(
31+
(None, 2),
32+
(None, None),
33+
("id1", None),
34+
("id1", 4),
35+
(("id1", "id2"), None),
36+
(("id1", "id2"), 4),
37+
(("id1", "id2"), (4, 5)),
38+
),
39+
)
40+
def test_update_invalid_pk(fresh_db, pk, update_pk):
41+
table = fresh_db["table"]
42+
table.insert({"id1": 5, "id2": 3, "v": 1}, pk=pk).last_pk
43+
with pytest.raises(NotFoundError):
44+
table.update(update_pk, {"v": 2})
45+
46+
47+
def test_update_alter(fresh_db):
48+
table = fresh_db["table"]
49+
rowid = table.insert({"foo": "bar"}).last_pk
50+
table.update(rowid, {"new_col": 1.2}, alter=True)
51+
assert [{"foo": "bar", "new_col": 1.2}] == list(table.rows)
52+
# Let's try adding three cols at once
53+
table.update(
54+
rowid,
55+
{"str_col": "str", "bytes_col": b"\xa0 has bytes", "int_col": -10},
56+
alter=True,
57+
)
58+
assert [
59+
{
60+
"foo": "bar",
61+
"new_col": 1.2,
62+
"str_col": "str",
63+
"bytes_col": b"\xa0 has bytes",
64+
"int_col": -10,
65+
}
66+
] == list(table.rows)
67+
68+
69+
def test_update_with_no_values_sets_last_pk(fresh_db):
70+
table = fresh_db.table("dogs", pk="id")
71+
table.insert_all([{"id": 1, "name": "Cleo"}, {"id": 2, "name": "Pancakes"}])
72+
table.update(1)
73+
assert 1 == table.last_pk
74+
table.update(2)
75+
assert 2 == table.last_pk
76+
with pytest.raises(NotFoundError):
77+
table.update(3)

0 commit comments

Comments
 (0)