Skip to content

Commit 8a54c8b

Browse files
simonwclaude
andcommitted
Fix type errors in db.py
- Add type annotation for Database.conn - Add type: ignore for optional sqlite_dump import - Update execute/query parameter types to Sequence|Dict for sqlite3 compatibility - Use getattr for fn.__name__ access to handle callables without __name__ - Handle None return from find_spatialite() with OSError - Fix pk_values assignment to use local variable 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 5c15db9 commit 8a54c8b

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

sqlite_utils/db.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from sqlite_utils.plugins import pm
4242

4343
try:
44-
from sqlite_dump import iterdump
44+
from sqlite_dump import iterdump # type: ignore[import-not-found]
4545
except ImportError:
4646
iterdump = None
4747

@@ -526,7 +526,7 @@ def attach(self, alias: str, filepath: Union[str, pathlib.Path]):
526526
self.execute(attach_sql)
527527

528528
def query(
529-
self, sql: str, params: Optional[Union[Iterable, dict]] = None
529+
self, sql: str, params: Optional[Union[Sequence, Dict[str, Any]]] = None
530530
) -> Generator[dict, None, None]:
531531
"""
532532
Execute ``sql`` and return an iterable of dictionaries representing each row.
@@ -541,7 +541,7 @@ def query(
541541
yield dict(zip(keys, row))
542542

543543
def execute(
544-
self, sql: str, parameters: Optional[Union[Iterable, dict]] = None
544+
self, sql: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None
545545
) -> sqlite3.Cursor:
546546
"""
547547
Execute SQL query and return a ``sqlite3.Cursor``.
@@ -806,10 +806,11 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int
806806
:param tables: Subset list of tables to return counts for.
807807
"""
808808
sql = 'select "table", count from {}'.format(self._counts_table_name)
809-
if tables:
810-
sql += ' where "table" in ({})'.format(", ".join("?" for table in tables))
809+
tables_list = list(tables) if tables else None
810+
if tables_list:
811+
sql += ' where "table" in ({})'.format(", ".join("?" for _ in tables_list))
811812
try:
812-
return {r[0]: r[1] for r in self.execute(sql, tables).fetchall()}
813+
return {r[0]: r[1] for r in self.execute(sql, tables_list).fetchall()}
813814
except OperationalError:
814815
return {}
815816

@@ -826,7 +827,7 @@ def reset_counts(self):
826827
)
827828

828829
def execute_returning_dicts(
829-
self, sql: str, params: Optional[Union[Iterable, dict]] = None
830+
self, sql: str, params: Optional[Union[Sequence, Dict[str, Any]]] = None
830831
) -> List[dict]:
831832
return list(self.query(sql, params))
832833

@@ -1340,6 +1341,8 @@ def init_spatialite(self, path: Optional[str] = None) -> bool:
13401341
"""
13411342
if path is None:
13421343
path = find_spatialite()
1344+
if path is None:
1345+
raise OSError("Could not find SpatiaLite extension")
13431346

13441347
self.conn.enable_load_extension(True)
13451348
self.conn.load_extension(path)
@@ -3006,7 +3009,7 @@ def convert_value(v):
30063009
bar.update(1)
30073010
return jsonify_if_needed(fn(v))
30083011

3009-
fn_name = fn.__name__
3012+
fn_name = getattr(fn, "__name__", "fn")
30103013
if fn_name == "<lambda>":
30113014
fn_name = f"lambda_{abs(hash(fn))}"
30123015
self.db.register_function(convert_value, name=fn_name)
@@ -3251,9 +3254,11 @@ def build_insert_queries_and_params(
32513254
)
32523255
# We can populate .last_pk right here
32533256
if num_records_processed == 1:
3254-
self.last_pk = tuple(record[pk] for pk in pks)
3255-
if len(self.last_pk) == 1:
3256-
self.last_pk = self.last_pk[0]
3257+
pk_values = tuple(record[pk] for pk in pks)
3258+
if len(pk_values) == 1:
3259+
self.last_pk = pk_values[0]
3260+
else:
3261+
self.last_pk = pk_values
32573262
return queries_and_params
32583263

32593264
def insert_chunk(

0 commit comments

Comments
 (0)