Skip to content

Commit d833a5c

Browse files
simonwclaude
andcommitted
Add comprehensive type annotations
- mypy.ini: expanded configuration with module-specific settings - hookspecs.py: type annotations for hook functions - plugins.py: typed get_plugins() return value - recipes.py: full type annotations for parsedate, parsedatetime, jsonsplit - utils.py: extensive type annotations including Row type alias, TypeTracker, ValueTracker, and all utility functions - db.py: type annotations for Database methods (__exit__, ensure_autocommit_off, tracer, register_function, etc.) and Queryable class methods - tests/test_docs.py: updated to match new signature display format 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 8710385 commit d833a5c

File tree

7 files changed

+232
-117
lines changed

7 files changed

+232
-117
lines changed

mypy.ini

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,32 @@
11
[mypy]
2+
python_version = 3.10
3+
warn_return_any = False
4+
warn_unused_configs = True
5+
warn_redundant_casts = False
6+
warn_unused_ignores = False
7+
check_untyped_defs = True
8+
disallow_untyped_defs = False
9+
disallow_incomplete_defs = False
10+
no_implicit_optional = True
11+
strict_equality = True
212

3-
[mypy-pysqlite3,sqlean,sqlite_dump]
4-
ignore_missing_imports = True
13+
[mypy-sqlite_utils.cli]
14+
ignore_errors = True
15+
16+
[mypy-pysqlite3.*]
17+
ignore_missing_imports = True
18+
19+
[mypy-sqlean.*]
20+
ignore_missing_imports = True
21+
22+
[mypy-sqlite_dump.*]
23+
ignore_missing_imports = True
24+
25+
[mypy-sqlite_fts4.*]
26+
ignore_missing_imports = True
27+
28+
[mypy-pandas.*]
29+
ignore_missing_imports = True
30+
31+
[mypy-numpy.*]
32+
ignore_missing_imports = True

sqlite_utils/db.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
Generator,
3333
Iterable,
3434
Sequence,
35+
Set,
36+
Type,
3537
Union,
3638
Optional,
3739
List,
@@ -287,7 +289,7 @@ class DescIndex(str):
287289
class BadMultiValues(Exception):
288290
"With multi=True code must return a Python dictionary"
289291

290-
def __init__(self, values):
292+
def __init__(self, values: object) -> None:
291293
self.values = values
292294

293295

@@ -386,15 +388,20 @@ def __init__(
386388
def __enter__(self) -> "Database":
387389
return self
388390

389-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
391+
def __exit__(
392+
self,
393+
exc_type: Optional[Type[BaseException]],
394+
exc_val: Optional[BaseException],
395+
exc_tb: Optional[object],
396+
) -> None:
390397
self.close()
391398

392399
def close(self) -> None:
393400
"Close the SQLite connection, and the underlying database file"
394401
self.conn.close()
395402

396403
@contextlib.contextmanager
397-
def ensure_autocommit_off(self):
404+
def ensure_autocommit_off(self) -> Generator[None, None, None]:
398405
"""
399406
Ensure autocommit is off for this database connection.
400407
@@ -413,7 +420,9 @@ def ensure_autocommit_off(self):
413420
self.conn.isolation_level = old_isolation_level
414421

415422
@contextlib.contextmanager
416-
def tracer(self, tracer: Optional[Callable] = None):
423+
def tracer(
424+
self, tracer: Optional[Callable[[str, Optional[Sequence]], None]] = None
425+
) -> Generator["Database", None, None]:
417426
"""
418427
Context manager to temporarily set a tracer function - all executed SQL queries will
419428
be passed to this.
@@ -456,7 +465,7 @@ def register_function(
456465
deterministic: bool = False,
457466
replace: bool = False,
458467
name: Optional[str] = None,
459-
):
468+
) -> Optional[Callable[[Callable], Callable]]:
460469
"""
461470
``fn`` will be made available as a function within SQL, with the same name and number
462471
of arguments. Can be used as a decorator::
@@ -479,12 +488,12 @@ def upper(value):
479488
:param name: name of the SQLite function - if not specified, the Python function name will be used
480489
"""
481490

482-
def register(fn):
491+
def register(fn: Callable) -> Callable:
483492
fn_name = name or fn.__name__
484493
arity = len(inspect.signature(fn).parameters)
485494
if not replace and (fn_name, arity) in self._registered_functions:
486495
return fn
487-
kwargs = {}
496+
kwargs: Dict[str, bool] = {}
488497
registered = False
489498
if deterministic:
490499
# Try this, but fall back if sqlite3.NotSupportedError
@@ -504,12 +513,13 @@ def register(fn):
504513
return register
505514
else:
506515
register(fn)
516+
return None
507517

508-
def register_fts4_bm25(self):
518+
def register_fts4_bm25(self) -> None:
509519
"Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4."
510520
self.register_function(rank_bm25, deterministic=True, replace=True)
511521

