From c2e8270b9f86a074ef5549033820fc0157febc57 Mon Sep 17 00:00:00 2001 From: Peter Gaultney Date: Tue, 2 Jan 2024 14:49:37 -0600 Subject: [PATCH 1/2] work-in-progress FTS for attached dbs --- sqlite_utils/db.py | 200 +++++++++++++++++++++++++++++---------------- 1 file changed, 130 insertions(+), 70 deletions(-) diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 371eed9c..a60df432 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -601,6 +601,10 @@ def quote_default_value(self, value: str) -> str: return self.quote(value) + def database_names(self) -> List[str]: + "List of string database names available in this connection." + return [r[1] for r in self.execute("PRAGMA database_list").fetchall()] + def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]: """ List of string table names in this database. @@ -614,7 +618,20 @@ def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]: if fts5: where.append("sql like '%USING FTS5%'") sql = "select name from sqlite_master where {}".format(" AND ".join(where)) - return [r[0] for r in self.execute(sql).fetchall()] + + def _exec_in_db(db_name: str, sql: str) -> List[str]: + if db_name == "main": + db_name = "" + if db_name: + sql = sql.replace("sqlite_master", f"{db_name}.sqlite_master") + table_names = [r[0] for r in self.execute(sql).fetchall()] + if db_name: + return [f"{db_name}.{tbl_name}" for tbl_name in table_names] + return table_names + + return list( + itertools.chain(*[_exec_in_db(db_name, sql) for db_name in self.database_names()]) + ) def view_names(self) -> List[str]: "List of string view names in this database." @@ -1271,12 +1288,34 @@ def init_spatialite(self, path: Optional[str] = None) -> bool: return result and bool(result[0]) +def _split_names(fullname: str) -> Tuple[str, str]: + if '.' not in fullname: + return '', fullname + return fullname.split('.') + + +def dbname(fullname: str) -> str: + return _split_names(fullname)[0] + + +def tablename(fullname: str) -> str: + return _split_names(fullname)[1] + + +def escaped_name(fullname: str) -> str: + """This is how SQLite expects a database name joined to a table name to use the square-bracket escapes.""" + db, tbl = _split_names(fullname) + if not db: + return f'[{tbl}]' + return f'{db}.[{tbl}]' + + class Queryable: def exists(self) -> bool: "Does this table or view exist yet?" return False - def __init__(self, db, name): + def __init__(self, db, name: str): self.db = db self.name = name @@ -1292,7 +1331,7 @@ def count_where( :param where_args: Parameters to use with that fragment - an iterable for ``id > ?`` parameters, or a dictionary for ``id > :id`` """ - sql = "select count(*) from [{}]".format(self.name) + sql = "select count(*) from {}".format(escaped_name(self.name)) if where is not None: sql += " where " + where return self.db.execute(sql, where_args or []).fetchone()[0] @@ -1335,7 +1374,7 @@ def rows_where( """ if not self.exists(): return - sql = "select {} from [{}]".format(select, self.name) + sql = "select {} from {}".format(select, escaped_name(self.name)) if where is not None: sql += " where " + where if order_by is not None: @@ -1387,12 +1426,23 @@ def pks_and_rows_where( row_pk = row_pk[0] yield row_pk, row + @property + def is_attached(self) -> bool: + return dbname(self.name) not in {'', 'main'} + + @property + def _pragma_name(self) -> Tuple[str, str]: + if "." in self.name: + db, name = self.name.split(".") + return db + ".", name + return "", self.name + @property def columns(self) -> List["Column"]: "List of :ref:`Columns ` representing the columns in this table or view." if not self.exists(): return [] - rows = self.db.execute("PRAGMA table_info([{}])".format(self.name)).fetchall() + rows = self.db.execute("PRAGMA {}table_info([{}])".format(*self._pragma_name)).fetchall() return [Column(*row) for row in rows] @property @@ -1403,9 +1453,10 @@ def columns_dict(self) -> Dict[str, Any]: @property def schema(self) -> str: "SQL schema for this table or view." - return self.db.execute( - "select sql from sqlite_master where name = ?", (self.name,) - ).fetchone()[0] + db, name = self._pragma_name + return self.db.execute(f"select sql from {db}sqlite_master where name = ?", (name,)).fetchone()[ + 0 + ] class Table(Queryable): @@ -1544,7 +1595,7 @@ def foreign_keys(self) -> List["ForeignKey"]: "List of foreign keys defined on this table." fks = [] for row in self.db.execute( - "PRAGMA foreign_key_list([{}])".format(self.name) + "PRAGMA {}foreign_key_list([{}])".format(*self._pragma_name) ).fetchall(): if row is not None: id, seq, table_name, from_, to_, on_update, on_delete, match = row @@ -1569,7 +1620,8 @@ def virtual_table_using(self) -> Optional[str]: @property def indexes(self) -> List[Index]: "List of indexes defined on this table." - sql = 'PRAGMA index_list("{}")'.format(self.name) + db, table_name = self._pragma_name + sql = 'PRAGMA {}index_list("{}")'.format(db, table_name) indexes = [] for row in self.db.execute_returning_dicts(sql): index_name = row["name"] @@ -1578,7 +1630,7 @@ def indexes(self) -> List[Index]: if not index_name.startswith('"') else index_name ) - column_sql = "PRAGMA index_info({})".format(index_name_quoted) + column_sql = "PRAGMA {}index_info({})".format(db, index_name_quoted) columns = [] for seqno, cid, name in self.db.execute(column_sql).fetchall(): columns.append(name) @@ -1593,7 +1645,8 @@ def indexes(self) -> List[Index]: @property def xindexes(self) -> List[XIndex]: "List of indexes defined on this table using the more detailed ``XIndex`` format." - sql = 'PRAGMA index_list("{}")'.format(self.name) + db, table_name = self._pragma_name + sql = 'PRAGMA {}index_list("{}")'.format(db, table_name) indexes = [] for row in self.db.execute_returning_dicts(sql): index_name = row["name"] @@ -1602,7 +1655,7 @@ def xindexes(self) -> List[XIndex]: if not index_name.startswith('"') else index_name ) - column_sql = "PRAGMA index_xinfo({})".format(index_name_quoted) + column_sql = "PRAGMA {}index_xinfo({})".format(db, index_name_quoted) index_columns = [] for info in self.db.execute(column_sql).fetchall(): index_columns.append(XIndexColumn(*info)) @@ -1612,12 +1665,13 @@ def xindexes(self) -> List[XIndex]: @property def triggers(self) -> List[Trigger]: "List of triggers defined on this table." + db, table_name = self._pragma_name return [ Trigger(*r) for r in self.db.execute( - "select name, tbl_name, sql from sqlite_master where type = 'trigger'" + f"select name, tbl_name, sql from {db}sqlite_master where type = 'trigger'" " and tbl_name = ?", - (self.name,), + (table_name,), ).fetchall() ] @@ -1709,9 +1763,9 @@ def duplicate(self, new_name: str) -> "Table": if not self.exists(): raise NoTable(f"Table {self.name} does not exist") with self.db.conn: - sql = "CREATE TABLE [{new_table}] AS SELECT * FROM [{table}];".format( - new_table=new_name, - table=self.name, + sql = "CREATE TABLE {new_table} AS SELECT * FROM {table};".format( + new_table=escaped_name(new_name), + table=escaped_name(self.name), ) self.db.execute(sql) return self.db[new_name] @@ -1765,21 +1819,22 @@ def transform( column_order=column_order, keep_table=keep_table, ) - pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[ + db, _ = self._pragma_name + pragma_foreign_keys_was_on = self.db.execute(f"PRAGMA {db}foreign_keys").fetchone()[ 0 ] try: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=0;") + self.db.execute(f"PRAGMA {db}foreign_keys=0;") with self.db.conn: for sql in sqls: self.db.execute(sql) # Run the foreign_key_check before we commit if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_key_check;") + self.db.execute(f"PRAGMA {db}foreign_key_check;") finally: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=1;") + self.db.execute(f"PRAGMA {db}foreign_keys=1;") return self def transform_sql( @@ -1944,9 +1999,9 @@ def transform_sql( if "rowid" not in new_cols: new_cols.insert(0, "rowid") old_cols.insert(0, "rowid") - copy_sql = "INSERT INTO [{new_table}] ({new_cols})\n SELECT {old_cols} FROM [{old_table}];".format( - new_table=new_table_name, - old_table=self.name, + copy_sql = "INSERT INTO {new_table} ({new_cols})\n SELECT {old_cols} FROM {old_table};".format( + new_table=escaped_name(new_table_name), + old_table=escaped_name(self.name), old_cols=", ".join("[{}]".format(col) for col in old_cols), new_cols=", ".join("[{}]".format(col) for col in new_cols), ) @@ -1954,13 +2009,13 @@ def transform_sql( # Drop (or keep) the old table if keep_table: sqls.append( - "ALTER TABLE [{}] RENAME TO [{}];".format(self.name, keep_table) + "ALTER TABLE {} RENAME TO {};".format(escaped_name(self.name), escaped_name(keep_table)) ) else: - sqls.append("DROP TABLE [{}];".format(self.name)) + sqls.append("DROP TABLE {};".format(escaped_name(self.name))) # Rename the new one sqls.append( - "ALTER TABLE [{}] RENAME TO [{}];".format(new_table_name, self.name) + "ALTER TABLE {} RENAME TO {};".format(escaped_name(new_table_name), escaped_name(self.name)) ) return sqls @@ -2023,11 +2078,11 @@ def extract( lookup_columns = [(rename.get(col) or col) for col in columns] lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True) self.db.execute( - "INSERT OR IGNORE INTO [{lookup_table}] ({lookup_columns}) SELECT DISTINCT {table_cols} FROM [{table}]".format( - lookup_table=table, + "INSERT OR IGNORE INTO {lookup_table} ({lookup_columns}) SELECT DISTINCT {table_cols} FROM {table}".format( + lookup_table=escaped_name(table), lookup_columns=", ".join("[{}]".format(c) for c in lookup_columns), table_cols=", ".join("[{}]".format(c) for c in columns), - table=self.name, + table=escaped_name(self.name), ) ) @@ -2036,14 +2091,14 @@ def extract( # And populate it self.db.execute( - "UPDATE [{table}] SET [{magic_lookup_column}] = (SELECT id FROM [{lookup_table}] WHERE {where})".format( - table=self.name, + "UPDATE {table} SET [{magic_lookup_column}] = (SELECT id FROM {lookup_table} WHERE {where})".format( + table=escaped_name(self.name), magic_lookup_column=magic_lookup_column, - lookup_table=table, + lookup_table=escaped_name(table), where=" AND ".join( - "[{table}].[{column}] IS [{lookup_table}].[{lookup_column}]".format( - table=self.name, - lookup_table=table, + "{table}.[{column}] IS {lookup_table}.[{lookup_column}]".format( + table=escaped_name(self.name), + lookup_table=escaped_name(table), column=column, lookup_column=rename.get(column) or column, ) @@ -2117,13 +2172,13 @@ def create_index( textwrap.dedent( """ CREATE {unique}INDEX {if_not_exists}[{index_name}] - ON [{table_name}] ({columns}); + ON {table_name} ({columns}); """ ) .strip() .format( index_name=created_index_name, - table_name=self.name, + table_name=escaped_name(self.name), columns=", ".join(columns_sql), unique="UNIQUE " if unique else "", if_not_exists="IF NOT EXISTS " if if_not_exists else "", @@ -2193,8 +2248,8 @@ def add_column( not_null_sql = "NOT NULL DEFAULT {}".format( self.db.quote_default_value(not_null_default) ) - sql = "ALTER TABLE [{table}] ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( - table=self.name, + sql = "ALTER TABLE {table} ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( + table=escaped_name(self.name), col_name=col_name, col_type=fk_col_type or COLUMN_TYPE_MAPPING[col_type], not_null_default=(" " + not_null_sql) if not_null_sql else "", @@ -2211,7 +2266,7 @@ def drop(self, ignore: bool = False): :param ignore: Set to ``True`` to ignore the error if the table does not exist """ try: - self.db.execute("DROP TABLE [{}]".format(self.name)) + self.db.execute("DROP TABLE {}".format(escaped_name(self.name))) except sqlite3.OperationalError: if not ignore: raise @@ -2378,6 +2433,9 @@ def enable_fts( """ Enable SQLite full-text search against the specified columns. + Creates the FTS virtual table(s) in the `main` database, even if the + source table is in an attached database. + See :ref:`python_api_fts` for more details. :param columns: List of column names to include in the search index. @@ -2386,6 +2444,7 @@ def enable_fts( :param tokenize: Custom SQLite tokenizer to use, for example ``"porter"`` to enable Porter stemming. :param replace: Should any existing FTS index for this table be replaced by the new one? """ + table_name = tablename(self.name) create_fts_sql = ( textwrap.dedent( """ @@ -2397,19 +2456,19 @@ def enable_fts( ) .strip() .format( - table=self.name, + table=table_name, columns=", ".join("[{}]".format(c) for c in columns), fts_version=fts_version, tokenize="\n tokenize='{}',".format(tokenize) if tokenize else "", ) ) should_recreate = False - if replace and self.db["{}_fts".format(self.name)].exists(): + if replace and self.db["{}_fts".format(table_name)].exists(): # Does the table need to be recreated? - fts_schema = self.db["{}_fts".format(self.name)].schema + fts_schema = self.db["{}_fts".format(table_name)].schema if fts_schema != create_fts_sql: should_recreate = True - expected_triggers = {self.name + suffix for suffix in ("_ai", "_ad", "_au")} + expected_triggers = {table_name + suffix for suffix in ("_ai", "_ad", "_au")} existing_triggers = {t.name for t in self.triggers} has_triggers = existing_triggers.issuperset(expected_triggers) if has_triggers != create_triggers: @@ -2444,7 +2503,7 @@ def enable_fts( ) .strip() .format( - table=self.name, + table=table_name, columns=", ".join("[{}]".format(c) for c in columns), old_cols=old_cols, new_cols=new_cols, @@ -2469,7 +2528,7 @@ def populate_fts(self, columns: Iterable[str]) -> "Table": ) .strip() .format( - table=self.name, columns=", ".join("[{}]".format(c) for c in columns) + table=tablename(self.name), columns=", ".join("[{}]".format(c) for c in columns) ) ) self.db.executescript(sql) @@ -2505,9 +2564,9 @@ def rebuild_fts(self): fts_table = self.detect_fts() if fts_table is None: # Assume this is itself an FTS table - fts_table = self.name + fts_table = escaped_name(self.name) self.db.execute( - "INSERT INTO [{table}]([{table}]) VALUES('rebuild');".format( + "INSERT INTO {table}({table}) VALUES('rebuild');".format( table=fts_table ) ) @@ -2529,10 +2588,11 @@ def detect_fts(self) -> Optional[str]: ) """ ).strip() + table_name = tablename(self.name) args = { - "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(self.name), - "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(self.name), - "table": self.name, + "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(table_name), + "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(table_name), + "table": table_name, } rows = self.db.execute(sql, args).fetchall() if len(rows) == 0: @@ -2592,7 +2652,7 @@ def search_sql( select rowid, {columns} - from [{dbtable}]{where_clause} + from {dbtable}{where_clause} ) select {columns_with_prefix} @@ -2621,7 +2681,7 @@ def search_sql( if offset is not None: limit_offset += " offset {}".format(offset) return sql.format( - dbtable=self.name, + dbtable=escaped_name(self.name), where_clause="\n where {}".format(where) if where else "", original=original, columns=columns_sql, @@ -2692,8 +2752,8 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": pk_values = [pk_values] self.get(pk_values) wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks] - sql = "delete from [{table}] where {wheres}".format( - table=self.name, wheres=" and ".join(wheres) + sql = "delete from {table} where {wheres}".format( + table=escaped_name(self.name), wheres=" and ".join(wheres) ) with self.db.conn: self.db.execute(sql, pk_values) @@ -2717,7 +2777,7 @@ def delete_where( """ if not self.exists(): return self - sql = "delete from [{}]".format(self.name) + sql = f"delete from {escaped_name(self.name)}" if where is not None: sql += " where " + where self.db.execute(sql, where_args or []) @@ -2762,8 +2822,8 @@ def update( args.append(jsonify_if_needed(value)) wheres = ["[{}] = ?".format(pk_name) for pk_name in pks] args.extend(pk_values) - sql = "update [{table}] set {sets} where {wheres}".format( - table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres) + sql = "update {table} set {sets} where {wheres}".format( + table=escaped_name(self.name), sets=", ".join(sets), wheres=" and ".join(wheres) ) with self.db.conn: try: @@ -2843,8 +2903,8 @@ def convert_value(v): if fn_name == "": fn_name = f"lambda_{abs(hash(fn))}" self.db.register_function(convert_value, name=fn_name) - sql = "update [{table}] set {sets}{where};".format( - table=self.name, + sql = "update {table} set {sets}{where};".format( + table=escaped_name(self.name), sets=", ".join( [ "[{output_column}] = {fn_name}([{column}])".format( @@ -2965,8 +3025,8 @@ def build_insert_queries_and_params( # them since it ignores the resulting integrity errors if not_null: placeholders.extend(not_null) - sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format( - table=self.name, + sql = "INSERT OR IGNORE INTO {table}({cols}) VALUES({placeholders});".format( + table=escaped_name(self.name), cols=", ".join(["[{}]".format(p) for p in placeholders]), placeholders=", ".join(["?" for p in placeholders]), ) @@ -2976,8 +3036,8 @@ def build_insert_queries_and_params( # UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001; set_cols = [col for col in all_columns if col not in pks] if set_cols: - sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format( - table=self.name, + sql2 = "UPDATE {table} SET {pairs} WHERE {wheres}".format( + table=escaped_name(self.name), pairs=", ".join( "[{}] = {}".format(col, conversions.get(col, "?")) for col in set_cols @@ -3004,10 +3064,10 @@ def build_insert_queries_and_params( elif ignore: or_what = "OR IGNORE " sql = """ - INSERT {or_what}INTO [{table}] ({columns}) VALUES {rows}; + INSERT {or_what}INTO {table} ({columns}) VALUES {rows}; """.strip().format( or_what=or_what, - table=self.name, + table=escaped_name(self.name), columns=", ".join("[{}]".format(c) for c in all_columns), rows=", ".join( "({placeholders})".format( @@ -3265,7 +3325,7 @@ def insert_all( self.last_rowid = None self.last_pk = None if truncate and self.exists(): - self.db.execute("DELETE FROM [{}];".format(self.name)) + self.db.execute("DELETE FROM {};".format(escaped_name(self.name))) for chunk in chunks(itertools.chain([first_record], records), batch_size): chunk = list(chunk) num_records_processed += len(chunk) @@ -3776,7 +3836,7 @@ def drop(self, ignore=False): """ try: - self.db.execute("DROP VIEW [{}]".format(self.name)) + self.db.execute("DROP VIEW {}".format(escaped_name(self.name))) except sqlite3.OperationalError: if not ignore: raise From 7b2488c1a2eee543c500d0e18ddb0b8d09b93b84 Mon Sep 17 00:00:00 2001 From: Peter Gaultney Date: Tue, 9 Jan 2024 09:27:59 -0600 Subject: [PATCH 2/2] rework to store schema name explicitly --- sqlite_utils/db.py | 201 +++++++++++++++++++++++---------------------- 1 file changed, 102 insertions(+), 99 deletions(-) diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index a60df432..c204b95c 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -532,7 +532,9 @@ def executescript(self, sql: str) -> sqlite3.Cursor: self._tracer(sql, None) return self.conn.executescript(sql) - def table(self, table_name: str, **kwargs) -> Union["Table", "View"]: + def table( + self, table_name: str, schema: str = "", **kwargs + ) -> Union["Table", "View"]: """ Return a table object, optionally configured with default options. @@ -541,10 +543,10 @@ def table(self, table_name: str, **kwargs) -> Union["Table", "View"]: :param table_name: Name of the table """ if table_name in self.view_names(): - return View(self, table_name, **kwargs) + return View(self, table_name, schema_name=schema, **kwargs) else: kwargs.setdefault("strict", self.strict) - return Table(self, table_name, **kwargs) + return Table(self, table_name, schema_name=schema, **kwargs) def quote(self, value: str) -> str: """ @@ -601,44 +603,46 @@ def quote_default_value(self, value: str) -> str: return self.quote(value) - def database_names(self) -> List[str]: - "List of string database names available in this connection." - return [r[1] for r in self.execute("PRAGMA database_list").fetchall()] + def schema_names(self) -> List[str]: + """List of string database schemas available in this connection. - def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]: + Unless other databases are ATTACHed using `attach`, this will only return + `['main']` or `['main', 'temp']`. See https://www.sqlite.org/lang_attach.html """ - List of string table names in this database. + return [r[1] for r in self.execute("PRAGMA database_list").fetchall()] + + def _from_schema(self, schema: str) -> str: + if schema and schema != "main": + return f"{schema}.sqlite_master" + return "sqlite_master" # keep SQL simple for the standard case. + + def table_names( + self, fts4: bool = False, fts5: bool = False, schema: str = "" + ) -> List[str]: + """List of string table names in the specified database schema. :param fts4: Only return tables that are part of FTS4 indexes :param fts5: Only return tables that are part of FTS5 indexes + :param schema: By default, the `main` schema is queried, but a different, + attached database can be queried instead. """ where = ["type = 'table'"] if fts4: where.append("sql like '%USING FTS4%'") if fts5: where.append("sql like '%USING FTS5%'") - sql = "select name from sqlite_master where {}".format(" AND ".join(where)) - - def _exec_in_db(db_name: str, sql: str) -> List[str]: - if db_name == "main": - db_name = "" - if db_name: - sql = sql.replace("sqlite_master", f"{db_name}.sqlite_master") - table_names = [r[0] for r in self.execute(sql).fetchall()] - if db_name: - return [f"{db_name}.{tbl_name}" for tbl_name in table_names] - return table_names - - return list( - itertools.chain(*[_exec_in_db(db_name, sql) for db_name in self.database_names()]) + + sql = "select name from {} where {}".format( + self._from_schema(schema), " AND ".join(where) ) + return [r[0] for r in self.execute(sql).fetchall()] - def view_names(self) -> List[str]: + def view_names(self, schema: str = "") -> List[str]: "List of string view names in this database." return [ r[0] for r in self.execute( - "select name from sqlite_master where type = 'view'" + f"select name from {self._from_schema(schema)} where type = 'view'" ).fetchall() ] @@ -1288,26 +1292,10 @@ def init_spatialite(self, path: Optional[str] = None) -> bool: return result and bool(result[0]) -def _split_names(fullname: str) -> Tuple[str, str]: - if '.' not in fullname: - return '', fullname - return fullname.split('.') - - -def dbname(fullname: str) -> str: - return _split_names(fullname)[0] - - -def tablename(fullname: str) -> str: - return _split_names(fullname)[1] - - -def escaped_name(fullname: str) -> str: - """This is how SQLite expects a database name joined to a table name to use the square-bracket escapes.""" - db, tbl = _split_names(fullname) - if not db: - return f'[{tbl}]' - return f'{db}.[{tbl}]' +def _fullname(schema_name: str, table_name: str) -> str: + if schema_name: + return f"{schema_name}.[{table_name}]" + return "[" + table_name + "]" class Queryable: @@ -1315,9 +1303,14 @@ def exists(self) -> bool: "Does this table or view exist yet?" return False - def __init__(self, db, name: str): + def __init__(self, db, name: str, schema_name: str = ""): self.db = db self.name = name + self.schema_name = schema_name # default is empty string, a.k.a. 'main' + + @property + def _fullname(self) -> str: + return _fullname(self.schema_name, self.name) def count_where( self, @@ -1331,7 +1324,7 @@ def count_where( :param where_args: Parameters to use with that fragment - an iterable for ``id > ?`` parameters, or a dictionary for ``id > :id`` """ - sql = "select count(*) from {}".format(escaped_name(self.name)) + sql = "select count(*) from {}".format(self._fullname) if where is not None: sql += " where " + where return self.db.execute(sql, where_args or []).fetchone()[0] @@ -1374,7 +1367,7 @@ def rows_where( """ if not self.exists(): return - sql = "select {} from {}".format(select, escaped_name(self.name)) + sql = "select {} from {}".format(select, self._fullname) if where is not None: sql += " where " + where if order_by is not None: @@ -1428,13 +1421,12 @@ def pks_and_rows_where( @property def is_attached(self) -> bool: - return dbname(self.name) not in {'', 'main'} + return self.schema_name not in {"", "main"} @property def _pragma_name(self) -> Tuple[str, str]: - if "." in self.name: - db, name = self.name.split(".") - return db + ".", name + if self.schema_name: + return self.schema_name + ".", self.name return "", self.name @property @@ -1442,7 +1434,9 @@ def columns(self) -> List["Column"]: "List of :ref:`Columns ` representing the columns in this table or view." if not self.exists(): return [] - rows = self.db.execute("PRAGMA {}table_info([{}])".format(*self._pragma_name)).fetchall() + rows = self.db.execute( + "PRAGMA {}table_info([{}])".format(*self._pragma_name) + ).fetchall() return [Column(*row) for row in rows] @property @@ -1454,9 +1448,9 @@ def columns_dict(self) -> Dict[str, Any]: def schema(self) -> str: "SQL schema for this table or view." db, name = self._pragma_name - return self.db.execute(f"select sql from {db}sqlite_master where name = ?", (name,)).fetchone()[ - 0 - ] + return self.db.execute( + f"select sql from {db}sqlite_master where name = ?", (name,) + ).fetchone()[0] class Table(Queryable): @@ -1509,8 +1503,9 @@ def __init__( conversions: Optional[dict] = None, columns: Optional[Dict[str, Any]] = None, strict: bool = False, + schema_name: str = "", ): - super().__init__(db, name) + super().__init__(db, name, schema_name=schema_name) self._defaults = dict( pk=pk, foreign_keys=foreign_keys, @@ -1547,7 +1542,7 @@ def count(self) -> int: return self.count_where() def exists(self) -> bool: - return self.name in self.db.table_names() + return self.name in self.db.table_names(schema=self.schema_name) @property def pks(self) -> List[str]: @@ -1764,8 +1759,8 @@ def duplicate(self, new_name: str) -> "Table": raise NoTable(f"Table {self.name} does not exist") with self.db.conn: sql = "CREATE TABLE {new_table} AS SELECT * FROM {table};".format( - new_table=escaped_name(new_name), - table=escaped_name(self.name), + new_table=new_name, + table=self._fullname, ) self.db.execute(sql) return self.db[new_name] @@ -1820,9 +1815,9 @@ def transform( keep_table=keep_table, ) db, _ = self._pragma_name - pragma_foreign_keys_was_on = self.db.execute(f"PRAGMA {db}foreign_keys").fetchone()[ - 0 - ] + pragma_foreign_keys_was_on = self.db.execute( + f"PRAGMA {db}foreign_keys" + ).fetchone()[0] try: if pragma_foreign_keys_was_on: self.db.execute(f"PRAGMA {db}foreign_keys=0;") @@ -1999,9 +1994,12 @@ def transform_sql( if "rowid" not in new_cols: new_cols.insert(0, "rowid") old_cols.insert(0, "rowid") + + old_fullname = _fullname(self.schema_name, self.name) + new_fullname = _fullname(self.schema_name, new_table_name) copy_sql = "INSERT INTO {new_table} ({new_cols})\n SELECT {old_cols} FROM {old_table};".format( - new_table=escaped_name(new_table_name), - old_table=escaped_name(self.name), + new_table=new_fullname, + old_table=old_fullname, old_cols=", ".join("[{}]".format(col) for col in old_cols), new_cols=", ".join("[{}]".format(col) for col in new_cols), ) @@ -2009,14 +2007,14 @@ def transform_sql( # Drop (or keep) the old table if keep_table: sqls.append( - "ALTER TABLE {} RENAME TO {};".format(escaped_name(self.name), escaped_name(keep_table)) + "ALTER TABLE {} RENAME TO {};".format( + old_fullname, _fullname(self.schema_name, keep_table) + ) ) else: - sqls.append("DROP TABLE {};".format(escaped_name(self.name))) + sqls.append("DROP TABLE {};".format(old_fullname)) # Rename the new one - sqls.append( - "ALTER TABLE {} RENAME TO {};".format(escaped_name(new_table_name), escaped_name(self.name)) - ) + sqls.append("ALTER TABLE {} RENAME TO {};".format(new_fullname, old_fullname)) return sqls def extract( @@ -2079,10 +2077,10 @@ def extract( lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True) self.db.execute( "INSERT OR IGNORE INTO {lookup_table} ({lookup_columns}) SELECT DISTINCT {table_cols} FROM {table}".format( - lookup_table=escaped_name(table), + lookup_table=_fullname(self.schema_name, table), lookup_columns=", ".join("[{}]".format(c) for c in lookup_columns), table_cols=", ".join("[{}]".format(c) for c in columns), - table=escaped_name(self.name), + table=self._fullname, ) ) @@ -2090,15 +2088,16 @@ def extract( self.add_column(magic_lookup_column, int) # And populate it + lookup_table_full = _fullname(self.schema_name, table) self.db.execute( "UPDATE {table} SET [{magic_lookup_column}] = (SELECT id FROM {lookup_table} WHERE {where})".format( - table=escaped_name(self.name), + table=self._fullname, magic_lookup_column=magic_lookup_column, - lookup_table=escaped_name(table), + lookup_table=lookup_table_full, where=" AND ".join( "{table}.[{column}] IS {lookup_table}.[{lookup_column}]".format( - table=escaped_name(self.name), - lookup_table=escaped_name(table), + table=self._fullname, + lookup_table=lookup_table_full, column=column, lookup_column=rename.get(column) or column, ) @@ -2178,7 +2177,7 @@ def create_index( .strip() .format( index_name=created_index_name, - table_name=escaped_name(self.name), + table_name=self._fullname, columns=", ".join(columns_sql), unique="UNIQUE " if unique else "", if_not_exists="IF NOT EXISTS " if if_not_exists else "", @@ -2226,7 +2225,7 @@ def add_column( fk_col_type = None if fk is not None: # fk must be a valid table - if fk not in self.db.table_names(): + if fk not in self.db.table_names(schema=self.schema_name): raise AlterError("table '{}' does not exist".format(fk)) # if fk_col specified, must be a valid column if fk_col is not None: @@ -2249,7 +2248,7 @@ def add_column( self.db.quote_default_value(not_null_default) ) sql = "ALTER TABLE {table} ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( - table=escaped_name(self.name), + table=self._fullname, col_name=col_name, col_type=fk_col_type or COLUMN_TYPE_MAPPING[col_type], not_null_default=(" " + not_null_sql) if not_null_sql else "", @@ -2266,7 +2265,7 @@ def drop(self, ignore: bool = False): :param ignore: Set to ``True`` to ignore the error if the table does not exist """ try: - self.db.execute("DROP TABLE {}".format(escaped_name(self.name))) + self.db.execute("DROP TABLE {}".format(self._fullname)) except sqlite3.OperationalError: if not ignore: raise @@ -2292,7 +2291,9 @@ def guess_foreign_table(self, column: str) -> str: possibilities.append(column_without_id + "s") elif not column.endswith("s"): possibilities.append(column + "s") - existing_tables = {t.lower(): t for t in self.db.table_names()} + existing_tables = { + t.lower(): t for t in self.db.table_names(schema=self.schema_name) + } for table in possibilities: if table in existing_tables: return existing_tables[table] @@ -2444,7 +2445,7 @@ def enable_fts( :param tokenize: Custom SQLite tokenizer to use, for example ``"porter"`` to enable Porter stemming. :param replace: Should any existing FTS index for this table be replaced by the new one? """ - table_name = tablename(self.name) + table_name = self.name create_fts_sql = ( textwrap.dedent( """ @@ -2468,7 +2469,9 @@ def enable_fts( fts_schema = self.db["{}_fts".format(table_name)].schema if fts_schema != create_fts_sql: should_recreate = True - expected_triggers = {table_name + suffix for suffix in ("_ai", "_ad", "_au")} + expected_triggers = { + table_name + suffix for suffix in ("_ai", "_ad", "_au") + } existing_triggers = {t.name for t in self.triggers} has_triggers = existing_triggers.issuperset(expected_triggers) if has_triggers != create_triggers: @@ -2528,7 +2531,7 @@ def populate_fts(self, columns: Iterable[str]) -> "Table": ) .strip() .format( - table=tablename(self.name), columns=", ".join("[{}]".format(c) for c in columns) + table=self.name, columns=", ".join("[{}]".format(c) for c in columns) ) ) self.db.executescript(sql) @@ -2564,11 +2567,9 @@ def rebuild_fts(self): fts_table = self.detect_fts() if fts_table is None: # Assume this is itself an FTS table - fts_table = escaped_name(self.name) + fts_table = self._fullname self.db.execute( - "INSERT INTO {table}({table}) VALUES('rebuild');".format( - table=fts_table - ) + "INSERT INTO {table}({table}) VALUES('rebuild');".format(table=fts_table) ) return self @@ -2588,7 +2589,7 @@ def detect_fts(self) -> Optional[str]: ) """ ).strip() - table_name = tablename(self.name) + table_name = self.name args = { "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(table_name), "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(table_name), @@ -2681,7 +2682,7 @@ def search_sql( if offset is not None: limit_offset += " offset {}".format(offset) return sql.format( - dbtable=escaped_name(self.name), + dbtable=self._fullname, where_clause="\n where {}".format(where) if where else "", original=original, columns=columns_sql, @@ -2753,7 +2754,7 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": self.get(pk_values) wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks] sql = "delete from {table} where {wheres}".format( - table=escaped_name(self.name), wheres=" and ".join(wheres) + table=self._fullname, wheres=" and ".join(wheres) ) with self.db.conn: self.db.execute(sql, pk_values) @@ -2777,7 +2778,7 @@ def delete_where( """ if not self.exists(): return self - sql = f"delete from {escaped_name(self.name)}" + sql = f"delete from {self._fullname}" if where is not None: sql += " where " + where self.db.execute(sql, where_args or []) @@ -2823,7 +2824,7 @@ def update( wheres = ["[{}] = ?".format(pk_name) for pk_name in pks] args.extend(pk_values) sql = "update {table} set {sets} where {wheres}".format( - table=escaped_name(self.name), sets=", ".join(sets), wheres=" and ".join(wheres) + table=self._fullname, sets=", ".join(sets), wheres=" and ".join(wheres) ) with self.db.conn: try: @@ -2904,7 +2905,7 @@ def convert_value(v): fn_name = f"lambda_{abs(hash(fn))}" self.db.register_function(convert_value, name=fn_name) sql = "update {table} set {sets}{where};".format( - table=escaped_name(self.name), + table=self._fullname, sets=", ".join( [ "[{output_column}] = {fn_name}([{column}])".format( @@ -3026,7 +3027,7 @@ def build_insert_queries_and_params( if not_null: placeholders.extend(not_null) sql = "INSERT OR IGNORE INTO {table}({cols}) VALUES({placeholders});".format( - table=escaped_name(self.name), + table=self._fullname, cols=", ".join(["[{}]".format(p) for p in placeholders]), placeholders=", ".join(["?" for p in placeholders]), ) @@ -3037,7 +3038,7 @@ def build_insert_queries_and_params( set_cols = [col for col in all_columns if col not in pks] if set_cols: sql2 = "UPDATE {table} SET {pairs} WHERE {wheres}".format( - table=escaped_name(self.name), + table=self._fullname, pairs=", ".join( "[{}] = {}".format(col, conversions.get(col, "?")) for col in set_cols @@ -3067,7 +3068,7 @@ def build_insert_queries_and_params( INSERT {or_what}INTO {table} ({columns}) VALUES {rows}; """.strip().format( or_what=or_what, - table=escaped_name(self.name), + table=self._fullname, columns=", ".join("[{}]".format(c) for c in all_columns), rows=", ".join( "({placeholders})".format( @@ -3325,7 +3326,7 @@ def insert_all( self.last_rowid = None self.last_pk = None if truncate and self.exists(): - self.db.execute("DELETE FROM {};".format(escaped_name(self.name))) + self.db.execute("DELETE FROM {};".format(self._fullname)) for chunk in chunks(itertools.chain([first_record], records), batch_size): chunk = list(chunk) num_records_processed += len(chunk) @@ -3809,7 +3810,9 @@ def create_spatial_index(self, column_name) -> bool: :param column_name: Geometry column to create the spatial index against """ - if f"idx_{self.name}_{column_name}" in self.db.table_names(): + if f"idx_{self.name}_{column_name}" in self.db.table_names( + schema=self.schema_name + ): return False cursor = self.db.execute( @@ -3836,7 +3839,7 @@ def drop(self, ignore=False): """ try: - self.db.execute("DROP VIEW {}".format(escaped_name(self.name))) + self.db.execute("DROP VIEW {}".format(self._fullname)) except sqlite3.OperationalError: if not ignore: raise