Skip to content

Commit 37273d7

Browse files
authored
Fixed issue #433 - CLI eats cursor (#598)
The issue is that underlying iterator is not fully consumed within the body of the `with file_progress()` block. Instead, that block creates generator expressions like `docs = (dict(zip(headers, row)) for row in reader)` These iterables are consumed later, outside the `with file_progress()` block, which consumes the underlying iterator, and in turn updates the progress bar. This means that the `ProgressBar.__exit__` method gets called before the last time the `ProgressBar.update` method gets called. The result is that the code to make the cursor invisible (inside the `update()` method) is called after the cleanup code to make it visible (in the `__exit__` method). The fix is to move consumption of the `docs` iterators within the progress bar block. (An additional fix, to make ProgressBar more robust against this kind of misuse, would to make it refusing to update after its `__exit__` method had been called, just like files cannot be `read()` after they are closed. That requires a in the click library).
1 parent b92ea47 commit 37273d7

File tree

1 file changed

+87
-83
lines changed

1 file changed

+87
-83
lines changed

sqlite_utils/cli.py

Lines changed: 87 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,93 +1024,97 @@ def insert_upsert_implementation(
10241024
if flatten:
10251025
docs = (_flatten(doc) for doc in docs)
10261026

1027-
if stop_after:
1028-
docs = itertools.islice(docs, stop_after)
1029-
1030-
if convert:
1031-
variable = "row"
1032-
if lines:
1033-
variable = "line"
1034-
elif text:
1035-
variable = "text"
1036-
fn = _compile_code(convert, imports, variable=variable)
1037-
if lines:
1038-
docs = (fn(doc["line"]) for doc in docs)
1039-
elif text:
1040-
# Special case: this is allowed to be an iterable
1041-
text_value = list(docs)[0]["text"]
1042-
fn_return = fn(text_value)
1043-
if isinstance(fn_return, dict):
1044-
docs = [fn_return]
1027+
if stop_after:
1028+
docs = itertools.islice(docs, stop_after)
1029+
1030+
if convert:
1031+
variable = "row"
1032+
if lines:
1033+
variable = "line"
1034+
elif text:
1035+
variable = "text"
1036+
fn = _compile_code(convert, imports, variable=variable)
1037+
if lines:
1038+
docs = (fn(doc["line"]) for doc in docs)
1039+
elif text:
1040+
# Special case: this is allowed to be an iterable
1041+
text_value = list(docs)[0]["text"]
1042+
fn_return = fn(text_value)
1043+
if isinstance(fn_return, dict):
1044+
docs = [fn_return]
1045+
else:
1046+
try:
1047+
docs = iter(fn_return)
1048+
except TypeError:
1049+
raise click.ClickException(
1050+
"--convert must return dict or iterator"
1051+
)
10451052
else:
1046-
try:
1047-
docs = iter(fn_return)
1048-
except TypeError:
1049-
raise click.ClickException("--convert must return dict or iterator")
1050-
else:
1051-
docs = (fn(doc) or doc for doc in docs)
1052-
1053-
extra_kwargs = {
1054-
"ignore": ignore,
1055-
"replace": replace,
1056-
"truncate": truncate,
1057-
"analyze": analyze,
1058-
}
1059-
if not_null:
1060-
extra_kwargs["not_null"] = set(not_null)
1061-
if default:
1062-
extra_kwargs["defaults"] = dict(default)
1063-
if upsert:
1064-
extra_kwargs["upsert"] = upsert
1065-
1066-
# docs should all be dictionaries
1067-
docs = (verify_is_dict(doc) for doc in docs)
1068-
1069-
# Apply {"$base64": true, ...} decoding, if needed
1070-
docs = (decode_base64_values(doc) for doc in docs)
1071-
1072-
# For bulk_sql= we use cursor.executemany() instead
1073-
if bulk_sql:
1074-
if batch_size:
1075-
doc_chunks = chunks(docs, batch_size)
1076-
else:
1077-
doc_chunks = [docs]
1078-
for doc_chunk in doc_chunks:
1079-
with db.conn:
1080-
db.conn.cursor().executemany(bulk_sql, doc_chunk)
1081-
return
1053+
docs = (fn(doc) or doc for doc in docs)
1054+
1055+
extra_kwargs = {
1056+
"ignore": ignore,
1057+
"replace": replace,
1058+
"truncate": truncate,
1059+
"analyze": analyze,
1060+
}
1061+
if not_null:
1062+
extra_kwargs["not_null"] = set(not_null)
1063+
if default:
1064+
extra_kwargs["defaults"] = dict(default)
1065+
if upsert:
1066+
extra_kwargs["upsert"] = upsert
1067+
1068+
# docs should all be dictionaries
1069+
docs = (verify_is_dict(doc) for doc in docs)
1070+
1071+
# Apply {"$base64": true, ...} decoding, if needed
1072+
docs = (decode_base64_values(doc) for doc in docs)
1073+
1074+
# For bulk_sql= we use cursor.executemany() instead
1075+
if bulk_sql:
1076+
if batch_size:
1077+
doc_chunks = chunks(docs, batch_size)
1078+
else:
1079+
doc_chunks = [docs]
1080+
for doc_chunk in doc_chunks:
1081+
with db.conn:
1082+
db.conn.cursor().executemany(bulk_sql, doc_chunk)
1083+
return
10821084

1083-
try:
1084-
db[table].insert_all(
1085-
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
1086-
)
1087-
except Exception as e:
1088-
if (
1089-
isinstance(e, OperationalError)
1090-
and e.args
1091-
and "has no column named" in e.args[0]
1092-
):
1093-
raise click.ClickException(
1094-
"{}\n\nTry using --alter to add additional columns".format(e.args[0])
1085+
try:
1086+
db[table].insert_all(
1087+
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
10951088
)
1096-
# If we can find sql= and parameters= arguments, show those
1097-
variables = _find_variables(e.__traceback__, ["sql", "parameters"])
1098-
if "sql" in variables and "parameters" in variables:
1099-
raise click.ClickException(
1100-
"{}\n\nsql = {}\nparameters = {}".format(
1101-
str(e), variables["sql"], variables["parameters"]
1089+
except Exception as e:
1090+
if (
1091+
isinstance(e, OperationalError)
1092+
and e.args
1093+
and "has no column named" in e.args[0]
1094+
):
1095+
raise click.ClickException(
1096+
"{}\n\nTry using --alter to add additional columns".format(
1097+
e.args[0]
1098+
)
11021099
)
1103-
)
1104-
else:
1105-
raise
1106-
if tracker is not None:
1107-
db[table].transform(types=tracker.types)
1108-
1109-
# Clean up open file-like objects
1110-
if sniff_buffer:
1111-
sniff_buffer.close()
1112-
if decoded_buffer:
1113-
decoded_buffer.close()
1100+
# If we can find sql= and parameters= arguments, show those
1101+
variables = _find_variables(e.__traceback__, ["sql", "parameters"])
1102+
if "sql" in variables and "parameters" in variables:
1103+
raise click.ClickException(
1104+
"{}\n\nsql = {}\nparameters = {}".format(
1105+
str(e), variables["sql"], variables["parameters"]
1106+
)
1107+
)
1108+
else:
1109+
raise
1110+
if tracker is not None:
1111+
db[table].transform(types=tracker.types)
1112+
1113+
# Clean up open file-like objects
1114+
if sniff_buffer:
1115+
sniff_buffer.close()
1116+
if decoded_buffer:
1117+
decoded_buffer.close()
11141118

11151119

11161120
def _find_variables(tb, vars):

0 commit comments

Comments
 (0)