512-
def attach(self, alias: str, filepath: Union[str, pathlib.Path]):
522+
def attach(self, alias: str, filepath: Union[str, pathlib.Path]) -> None:
513523
"""
514524
Attach another SQLite database file to this connection with the specified alias, equivalent to::
515525
@@ -567,7 +577,7 @@ def executescript(self, sql: str) -> sqlite3.Cursor:
567577
self._tracer(sql, None)
568578
return self.conn.executescript(sql)
569579

570-
def table(self, table_name: str, **kwargs) -> "Table":
580+
def table(self, table_name: str, **kwargs: Any) -> "Table":
571581
"""
572582
Return a table object, optionally configured with default options.
573583
@@ -766,25 +776,25 @@ def journal_mode(self) -> str:
766776
"""
767777
return self.execute("PRAGMA journal_mode;").fetchone()[0]
768778

769-
def enable_wal(self):
779+
def enable_wal(self) -> None:
770780
"""
771781
Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode.
772782
"""
773783
if self.journal_mode != "wal":
774784
with self.ensure_autocommit_off():
775785
self.execute("PRAGMA journal_mode=wal;")
776786

777-
def disable_wal(self):
787+
def disable_wal(self) -> None:
778788
"Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode."
779789
if self.journal_mode != "delete":
780790
with self.ensure_autocommit_off():
781791
self.execute("PRAGMA journal_mode=delete;")
782792

783-
def _ensure_counts_table(self):
793+
def _ensure_counts_table(self) -> None:
784794
with self.conn:
785795
self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name))
786796

787-
def enable_counts(self):
797+
def enable_counts(self) -> None:
788798
"""
789799
Enable trigger-based count caching for every table in the database, see
790800
:ref:`python_api_cached_table_counts`.
@@ -814,7 +824,7 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int
814824
except OperationalError:
815825
return {}
816826

817-
def reset_counts(self):
827+
def reset_counts(self) -> None:
818828
"Re-calculate cached counts for tables."
819829
tables = [table for table in self.tables if table.has_counts_triggers]
820830
with self.conn:
@@ -1159,7 +1169,7 @@ def create_table(
11591169
hash_id_columns=hash_id_columns,
11601170
)
11611171

1162-
def rename_table(self, name: str, new_name: str):
1172+
def rename_table(self, name: str, new_name: str) -> None:
11631173
"""
11641174
Rename a table.
11651175
@@ -1174,7 +1184,7 @@ def rename_table(self, name: str, new_name: str):
11741184

11751185
def create_view(
11761186
self, name: str, sql: str, ignore: bool = False, replace: bool = False
1177-
):
1187+
) -> "Database":
11781188
"""
11791189
Create a new SQL view with the specified name - ``sql`` should start with ``SELECT ...``.
11801190
@@ -1220,7 +1230,9 @@ def m2m_table_candidates(self, table: str, other_table: str) -> List[str]:
12201230
candidates.append(table_obj.name)
12211231
return candidates
12221232

1223-
def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]):
1233+
def add_foreign_keys(
1234+
self, foreign_keys: Iterable[Tuple[str, str, str, str]]
1235+
) -> None:
12241236
"""
12251237
See :ref:`python_api_add_foreign_keys`.
12261238
@@ -1272,7 +1284,7 @@ def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]):
12721284

12731285
self.vacuum()
12741286

1275-
def index_foreign_keys(self):
1287+
def index_foreign_keys(self) -> None:
12761288
"Create indexes for every foreign key column on every table in the database."
12771289
for table_name in self.table_names():
12781290
table = self.table(table_name)
@@ -1283,11 +1295,11 @@ def index_foreign_keys(self):
12831295
if fk.column not in existing_indexes:
12841296
table.create_index([fk.column], find_unique_name=True)
12851297

1286-
def vacuum(self):
1298+
def vacuum(self) -> None:
12871299
"Run a SQLite ``VACUUM`` against the database."
12881300
self.execute("VACUUM;")
12891301

1290-
def analyze(self, name=None):
1302+
def analyze(self, name: Optional[str] = None) -> None:
12911303
"""
12921304
Run ``ANALYZE`` against the entire database or a named table or index.
12931305
@@ -1355,18 +1367,21 @@ def init_spatialite(self, path: Optional[str] = None) -> bool:
13551367

13561368

13571369
class Queryable:
1370+
db: "Database"
1371+
name: str
1372+
13581373
def exists(self) -> bool:
13591374
"Does this table or view exist yet?"
13601375
return False
13611376

1362-
def __init__(self, db, name):
1377+
def __init__(self, db: "Database", name: str) -> None:
13631378
self.db = db
13641379
self.name = name
13651380

