Skip to content

Commit c79737b

Browse files
committed
Type signatures for .create_table() and .create_table_sql() and .create() and Table.__init__
Closes #314
1 parent 282e813 commit c79737b

File tree

2 files changed

+62
-43
lines changed

2 files changed

+62
-43
lines changed

sqlite_utils/db.py

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@
147147
Trigger = namedtuple("Trigger", ("name", "table", "sql"))
148148

149149

150+
ForeignKeysType = Union[
151+
Iterable[str],
152+
Iterable[ForeignKey],
153+
Iterable[Tuple[str, str]],
154+
Iterable[Tuple[str, str, str]],
155+
Iterable[Tuple[str, str, str, str]],
156+
]
157+
158+
150159
class Default:
151160
pass
152161

@@ -572,18 +581,22 @@ def execute_returning_dicts(
572581
) -> List[dict]:
573582
return list(self.query(sql, params))
574583

575-
def resolve_foreign_keys(self, name, foreign_keys):
576-
# foreign_keys may be a list of strcolumn names, a list of ForeignKey tuples,
584+
def resolve_foreign_keys(
585+
self, name: str, foreign_keys: ForeignKeysType
586+
) -> List[ForeignKey]:
587+
# foreign_keys may be a list of column names, a list of ForeignKey tuples,
577588
# a list of tuple-pairs or a list of tuple-triples. We want to turn
578589
# it into a list of ForeignKey tuples
590+
table = cast(Table, self[name])
579591
if all(isinstance(fk, ForeignKey) for fk in foreign_keys):
580-
return foreign_keys
592+
return cast(List[ForeignKey], foreign_keys)
581593
if all(isinstance(fk, str) for fk in foreign_keys):
582594
# It's a list of columns
583595
fks = []
584596
for column in foreign_keys:
585-
other_table = self[name].guess_foreign_table(column)
586-
other_column = self[name].guess_foreign_column(other_table)
597+
column = cast(str, column)
598+
other_table = table.guess_foreign_table(column)
599+
other_column = table.guess_foreign_column(other_table)
587600
fks.append(ForeignKey(name, column, other_table, other_column))
588601
return fks
589602
assert all(
@@ -596,6 +609,7 @@ def resolve_foreign_keys(self, name, foreign_keys):
596609
3,
597610
), "foreign_keys= should be a list of tuple pairs or triples"
598611
if len(tuple_or_list) == 3:
612+
tuple_or_list = cast(Tuple[str, str, str], tuple_or_list)
599613
fks.append(
600614
ForeignKey(
601615
name, tuple_or_list[0], tuple_or_list[1], tuple_or_list[2]
@@ -608,7 +622,7 @@ def resolve_foreign_keys(self, name, foreign_keys):
608622
name,
609623
tuple_or_list[0],
610624
tuple_or_list[1],
611-
self[name].guess_foreign_column(tuple_or_list[1]),
625+
table.guess_foreign_column(tuple_or_list[1]),
612626
)
613627
)
614628
return fks
@@ -618,12 +632,12 @@ def create_table_sql(
618632
name: str,
619633
columns: Dict[str, Any],
620634
pk: Optional[Any] = None,
621-
foreign_keys=None,
622-
column_order=None,
623-
not_null=None,
624-
defaults=None,
625-
hash_id=None,
626-
extracts=None,
635+
foreign_keys: Optional[ForeignKeysType] = None,
636+
column_order: Optional[List[str]] = None,
637+
not_null: Iterable[str] = None,
638+
defaults: Optional[Dict[str, Any]] = None,
639+
hash_id: Optional[Any] = None,
640+
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
627641
) -> str:
628642
"Returns the SQL ``CREATE TABLE`` statement for creating the specified table."
629643
foreign_keys = self.resolve_foreign_keys(name, foreign_keys or [])
@@ -656,9 +670,11 @@ def create_table_sql(
656670
validate_column_names(columns.keys())
657671
column_items = list(columns.items())
658672
if column_order is not None:
659-
column_items.sort(
660-
key=lambda p: column_order.index(p[0]) if p[0] in column_order else 999
661-
)
673+
674+
def sort_key(p):
675+
return column_order.index(p[0]) if p[0] in column_order else 999
676+
677+
column_items.sort(key=sort_key)
662678
if hash_id:
663679
column_items.insert(0, (hash_id, str))
664680
pk = hash_id
@@ -725,12 +741,12 @@ def create_table(
725741
name: str,
726742
columns: Dict[str, Any],
727743
pk: Optional[Any] = None,
728-
foreign_keys=None,
729-
column_order=None,
730-
not_null=None,
731-
defaults=None,
732-
hash_id=None,
733-
extracts=None,
744+
foreign_keys: Optional[ForeignKeysType] = None,
745+
column_order: Optional[List[str]] = None,
746+
not_null: Iterable[str] = None,
747+
defaults: Optional[Dict[str, Any]] = None,
748+
hash_id: Optional[Any] = None,
749+
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
734750
) -> "Table":
735751
"""
736752
Create a table with the specified name and the specified ``{column_name: type}`` columns.
@@ -1021,19 +1037,19 @@ def __init__(
10211037
self,
10221038
db: Database,
10231039
name: str,
1024-
pk=None,
1025-
foreign_keys=None,
1026-
column_order=None,
1027-
not_null=None,
1028-
defaults=None,
1029-
batch_size=100,
1030-
hash_id=None,
1031-
alter=False,
1032-
ignore=False,
1033-
replace=False,
1034-
extracts=None,
1035-
conversions=None,
1036-
columns=None,
1040+
pk: Optional[Any] = None,
1041+
foreign_keys: Optional[ForeignKeysType] = None,
1042+
column_order: Optional[List[str]] = None,
1043+
not_null: Iterable[str] = None,
1044+
defaults: Optional[Dict[str, Any]] = None,
1045+
batch_size: int = 100,
1046+
hash_id: Optional[Any] = None,
1047+
alter: bool = False,
1048+
ignore: bool = False,
1049+
replace: bool = False,
1050+
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
1051+
conversions: Optional[dict] = None,
1052+
columns: Optional[Union[Dict[str, Any]]] = None,
10371053
):
10381054
super().__init__(db, name)
10391055
self._defaults = dict(
@@ -1202,14 +1218,14 @@ def triggers_dict(self) -> Dict[str, str]:
12021218

12031219
def create(
12041220
self,
1205-
columns,
1206-
pk=None,
1207-
foreign_keys=None,
1208-
column_order=None,
1209-
not_null=None,
1210-
defaults=None,
1211-
hash_id=None,
1212-
extracts=None,
1221+
columns: Dict[str, Any],
1222+
pk: Optional[Any] = None,
1223+
foreign_keys: Optional[ForeignKeysType] = None,
1224+
column_order: Optional[List[str]] = None,
1225+
not_null: Iterable[str] = None,
1226+
defaults: Optional[Dict[str, Any]] = None,
1227+
hash_id: Optional[Any] = None,
1228+
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
12131229
) -> "Table":
12141230
"""
12151231
Create a table with the specified columns.
@@ -2914,7 +2930,9 @@ def _hash(record):
29142930
).hexdigest()
29152931

29162932

2917-
def resolve_extracts(extracts):
2933+
def resolve_extracts(
2934+
extracts: Optional[Union[Dict[str, str], List[str], Tuple[str]]]
2935+
) -> dict:
29182936
if extracts is None:
29192937
extracts = {}
29202938
if isinstance(extracts, (list, tuple)):

tests/test_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def test_tracer():
1313
("PRAGMA recursive_triggers=on;", None),
1414
("select name from sqlite_master where type = 'view'", None),
1515
("select name from sqlite_master where type = 'table'", None),
16+
("select name from sqlite_master where type = 'view'", None),
1617
("CREATE TABLE [dogs] (\n [name] TEXT\n);\n ", None),
1718
("select name from sqlite_master where type = 'view'", None),
1819
("INSERT INTO [dogs] ([name]) VALUES (?);", ["Cleopaws"]),

0 commit comments

Comments
 (0)