Skip to content

Commit dc9947a

Browse files
committed
1 parent 77359be commit dc9947a

File tree

7 files changed

+110
-61
lines changed

7 files changed

+110
-61
lines changed

sqlite_utils/cli.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,7 @@ def inner(fn):
905905
required=True,
906906
),
907907
click.argument("table"),
908-
click.argument("file", type=click.File("rb"), required=True),
908+
click.argument("file", type=click.File("rb", lazy=True), required=True),
909909
click.option(
910910
"--pk",
911911
help="Columns to use as the primary key, e.g. id",
@@ -2000,6 +2000,7 @@ def memory(
20002000
for i, path in enumerate(paths):
20012001
# Path may have a :format suffix
20022002
fp = None
2003+
should_close_fp = False
20032004
if ":" in path and path.rsplit(":", 1)[-1].upper() in Format.__members__:
20042005
path, suffix = path.rsplit(":", 1)
20052006
format = Format[suffix.upper()]
@@ -2017,29 +2018,32 @@ def memory(
20172018
file_table = stem
20182019
stem_counts[stem] = stem_counts.get(stem, 1) + 1
20192020
fp = file_path.open("rb")
2020-
rows, format_used = rows_from_file(fp, format=format, encoding=encoding)
2021-
tracker = None
2022-
if format_used in (Format.CSV, Format.TSV) and not no_detect_types:
2023-
tracker = TypeTracker()
2024-
rows = tracker.wrap(rows)
2025-
if flatten:
2026-
rows = (_flatten(row) for row in rows)
2027-
2028-
db[file_table].insert_all(rows, alter=True)
2029-
if tracker is not None:
2030-
db[file_table].transform(types=tracker.types)
2031-
# Add convenient t / t1 / t2 views
2032-
view_names = ["t{}".format(i + 1)]
2033-
if i == 0:
2034-
view_names.append("t")
2035-
for view_name in view_names:
2036-
if not db[view_name].exists():
2037-
db.create_view(
2038-
view_name, "select * from {}".format(quote_identifier(file_table))
2039-
)
2040-
2041-
if fp:
2042-
fp.close()
2021+
should_close_fp = True
2022+
try:
2023+
rows, format_used = rows_from_file(fp, format=format, encoding=encoding)
2024+
tracker = None
2025+
if format_used in (Format.CSV, Format.TSV) and not no_detect_types:
2026+
tracker = TypeTracker()
2027+
rows = tracker.wrap(rows)
2028+
if flatten:
2029+
rows = (_flatten(row) for row in rows)
2030+
2031+
db[file_table].insert_all(rows, alter=True)
2032+
if tracker is not None:
2033+
db[file_table].transform(types=tracker.types)
2034+
# Add convenient t / t1 / t2 views
2035+
view_names = ["t{}".format(i + 1)]
2036+
if i == 0:
2037+
view_names.append("t")
2038+
for view_name in view_names:
2039+
if not db[view_name].exists():
2040+
db.create_view(
2041+
view_name,
2042+
"select * from {}".format(quote_identifier(file_table)),
2043+
)
2044+
finally:
2045+
if should_close_fp and fp:
2046+
fp.close()
20432047

20442048
if analyze:
20452049
_analyze(db, tables=None, columns=None, save=False)

sqlite_utils/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,29 @@
99
import os
1010
import sys
1111
from . import recipes
12-
from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type
12+
from typing import Dict, cast, BinaryIO, Iterable, Iterator, Optional, Tuple, Type
13+
14+
15+
class _CloseableIterator(Iterator[dict]):
16+
"""Iterator wrapper that closes a file when iteration is complete."""
17+
18+
def __init__(self, iterator: Iterator[dict], closeable: io.IOBase):
19+
self._iterator = iterator
20+
self._closeable = closeable
21+
22+
def __iter__(self) -> "_CloseableIterator":
23+
return self
24+
25+
def __next__(self) -> dict:
26+
try:
27+
return next(self._iterator)
28+
except StopIteration:
29+
self._closeable.close()
30+
raise
31+
32+
def close(self) -> None:
33+
self._closeable.close()
34+
1335

1436
import click
1537

@@ -299,7 +321,8 @@ class Format(enum.Enum):
299321
reader = csv.DictReader(decoded_fp, dialect=dialect)
300322
else:
301323
reader = csv.DictReader(decoded_fp)
302-
return _extra_key_strategy(reader, ignore_extras, extras_key), Format.CSV
324+
rows = _extra_key_strategy(reader, ignore_extras, extras_key)
325+
return _CloseableIterator(iter(rows), decoded_fp), Format.CSV
303326
elif format == Format.TSV:
304327
rows = rows_from_file(
305328
fp, format=Format.CSV, dialect=csv.excel_tab, encoding=encoding

tests/test_cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ def _supports_pragma_function_list():
1919
db = Database(memory=True)
2020
try:
2121
db.execute("select * from pragma_function_list()")
22+
return True
2223
except Exception:
2324
return False
24-
return True
25+
finally:
26+
db.close()
2527

2628

2729
def _has_compiled_ext():

tests/test_introspect.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
import pytest
33

44

5+
def _check_supports_strict():
6+
"""Check if SQLite supports strict tables without leaking the database."""
7+
db = Database(memory=True)
8+
result = db.supports_strict
9+
db.close()
10+
return result
11+
12+
513
def test_table_names(existing_db):
614
assert ["foo"] == existing_db.table_names()
715

@@ -282,7 +290,7 @@ def test_use_rowid(fresh_db):
282290

283291

284292
@pytest.mark.skipif(
285-
not Database(memory=True).supports_strict,
293+
not _check_supports_strict(),
286294
reason="Needs SQLite version that supports strict",
287295
)
288296
@pytest.mark.parametrize(

tests/test_plugins.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ def _supports_pragma_function_list():
99
db = Database(memory=True)
1010
try:
1111
db.execute("select * from pragma_function_list()")
12+
return True
1213
except Exception:
1314
return False
14-
return True
15+
finally:
16+
db.close()
1517

1618

1719
def test_register_commands():

tests/test_recreate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ def test_recreate_ignored_for_in_memory():
1414

1515
def test_recreate_not_allowed_for_connection():
1616
conn = sqlite3.connect(":memory:")
17-
with pytest.raises(AssertionError):
18-
Database(conn, recreate=True)
17+
try:
18+
with pytest.raises(AssertionError):
19+
Database(conn, recreate=True)
20+
finally:
21+
conn.close()
1922

2023

2124
@pytest.mark.parametrize(

tests/test_register_function.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,46 @@ def to_lower(s):
4242

4343

4444
def test_register_function_deterministic_tries_again_if_exception_raised(fresh_db):
45+
# Save the original connection so we can close it later
46+
original_conn = fresh_db.conn
4547
fresh_db.conn = MagicMock()
4648
fresh_db.conn.create_function = MagicMock()
4749

48-
@fresh_db.register_function(deterministic=True)
49-
def to_lower_2(s):
50-
return s.lower()
51-
52-
fresh_db.conn.create_function.assert_called_with(
53-
"to_lower_2", 1, to_lower_2, deterministic=True
54-
)
55-
56-
first = True
57-
58-
def side_effect(*args, **kwargs):
59-
# Raise exception only first time this is called
60-
nonlocal first
61-
if first:
62-
first = False
63-
raise sqlite3.NotSupportedError()
64-
65-
# But if sqlite3.NotSupportedError is raised, it tries again
66-
fresh_db.conn.create_function.reset_mock()
67-
fresh_db.conn.create_function.side_effect = side_effect
68-
69-
@fresh_db.register_function(deterministic=True)
70-
def to_lower_3(s):
71-
return s.lower()
72-
73-
# Should have been called once with deterministic=True and once without
74-
assert fresh_db.conn.create_function.call_args_list == [
75-
call("to_lower_3", 1, to_lower_3, deterministic=True),
76-
call("to_lower_3", 1, to_lower_3),
77-
]
50+
try:
51+
52+
@fresh_db.register_function(deterministic=True)
53+
def to_lower_2(s):
54+
return s.lower()
55+
56+
fresh_db.conn.create_function.assert_called_with(
57+
"to_lower_2", 1, to_lower_2, deterministic=True
58+
)
59+
60+
first = True
61+
62+
def side_effect(*args, **kwargs):
63+
# Raise exception only first time this is called
64+
nonlocal first
65+
if first:
66+
first = False
67+
raise sqlite3.NotSupportedError()
68+
69+
# But if sqlite3.NotSupportedError is raised, it tries again
70+
fresh_db.conn.create_function.reset_mock()
71+
fresh_db.conn.create_function.side_effect = side_effect
72+
73+
@fresh_db.register_function(deterministic=True)
74+
def to_lower_3(s):
75+
return s.lower()
76+
77+
# Should have been called once with deterministic=True and once without
78+
assert fresh_db.conn.create_function.call_args_list == [
79+
call("to_lower_3", 1, to_lower_3, deterministic=True),
80+
call("to_lower_3", 1, to_lower_3),
81+
]
82+
finally:
83+
# Close the original connection that was replaced with the mock
84+
original_conn.close()
7885

7986

8087
def test_register_function_replace(fresh_db):

0 commit comments

Comments
 (0)