13661381
def count_where(
13671382
self,
13681383
where: Optional[str] = None,
1369-
where_args: Optional[Union[Iterable, dict]] = None,
1384+
where_args: Optional[Union[Sequence, Dict[str, Any]]] = None,
13701385
) -> int:
13711386
"""
13721387
Executes ``SELECT count(*) FROM table WHERE ...`` and returns a count.
@@ -1380,7 +1395,7 @@ def count_where(
13801395
sql += " where " + where
13811396
return self.db.execute(sql, where_args or []).fetchone()[0]
13821397

1383-
def execute_count(self):
1398+
def execute_count(self) -> int:
13841399
# Backwards compatibility, see https://github.com/simonw/sqlite-utils/issues/305#issuecomment-890713185
13851400
return self.count_where()
13861401

@@ -1390,19 +1405,19 @@ def count(self) -> int:
13901405
return self.count_where()
13911406

13921407
@property
1393-
def rows(self) -> Generator[dict, None, None]:
1408+
def rows(self) -> Generator[Dict[str, Any], None, None]:
13941409
"Iterate over every dictionaries for each row in this table or view."
13951410
return self.rows_where()
13961411

13971412
def rows_where(
13981413
self,
13991414
where: Optional[str] = None,
1400-
where_args: Optional[Union[Iterable, dict]] = None,
1415+
where_args: Optional[Union[Sequence, Dict[str, Any]]] = None,
14011416
order_by: Optional[str] = None,
14021417
select: str = "*",
14031418
limit: Optional[int] = None,
14041419
offset: Optional[int] = None,
1405-
) -> Generator[dict, None, None]:
1420+
) -> Generator[Dict[str, Any], None, None]:
14061421
"""
14071422
Iterate over every row in this table or view that matches the specified where clause.
14081423
@@ -2350,7 +2365,7 @@ def add_column(
23502365
self.add_foreign_key(col_name, fk, fk_col)
23512366
return self
23522367

2353-
def drop(self, ignore: bool = False):
2368+
def drop(self, ignore: bool = False) -> None:
23542369
"""
23552370
Drop this table.
23562371
@@ -2394,7 +2409,7 @@ def guess_foreign_table(self, column: str) -> str:
23942409
)
23952410
)
23962411

2397-
def guess_foreign_column(self, other_table: str):
2412+
def guess_foreign_column(self, other_table: str) -> str:
23982413
pks = [c for c in self.db[other_table].columns if c.is_pk]
23992414
if len(pks) != 1:
24002415
raise BadPrimaryKey(
@@ -2453,7 +2468,7 @@ def add_foreign_key(
24532468
self.db.add_foreign_keys([(self.name, column, other_table, other_column)])
24542469
return self
24552470

2456-
def enable_counts(self):
2471+
def enable_counts(self) -> None:
24572472
"""
24582473
Set up triggers to update a cache of the count of rows in this table.
24592474
@@ -2665,7 +2680,7 @@ def disable_fts(self) -> "Table":
26652680
)
26662681
return self
26672682

2668-
def rebuild_fts(self):
2683+
def rebuild_fts(self) -> None:
26692684
"Run the ``rebuild`` operation against the associated full-text search index table."
26702685
fts_table = self.detect_fts()
26712686
if fts_table is None:
@@ -2849,7 +2864,7 @@ def search(
28492864
for row in cursor:
28502865
yield dict(zip(columns, row))
28512866

2852-
def value_or_default(self, key, value):
2867+
def value_or_default(self, key: str, value: Any) -> Any:
28532868
return self._defaults[key] if value is DEFAULT else value
28542869

28552870
def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table":
@@ -3913,7 +3928,7 @@ def m2m(
39133928
)
39143929
return self
39153930

3916-
def analyze(self):
3931+
def analyze(self) -> None:
39173932
"Run ANALYZE against this table"
39183933
self.db.analyze(self.name)
39193934

@@ -4105,15 +4120,15 @@ def create_spatial_index(self, column_name) -> bool:
41054120

41064121

41074122
class View(Queryable):
4108-
def exists(self):
4123+
def exists(self) -> bool:
41094124
return True
41104125

41114126
def __repr__(self) -> str:
41124127
return "<View {} ({})>".format(
41134128
self.name, ", ".join(c.name for c in self.columns)
41144129
)
41154130

4116-
def drop(self, ignore=False):
4131+
def drop(self, ignore: bool = False) -> None:
41174132
"""
41184133
Drop this view.
41194134
@@ -4126,14 +4141,14 @@ def drop(self, ignore=False):
41264141
if not ignore:
41274142
raise
41284143

4129-
def enable_fts(self, *args, **kwargs):
4144+
def enable_fts(self, *args: object, **kwargs: object) -> None:
41304145
"``enable_fts()`` is supported on tables but not on views."
41314146
raise NotImplementedError(
41324147
"enable_fts() is supported on tables but not on views"
41334148
)
41344149

41354150

4136-
def jsonify_if_needed(value):
4151+
def jsonify_if_needed(value: object) -> object:
41374152
if isinstance(value, decimal.Decimal):
41384153
return float(value)
41394154
if isinstance(value, (dict, list, tuple)):
@@ -4158,7 +4173,7 @@ def resolve_extracts(
41584173
return extracts
41594174

41604175

4161-
def _decode_default_value(value):
4176+
def _decode_default_value(value: str) -> object:
41624177
if value.startswith("'") and value.endswith("'"):
41634178
# It's a string
41644179
return value[1:-1]

0 commit comments

Comments
 (0)