3232 Generator ,
3333 Iterable ,
3434 Sequence ,
35+ Set ,
36+ Type ,
3537 Union ,
3638 Optional ,
3739 List ,
@@ -287,7 +289,7 @@ class DescIndex(str):
287289class BadMultiValues (Exception ):
288290 "With multi=True code must return a Python dictionary"
289291
290- def __init__ (self , values ) :
292+ def __init__ (self , values : object ) -> None :
291293 self .values = values
292294
293295
@@ -386,15 +388,20 @@ def __init__(
386388 def __enter__ (self ) -> "Database" :
387389 return self
388390
389- def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
391+ def __exit__ (
392+ self ,
393+ exc_type : Optional [Type [BaseException ]],
394+ exc_val : Optional [BaseException ],
395+ exc_tb : Optional [object ],
396+ ) -> None :
390397 self .close ()
391398
392399 def close (self ) -> None :
393400 "Close the SQLite connection, and the underlying database file"
394401 self .conn .close ()
395402
396403 @contextlib .contextmanager
397- def ensure_autocommit_off (self ):
404+ def ensure_autocommit_off (self ) -> Generator [ None , None , None ] :
398405 """
399406 Ensure autocommit is off for this database connection.
400407
@@ -413,7 +420,9 @@ def ensure_autocommit_off(self):
413420 self .conn .isolation_level = old_isolation_level
414421
415422 @contextlib .contextmanager
416- def tracer (self , tracer : Optional [Callable ] = None ):
423+ def tracer (
424+ self , tracer : Optional [Callable [[str , Optional [Sequence ]], None ]] = None
425+ ) -> Generator ["Database" , None , None ]:
417426 """
418427 Context manager to temporarily set a tracer function - all executed SQL queries will
419428 be passed to this.
@@ -456,7 +465,7 @@ def register_function(
456465 deterministic : bool = False ,
457466 replace : bool = False ,
458467 name : Optional [str ] = None ,
459- ):
468+ ) -> Optional [ Callable [[ Callable ], Callable ]] :
460469 """
461470 ``fn`` will be made available as a function within SQL, with the same name and number
462471 of arguments. Can be used as a decorator::
@@ -479,12 +488,12 @@ def upper(value):
479488 :param name: name of the SQLite function - if not specified, the Python function name will be used
480489 """
481490
482- def register (fn ) :
491+ def register (fn : Callable ) -> Callable :
483492 fn_name = name or fn .__name__
484493 arity = len (inspect .signature (fn ).parameters )
485494 if not replace and (fn_name , arity ) in self ._registered_functions :
486495 return fn
487- kwargs = {}
496+ kwargs : Dict [ str , bool ] = {}
488497 registered = False
489498 if deterministic :
490499 # Try this, but fall back if sqlite3.NotSupportedError
@@ -504,12 +513,13 @@ def register(fn):
504513 return register
505514 else :
506515 register (fn )
516+ return None
507517
508- def register_fts4_bm25 (self ):
518+ def register_fts4_bm25 (self ) -> None :
509519 "Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4."
510520 self .register_function (rank_bm25 , deterministic = True , replace = True )
511521
512- def attach (self , alias : str , filepath : Union [str , pathlib .Path ]):
522+ def attach (self , alias : str , filepath : Union [str , pathlib .Path ]) -> None :
513523 """
514524 Attach another SQLite database file to this connection with the specified alias, equivalent to::
515525
@@ -567,7 +577,7 @@ def executescript(self, sql: str) -> sqlite3.Cursor:
567577 self ._tracer (sql , None )
568578 return self .conn .executescript (sql )
569579
570- def table (self , table_name : str , ** kwargs ) -> "Table" :
580+ def table (self , table_name : str , ** kwargs : Any ) -> "Table" :
571581 """
572582 Return a table object, optionally configured with default options.
573583
@@ -766,25 +776,25 @@ def journal_mode(self) -> str:
766776 """
767777 return self .execute ("PRAGMA journal_mode;" ).fetchone ()[0 ]
768778
769- def enable_wal (self ):
779+ def enable_wal (self ) -> None :
770780 """
771781 Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode.
772782 """
773783 if self .journal_mode != "wal" :
774784 with self .ensure_autocommit_off ():
775785 self .execute ("PRAGMA journal_mode=wal;" )
776786
777- def disable_wal (self ):
787+ def disable_wal (self ) -> None :
778788 "Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode."
779789 if self .journal_mode != "delete" :
780790 with self .ensure_autocommit_off ():
781791 self .execute ("PRAGMA journal_mode=delete;" )
782792
783- def _ensure_counts_table (self ):
793+ def _ensure_counts_table (self ) -> None :
784794 with self .conn :
785795 self .execute (_COUNTS_TABLE_CREATE_SQL .format (self ._counts_table_name ))
786796
787- def enable_counts (self ):
797+ def enable_counts (self ) -> None :
788798 """
789799 Enable trigger-based count caching for every table in the database, see
790800 :ref:`python_api_cached_table_counts`.
@@ -814,7 +824,7 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int
814824 except OperationalError :
815825 return {}
816826
817- def reset_counts (self ):
827+ def reset_counts (self ) -> None :
818828 "Re-calculate cached counts for tables."
819829 tables = [table for table in self .tables if table .has_counts_triggers ]
820830 with self .conn :
@@ -1159,7 +1169,7 @@ def create_table(
11591169 hash_id_columns = hash_id_columns ,
11601170 )
11611171
1162- def rename_table (self , name : str , new_name : str ):
1172+ def rename_table (self , name : str , new_name : str ) -> None :
11631173 """
11641174 Rename a table.
11651175
@@ -1174,7 +1184,7 @@ def rename_table(self, name: str, new_name: str):
11741184
11751185 def create_view (
11761186 self , name : str , sql : str , ignore : bool = False , replace : bool = False
1177- ):
1187+ ) -> "Database" :
11781188 """
11791189 Create a new SQL view with the specified name - ``sql`` should start with ``SELECT ...``.
11801190
@@ -1220,7 +1230,9 @@ def m2m_table_candidates(self, table: str, other_table: str) -> List[str]:
12201230 candidates .append (table_obj .name )
12211231 return candidates
12221232
1223- def add_foreign_keys (self , foreign_keys : Iterable [Tuple [str , str , str , str ]]):
1233+ def add_foreign_keys (
1234+ self , foreign_keys : Iterable [Tuple [str , str , str , str ]]
1235+ ) -> None :
12241236 """
12251237 See :ref:`python_api_add_foreign_keys`.
12261238
@@ -1272,7 +1284,7 @@ def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]):
12721284
12731285 self .vacuum ()
12741286
1275- def index_foreign_keys (self ):
1287+ def index_foreign_keys (self ) -> None :
12761288 "Create indexes for every foreign key column on every table in the database."
12771289 for table_name in self .table_names ():
12781290 table = self .table (table_name )
@@ -1283,11 +1295,11 @@ def index_foreign_keys(self):
12831295 if fk .column not in existing_indexes :
12841296 table .create_index ([fk .column ], find_unique_name = True )
12851297
1286- def vacuum (self ):
1298+ def vacuum (self ) -> None :
12871299 "Run a SQLite ``VACUUM`` against the database."
12881300 self .execute ("VACUUM;" )
12891301
1290- def analyze (self , name = None ):
1302+ def analyze (self , name : Optional [ str ] = None ) -> None :
12911303 """
12921304 Run ``ANALYZE`` against the entire database or a named table or index.
12931305
@@ -1355,18 +1367,21 @@ def init_spatialite(self, path: Optional[str] = None) -> bool:
13551367
13561368
13571369class Queryable :
1370+ db : "Database"
1371+ name : str
1372+
13581373 def exists (self ) -> bool :
13591374 "Does this table or view exist yet?"
13601375 return False
13611376
1362- def __init__ (self , db , name ) :
1377+ def __init__ (self , db : "Database" , name : str ) -> None :
13631378 self .db = db
13641379 self .name = name
13651380
13661381 def count_where (
13671382 self ,
13681383 where : Optional [str ] = None ,
1369- where_args : Optional [Union [Iterable , dict ]] = None ,
1384+ where_args : Optional [Union [Sequence , Dict [ str , Any ] ]] = None ,
13701385 ) -> int :
13711386 """
13721387 Executes ``SELECT count(*) FROM table WHERE ...`` and returns a count.
@@ -1380,7 +1395,7 @@ def count_where(
13801395 sql += " where " + where
13811396 return self .db .execute (sql , where_args or []).fetchone ()[0 ]
13821397
1383- def execute_count (self ):
1398+ def execute_count (self ) -> int :
13841399 # Backwards compatibility, see https://github.com/simonw/sqlite-utils/issues/305#issuecomment-890713185
13851400 return self .count_where ()
13861401
@@ -1390,19 +1405,19 @@ def count(self) -> int:
13901405 return self .count_where ()
13911406
13921407 @property
1393- def rows (self ) -> Generator [dict , None , None ]:
1408+ def rows (self ) -> Generator [Dict [ str , Any ] , None , None ]:
13941409 "Iterate over every dictionaries for each row in this table or view."
13951410 return self .rows_where ()
13961411
13971412 def rows_where (
13981413 self ,
13991414 where : Optional [str ] = None ,
1400- where_args : Optional [Union [Iterable , dict ]] = None ,
1415+ where_args : Optional [Union [Sequence , Dict [ str , Any ] ]] = None ,
14011416 order_by : Optional [str ] = None ,
14021417 select : str = "*" ,
14031418 limit : Optional [int ] = None ,
14041419 offset : Optional [int ] = None ,
1405- ) -> Generator [dict , None , None ]:
1420+ ) -> Generator [Dict [ str , Any ] , None , None ]:
14061421 """
14071422 Iterate over every row in this table or view that matches the specified where clause.
14081423
@@ -2350,7 +2365,7 @@ def add_column(
23502365 self .add_foreign_key (col_name , fk , fk_col )
23512366 return self
23522367
2353- def drop (self , ignore : bool = False ):
2368+ def drop (self , ignore : bool = False ) -> None :
23542369 """
23552370 Drop this table.
23562371
@@ -2394,7 +2409,7 @@ def guess_foreign_table(self, column: str) -> str:
23942409 )
23952410 )
23962411
2397- def guess_foreign_column (self , other_table : str ):
2412+ def guess_foreign_column (self , other_table : str ) -> str :
23982413 pks = [c for c in self .db [other_table ].columns if c .is_pk ]
23992414 if len (pks ) != 1 :
24002415 raise BadPrimaryKey (
@@ -2453,7 +2468,7 @@ def add_foreign_key(
24532468 self .db .add_foreign_keys ([(self .name , column , other_table , other_column )])
24542469 return self
24552470
2456- def enable_counts (self ):
2471+ def enable_counts (self ) -> None :
24572472 """
24582473 Set up triggers to update a cache of the count of rows in this table.
24592474
@@ -2665,7 +2680,7 @@ def disable_fts(self) -> "Table":
26652680 )
26662681 return self
26672682
2668- def rebuild_fts (self ):
2683+ def rebuild_fts (self ) -> None :
26692684 "Run the ``rebuild`` operation against the associated full-text search index table."
26702685 fts_table = self .detect_fts ()
26712686 if fts_table is None :
@@ -2849,7 +2864,7 @@ def search(
28492864 for row in cursor :
28502865 yield dict (zip (columns , row ))
28512866
2852- def value_or_default (self , key , value ) :
2867+ def value_or_default (self , key : str , value : Any ) -> Any :
28532868 return self ._defaults [key ] if value is DEFAULT else value
28542869
28552870 def delete (self , pk_values : Union [list , tuple , str , int , float ]) -> "Table" :
@@ -3913,7 +3928,7 @@ def m2m(
39133928 )
39143929 return self
39153930
3916- def analyze (self ):
3931+ def analyze (self ) -> None :
39173932 "Run ANALYZE against this table"
39183933 self .db .analyze (self .name )
39193934
@@ -4105,15 +4120,15 @@ def create_spatial_index(self, column_name) -> bool:
41054120
41064121
41074122class View (Queryable ):
4108- def exists (self ):
4123+ def exists (self ) -> bool :
41094124 return True
41104125
41114126 def __repr__ (self ) -> str :
41124127 return "<View {} ({})>" .format (
41134128 self .name , ", " .join (c .name for c in self .columns )
41144129 )
41154130
4116- def drop (self , ignore = False ):
4131+ def drop (self , ignore : bool = False ) -> None :
41174132 """
41184133 Drop this view.
41194134
@@ -4126,14 +4141,14 @@ def drop(self, ignore=False):
41264141 if not ignore :
41274142 raise
41284143
4129- def enable_fts (self , * args , ** kwargs ) :
4144+ def enable_fts (self , * args : object , ** kwargs : object ) -> None :
41304145 "``enable_fts()`` is supported on tables but not on views."
41314146 raise NotImplementedError (
41324147 "enable_fts() is supported on tables but not on views"
41334148 )
41344149
41354150
4136- def jsonify_if_needed (value ) :
4151+ def jsonify_if_needed (value : object ) -> object :
41374152 if isinstance (value , decimal .Decimal ):
41384153 return float (value )
41394154 if isinstance (value , (dict , list , tuple )):
@@ -4158,7 +4173,7 @@ def resolve_extracts(
41584173 return extracts
41594174
41604175
4161- def _decode_default_value (value ) :
4176+ def _decode_default_value (value : str ) -> object :
41624177 if value .startswith ("'" ) and value .endswith ("'" ):
41634178 # It's a string
41644179 return value [1 :- 1 ]
0 commit comments