@@ -3388,7 +3388,7 @@ def insert_all(
33883388
33893389 # Detect if we're using list-based iteration or dict-based iteration
33903390 list_mode = False
3391- column_names = None
3391+ column_names : List [ str ] = []
33923392
33933393 # Fix up any records with square braces in the column names (only for dict mode)
33943394 # We'll handle this differently for list mode
@@ -3423,11 +3423,14 @@ def insert_all(
34233423 else :
34243424 # Dict mode: traditional behavior
34253425 records_iter = itertools .chain ([first_record ], records_iter )
3426- records_iter = fix_square_braces (records_iter )
3426+ records_iter = fix_square_braces (
3427+ cast (Iterable [Dict [str , Any ]], records_iter )
3428+ )
34273429 try :
34283430 first_record = next (records_iter )
34293431 except StopIteration :
34303432 return self
3433+ first_record = cast (Dict [str , Any ], first_record )
34313434 num_columns = len (first_record .keys ())
34323435
34333436 assert (
@@ -3526,13 +3529,16 @@ def insert_all(
35263529 self .last_pk = self .last_rowid
35273530 else :
35283531 # For an upsert use first_record from earlier
3532+ # Note: This code path assumes dict mode; list mode upserts
3533+ # with single records don't populate last_pk correctly yet
3534+ first_record_dict = cast (Dict [str , Any ], first_record )
35293535 if hash_id :
3530- self .last_pk = hash_record (first_record , hash_id_columns )
3536+ self .last_pk = hash_record (first_record_dict , hash_id_columns )
35313537 else :
35323538 self .last_pk = (
3533- first_record [pk ]
3539+ first_record_dict [pk ]
35343540 if isinstance (pk , str )
3535- else tuple (first_record [p ] for p in pk )
3541+ else tuple (first_record_dict [p ] for p in pk )
35363542 )
35373543
35383544 if analyze :
0 commit comments