|
1 | | -from .utils import sqlite3 |
| 1 | +from .utils import sqlite3, OperationalError |
2 | 2 | from collections import namedtuple |
3 | 3 | import datetime |
4 | 4 | import hashlib |
@@ -456,31 +456,24 @@ def rows_where(self, where=None, where_args=None): |
456 | 456 |
|
457 | 457 | @property |
458 | 458 | def pks(self): |
459 | | - return [column.name for column in self.columns if column.is_pk] |
| 459 | + names = [column.name for column in self.columns if column.is_pk] |
| 460 | + if not names: |
| 461 | + names = ["rowid"] |
| 462 | + return names |
460 | 463 |
|
461 | 464 | def get(self, pk_values): |
462 | 465 | if not isinstance(pk_values, (list, tuple)): |
463 | 466 | pk_values = [pk_values] |
464 | 467 | pks = self.pks |
465 | | - pk_names = [] |
466 | | - if len(pks) == 0: |
467 | | - # rowid table |
468 | | - pk_names = ["rowid"] |
469 | | - last_pk = pk_values[0] |
470 | | - elif len(pks) == 1: |
471 | | - pk_names = [pks[0]] |
472 | | - last_pk = pk_values[0] |
473 | | - elif len(pks) > 1: |
474 | | - pk_names = pks |
475 | | - last_pk = pk_values |
476 | | - if len(pk_names) != len(pk_values): |
| 468 | + last_pk = pk_values[0] if len(pks) == 1 else pk_values |
| 469 | + if len(pks) != len(pk_values): |
477 | 470 | raise NotFoundError( |
478 | 471 | "Need {} primary key value{}".format( |
479 | | - len(pk_names), "" if len(pk_names) == 1 else "s" |
| 472 | + len(pks), "" if len(pks) == 1 else "s" |
480 | 473 | ) |
481 | 474 | ) |
482 | 475 |
|
483 | | - wheres = ["[{}] = ?".format(pk_name) for pk_name in pk_names] |
| 476 | + wheres = ["[{}] = ?".format(pk_name) for pk_name in pks] |
484 | 477 | rows = self.rows_where(" and ".join(wheres), pk_values) |
485 | 478 | try: |
486 | 479 | row = list(rows)[0] |
@@ -784,6 +777,41 @@ def search(self, q): |
784 | 777 | def value_or_default(self, key, value): |
785 | 778 | return self._defaults[key] if value is DEFAULT else value |
786 | 779 |
|
| 780 | + def update(self, pk_values, updates=None, alter=False): |
| 781 | + updates = updates or {} |
| 782 | + if not isinstance(pk_values, (list, tuple)): |
| 783 | + pk_values = [pk_values] |
| 784 | + # Sanity check that the record exists (raises error if not): |
| 785 | + self.get(pk_values) |
| 786 | + if not updates: |
| 787 | + return self |
| 788 | + args = [] |
| 789 | + sets = [] |
| 790 | + wheres = [] |
| 791 | + for key, value in updates.items(): |
| 792 | + sets.append("[{}] = ?".format(key)) |
| 793 | + args.append(value) |
| 794 | + wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks] |
| 795 | + args.extend(pk_values) |
| 796 | + sql = "update [{table}] set {sets} where {wheres}".format( |
| 797 | + table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres) |
| 798 | + ) |
| 799 | + with self.db.conn: |
| 800 | + try: |
| 801 | + rowcount = self.db.conn.execute(sql, args).rowcount |
| 802 | + except OperationalError as e: |
| 803 | + if alter and (" column" in e.args[0]): |
| 804 | + # Attempt to add any missing columns, then try again |
| 805 | + self.add_missing_columns([updates]) |
| 806 | + rowcount = self.db.conn.execute(sql, args).rowcount |
| 807 | + else: |
| 808 | + raise |
| 809 | + |
| 810 | + # TODO: Test this works (rolls back) - use better exception: |
| 811 | + assert rowcount == 1 |
| 812 | + self.last_pk = pk_values[0] if len(self.pks) == 1 else pk_values |
| 813 | + return self |
| 814 | + |
787 | 815 | def insert( |
788 | 816 | self, |
789 | 817 | record, |
@@ -918,15 +946,15 @@ def insert_all( |
918 | 946 | with self.db.conn: |
919 | 947 | try: |
920 | 948 | result = self.db.conn.execute(sql, values) |
921 | | - except sqlite3.OperationalError as e: |
922 | | - if alter and (" has no column " in e.args[0]): |
| 949 | + except OperationalError as e: |
| 950 | + if alter and (" column" in e.args[0]): |
923 | 951 | # Attempt to add any missing columns, then try again |
924 | 952 | self.add_missing_columns(chunk) |
925 | 953 | result = self.db.conn.execute(sql, values) |
926 | 954 | else: |
927 | 955 | raise |
928 | 956 | self.last_rowid = result.lastrowid |
929 | | - self.last_pk = None |
| 957 | + self.last_pk = self.last_rowid |
930 | 958 | # self.last_rowid will be 0 if a "INSERT OR IGNORE" happened |
931 | 959 | if (hash_id or pk) and self.last_rowid: |
932 | 960 | row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0] |
|
0 commit comments