147
147
Trigger = namedtuple ("Trigger" , ("name" , "table" , "sql" ))
148
148
149
149
150
+ ForeignKeysType = Union [
151
+ Iterable [str ],
152
+ Iterable [ForeignKey ],
153
+ Iterable [Tuple [str , str ]],
154
+ Iterable [Tuple [str , str , str ]],
155
+ Iterable [Tuple [str , str , str , str ]],
156
+ ]
157
+
158
+
150
159
class Default :
151
160
pass
152
161
@@ -572,18 +581,22 @@ def execute_returning_dicts(
572
581
) -> List [dict ]:
573
582
return list (self .query (sql , params ))
574
583
575
- def resolve_foreign_keys (self , name , foreign_keys ):
576
- # foreign_keys may be a list of strcolumn names, a list of ForeignKey tuples,
584
+ def resolve_foreign_keys (
585
+ self , name : str , foreign_keys : ForeignKeysType
586
+ ) -> List [ForeignKey ]:
587
+ # foreign_keys may be a list of column names, a list of ForeignKey tuples,
577
588
# a list of tuple-pairs or a list of tuple-triples. We want to turn
578
589
# it into a list of ForeignKey tuples
590
+ table = cast (Table , self [name ])
579
591
if all (isinstance (fk , ForeignKey ) for fk in foreign_keys ):
580
- return foreign_keys
592
+ return cast ( List [ ForeignKey ], foreign_keys )
581
593
if all (isinstance (fk , str ) for fk in foreign_keys ):
582
594
# It's a list of columns
583
595
fks = []
584
596
for column in foreign_keys :
585
- other_table = self [name ].guess_foreign_table (column )
586
- other_column = self [name ].guess_foreign_column (other_table )
597
+ column = cast (str , column )
598
+ other_table = table .guess_foreign_table (column )
599
+ other_column = table .guess_foreign_column (other_table )
587
600
fks .append (ForeignKey (name , column , other_table , other_column ))
588
601
return fks
589
602
assert all (
@@ -596,6 +609,7 @@ def resolve_foreign_keys(self, name, foreign_keys):
596
609
3 ,
597
610
), "foreign_keys= should be a list of tuple pairs or triples"
598
611
if len (tuple_or_list ) == 3 :
612
+ tuple_or_list = cast (Tuple [str , str , str ], tuple_or_list )
599
613
fks .append (
600
614
ForeignKey (
601
615
name , tuple_or_list [0 ], tuple_or_list [1 ], tuple_or_list [2 ]
@@ -608,7 +622,7 @@ def resolve_foreign_keys(self, name, foreign_keys):
608
622
name ,
609
623
tuple_or_list [0 ],
610
624
tuple_or_list [1 ],
611
- self [ name ] .guess_foreign_column (tuple_or_list [1 ]),
625
+ table .guess_foreign_column (tuple_or_list [1 ]),
612
626
)
613
627
)
614
628
return fks
@@ -618,12 +632,12 @@ def create_table_sql(
618
632
name : str ,
619
633
columns : Dict [str , Any ],
620
634
pk : Optional [Any ] = None ,
621
- foreign_keys = None ,
622
- column_order = None ,
623
- not_null = None ,
624
- defaults = None ,
625
- hash_id = None ,
626
- extracts = None ,
635
+ foreign_keys : Optional [ ForeignKeysType ] = None ,
636
+ column_order : Optional [ List [ str ]] = None ,
637
+ not_null : Iterable [ str ] = None ,
638
+ defaults : Optional [ Dict [ str , Any ]] = None ,
639
+ hash_id : Optional [ Any ] = None ,
640
+ extracts : Optional [ Union [ Dict [ str , str ], List [ str ]]] = None ,
627
641
) -> str :
628
642
"Returns the SQL ``CREATE TABLE`` statement for creating the specified table."
629
643
foreign_keys = self .resolve_foreign_keys (name , foreign_keys or [])
@@ -656,9 +670,11 @@ def create_table_sql(
656
670
validate_column_names (columns .keys ())
657
671
column_items = list (columns .items ())
658
672
if column_order is not None :
659
- column_items .sort (
660
- key = lambda p : column_order .index (p [0 ]) if p [0 ] in column_order else 999
661
- )
673
+
674
+ def sort_key (p ):
675
+ return column_order .index (p [0 ]) if p [0 ] in column_order else 999
676
+
677
+ column_items .sort (key = sort_key )
662
678
if hash_id :
663
679
column_items .insert (0 , (hash_id , str ))
664
680
pk = hash_id
@@ -725,12 +741,12 @@ def create_table(
725
741
name : str ,
726
742
columns : Dict [str , Any ],
727
743
pk : Optional [Any ] = None ,
728
- foreign_keys = None ,
729
- column_order = None ,
730
- not_null = None ,
731
- defaults = None ,
732
- hash_id = None ,
733
- extracts = None ,
744
+ foreign_keys : Optional [ ForeignKeysType ] = None ,
745
+ column_order : Optional [ List [ str ]] = None ,
746
+ not_null : Iterable [ str ] = None ,
747
+ defaults : Optional [ Dict [ str , Any ]] = None ,
748
+ hash_id : Optional [ Any ] = None ,
749
+ extracts : Optional [ Union [ Dict [ str , str ], List [ str ]]] = None ,
734
750
) -> "Table" :
735
751
"""
736
752
Create a table with the specified name and the specified ``{column_name: type}`` columns.
@@ -1021,19 +1037,19 @@ def __init__(
1021
1037
self ,
1022
1038
db : Database ,
1023
1039
name : str ,
1024
- pk = None ,
1025
- foreign_keys = None ,
1026
- column_order = None ,
1027
- not_null = None ,
1028
- defaults = None ,
1029
- batch_size = 100 ,
1030
- hash_id = None ,
1031
- alter = False ,
1032
- ignore = False ,
1033
- replace = False ,
1034
- extracts = None ,
1035
- conversions = None ,
1036
- columns = None ,
1040
+ pk : Optional [ Any ] = None ,
1041
+ foreign_keys : Optional [ ForeignKeysType ] = None ,
1042
+ column_order : Optional [ List [ str ]] = None ,
1043
+ not_null : Iterable [ str ] = None ,
1044
+ defaults : Optional [ Dict [ str , Any ]] = None ,
1045
+ batch_size : int = 100 ,
1046
+ hash_id : Optional [ Any ] = None ,
1047
+ alter : bool = False ,
1048
+ ignore : bool = False ,
1049
+ replace : bool = False ,
1050
+ extracts : Optional [ Union [ Dict [ str , str ], List [ str ]]] = None ,
1051
+ conversions : Optional [ dict ] = None ,
1052
+ columns : Optional [ Union [ Dict [ str , Any ]]] = None ,
1037
1053
):
1038
1054
super ().__init__ (db , name )
1039
1055
self ._defaults = dict (
@@ -1202,14 +1218,14 @@ def triggers_dict(self) -> Dict[str, str]:
1202
1218
1203
1219
def create (
1204
1220
self ,
1205
- columns ,
1206
- pk = None ,
1207
- foreign_keys = None ,
1208
- column_order = None ,
1209
- not_null = None ,
1210
- defaults = None ,
1211
- hash_id = None ,
1212
- extracts = None ,
1221
+ columns : Dict [ str , Any ] ,
1222
+ pk : Optional [ Any ] = None ,
1223
+ foreign_keys : Optional [ ForeignKeysType ] = None ,
1224
+ column_order : Optional [ List [ str ]] = None ,
1225
+ not_null : Iterable [ str ] = None ,
1226
+ defaults : Optional [ Dict [ str , Any ]] = None ,
1227
+ hash_id : Optional [ Any ] = None ,
1228
+ extracts : Optional [ Union [ Dict [ str , str ], List [ str ]]] = None ,
1213
1229
) -> "Table" :
1214
1230
"""
1215
1231
Create a table with the specified columns.
@@ -2914,7 +2930,9 @@ def _hash(record):
2914
2930
).hexdigest ()
2915
2931
2916
2932
2917
- def resolve_extracts (extracts ):
2933
+ def resolve_extracts (
2934
+ extracts : Optional [Union [Dict [str , str ], List [str ], Tuple [str ]]]
2935
+ ) -> dict :
2918
2936
if extracts is None :
2919
2937
extracts = {}
2920
2938
if isinstance (extracts , (list , tuple )):
0 commit comments