Skip to content

Commit 4862fbf

Browse files
author
Daniele Briggi
committed
fix(tests): failing
1 parent 42f798b commit 4862fbf

File tree

5 files changed

+126
-110
lines changed

5 files changed

+126
-110
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ repos:
1515
rev: v2.3.1
1616
hooks:
1717
- id: autoflake
18-
args: ["--remove-all-unused-imports", "--remove-unused-variables", "--ignore-init-module-imports", "--in-place", "--recursive", "./src", "./test"]
18+
args: ["--remove-all-unused-imports", "--remove-unused-variables", "--ignore-init-module-imports", "--in-place", "--recursive", "./src", "./tests"]
1919
- repo: https://github.com/pycqa/isort
2020
rev: 6.0.1
2121
hooks:

src/sqlite_rag/cli.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,38 @@
1414
from .formatters import get_formatter
1515
from .sqliterag import SQLiteRag
1616

17+
DEFAULT_DATABASE_PATH = "./sqliterag.sqlite"
18+
1719

1820
class RAGContext:
1921
"""Manage CLI state and RAG object reuse"""
2022

2123
def __init__(self):
2224
self.rag: Optional[SQLiteRag] = None
2325
self.in_repl = False
24-
self.database: str = ""
26+
self.database_path: str = ""
2527

2628
def enter_repl(self):
2729
"""Enter REPL mode"""
2830
self.in_repl = True
2931

30-
def get_rag(self, database_path: str, require_existing: bool = False) -> SQLiteRag:
32+
def get_rag(self, require_existing: bool = False) -> SQLiteRag:
3133
"""Create or reuse SQLiteRag instance"""
32-
if not self.database:
34+
if not self.database_path:
3335
raise ValueError("Database path not set. Use --database option.")
3436

3537
if self.in_repl:
3638
if self.rag is None:
37-
typer.echo(f"Debug: Using database path: {self.database}")
3839
self.rag = SQLiteRag.create(
39-
self.database, require_existing=require_existing
40+
self.database_path, require_existing=require_existing
4041
)
41-
typer.echo(f"Database: {Path(self.database).resolve()}")
4242
return self.rag
4343
else:
4444
# Regular mode - create new instance
45-
return SQLiteRag.create(database_path, require_existing=require_existing)
45+
typer.echo(f"Database: {Path(self.database_path).resolve()}")
46+
return SQLiteRag.create(
47+
self.database_path, require_existing=require_existing
48+
)
4649

4750

4851
rag_context = RAGContext()
@@ -66,28 +69,32 @@ def __call__(self, *args, **kwds):
6669
def main(
6770
ctx: typer.Context,
6871
database: str = typer.Option(
69-
"./sqliterag.sqlite",
72+
DEFAULT_DATABASE_PATH,
7073
"--database",
7174
"-db",
7275
help="Path to the SQLite database file",
7376
),
7477
):
7578
"""SQLite RAG - Retrieval Augmented Generation with SQLite"""
7679
ctx.ensure_object(dict)
77-
rag_context.database = database
7880
ctx.obj["rag_context"] = rag_context
7981

82+
if not rag_context.in_repl:
83+
rag_context.database_path = database
84+
8085
# If no subcommand was invoked, enter REPL mode
81-
if ctx.invoked_subcommand is None:
86+
if ctx.invoked_subcommand is None and not rag_context.in_repl:
8287
rag_context.enter_repl()
88+
typer.echo(f"Database: {Path(database).resolve()}")
89+
8390
repl_mode()
8491

8592

8693
@app.command("settings")
8794
def show_settings(ctx: typer.Context):
8895
"""Show current settings"""
8996
rag_context = ctx.obj["rag_context"]
90-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
97+
rag = rag_context.get_rag(require_existing=True)
9198
current_settings = rag.get_settings()
9299

93100
typer.echo("Current settings:")
@@ -168,7 +175,7 @@ def configure_settings(
168175
show_settings(ctx)
169176
return
170177

171-
conn = Database.new_connection(rag_context.database)
178+
conn = Database.new_connection(rag_context.database_path)
172179
settings_manager = SettingsManager(conn)
173180
settings_manager.configure(updates)
174181

@@ -199,7 +206,7 @@ def add(
199206
rag_context = ctx.obj["rag_context"]
200207
start_time = time.time()
201208

202-
rag = rag_context.get_rag(rag_context.database)
209+
rag = rag_context.get_rag()
203210
rag.add(
204211
path,
205212
recursive=recursive,
@@ -225,7 +232,7 @@ def add_text(
225232
):
226233
"""Add a text to the database"""
227234
rag_context = ctx.obj["rag_context"]
228-
rag = rag_context.get_rag(rag_context.database)
235+
rag = rag_context.get_rag()
229236
rag.add_text(text, uri=uri, metadata=json.loads(metadata or "{}"))
230237
typer.echo("Text added.")
231238

@@ -234,7 +241,7 @@ def add_text(
234241
def list_documents(ctx: typer.Context):
235242
"""List all documents in the database"""
236243
rag_context = ctx.obj["rag_context"]
237-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
244+
rag = rag_context.get_rag(require_existing=True)
238245
documents = rag.list_documents()
239246

240247
if not documents:
@@ -266,7 +273,7 @@ def remove(
266273
):
267274
"""Remove document by path or UUID"""
268275
rag_context = ctx.obj["rag_context"]
269-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
276+
rag = rag_context.get_rag(require_existing=True)
270277

271278
# Find the document first
272279
document = rag.find_document(identifier)
@@ -311,7 +318,7 @@ def rebuild(
311318
):
312319
"""Rebuild embeddings and full-text index"""
313320
rag_context = ctx.obj["rag_context"]
314-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
321+
rag = rag_context.get_rag(require_existing=True)
315322

316323
typer.echo("Rebuild process...")
317324

@@ -331,7 +338,7 @@ def reset(
331338
):
332339
"""Reset/clear the entire database"""
333340
rag_context = ctx.obj["rag_context"]
334-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
341+
rag = rag_context.get_rag(require_existing=True)
335342

336343
# Show warning and ask for confirmation unless -y flag is used
337344
if not yes:
@@ -374,7 +381,7 @@ def search(
374381
rag_context = ctx.obj["rag_context"]
375382
start_time = time.time()
376383

377-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
384+
rag = rag_context.get_rag(require_existing=True)
378385
results = rag.search(query, top_k=limit)
379386

380387
search_time = time.time() - start_time
@@ -397,7 +404,7 @@ def quantize(
397404
):
398405
"""Quantize vectors for faster search or clean up quantization structures"""
399406
rag_context = ctx.obj["rag_context"]
400-
rag = rag_context.get_rag(rag_context.database, require_existing=True)
407+
rag = rag_context.get_rag(require_existing=True)
401408

402409
if cleanup:
403410
typer.echo("Cleaning up quantization structures...")

tests/integration/test_cli.py

Lines changed: 93 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,114 @@
1-
import os
21
import tempfile
32
from pathlib import Path
43

5-
from pytest import fixture
64
from typer.testing import CliRunner
75

86
from sqlite_rag.cli import app
97
from sqlite_rag.settings import Settings
108

119

12-
@fixture
13-
def temp_dir():
14-
"""Change the current working directory in order to create
15-
the default database in a temporary location."""
16-
with tempfile.TemporaryDirectory() as tmpdir:
17-
original_cwd = os.getcwd()
18-
try:
19-
os.chdir(tmpdir)
20-
yield tmpdir
21-
finally:
22-
os.chdir(original_cwd)
23-
24-
2510
class TestCLI:
2611
def test_search_exact_match(self):
27-
# Use SQLiteRag to set up the test data directly
28-
with tempfile.TemporaryDirectory() as tmpdir:
29-
doc1_content = "The quick brown fox jumps over the lazy dog"
30-
doc2_content = (
31-
"How much wood would a woodchuck chuck if a woodchuck could chuck wood?"
32-
)
12+
"""Test adding documents and searching for an exact match."""
13+
doc1_content = "The quick brown fox jumps over the lazy dog"
14+
doc2_content = (
15+
"How much wood would a woodchuck chuck if a woodchuck could chuck wood?"
16+
)
17+
18+
with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
3319

3420
runner = CliRunner()
3521

3622
model_path = Path(Settings().model_path).absolute()
3723

38-
# Change to the temporary directory so CLI finds the database
39-
original_cwd = os.getcwd()
40-
try:
41-
os.chdir(tmpdir)
42-
# CWD has changed so the model must be referenced by absolute path
43-
result = runner.invoke(
44-
app,
45-
[
46-
"set",
47-
"--model-path-or-name",
48-
str(model_path),
49-
"--other-vector-options",
50-
"distance=cosine",
51-
],
52-
)
53-
assert result.exit_code == 0
54-
55-
# Add
56-
result = runner.invoke(
57-
app,
58-
[
59-
"add-text",
60-
doc1_content,
61-
],
62-
)
63-
assert result.exit_code == 0
64-
65-
result = runner.invoke(
66-
app,
67-
[
68-
"add-text",
69-
doc2_content,
70-
],
71-
)
72-
assert result.exit_code == 0
73-
74-
# Search
75-
result = runner.invoke(
76-
app, ["search", doc1_content, "--debug", "--limit", "1"]
77-
)
78-
finally:
79-
os.chdir(original_cwd)
24+
result = runner.invoke(
25+
app,
26+
[
27+
"--database",
28+
tmp_db.name,
29+
"configure",
30+
"--model-path",
31+
str(model_path),
32+
"--other-vector-options",
33+
"distance=cosine",
34+
],
35+
)
36+
assert result.exit_code == 0
37+
38+
# Add
39+
result = runner.invoke(
40+
app,
41+
[
42+
"--database",
43+
tmp_db.name,
44+
"add-text",
45+
doc1_content,
46+
],
47+
)
48+
assert result.exit_code == 0
49+
50+
result = runner.invoke(
51+
app,
52+
[
53+
"--database",
54+
tmp_db.name,
55+
"add-text",
56+
doc2_content,
57+
],
58+
)
59+
assert result.exit_code == 0
60+
61+
# Search
62+
result = runner.invoke(
63+
app,
64+
[
65+
"--database",
66+
tmp_db.name,
67+
"search",
68+
doc1_content,
69+
"--debug",
70+
"--limit",
71+
"1",
72+
],
73+
)
8074

8175
# Assert CLI command executed successfully
8276
assert result.exit_code == 0
83-
assert "Found 1 documents" in result.stdout
77+
assert "Search Results (1 matches)" in result.stdout
8478
# For exact match with cosine distance, we expect distance close to 0.0
85-
assert "0.000000" in result.stdout or "0.00000" in result.stdout
86-
87-
def test_set_settings(self, temp_dir):
88-
runner = CliRunner()
89-
90-
model_path = "mypath/mymodel.gguf"
91-
92-
result = runner.invoke(
93-
app,
94-
[
95-
"set",
96-
"--model-path-or-name",
97-
model_path,
98-
"--other-vector-options",
99-
"distance=L2",
100-
],
101-
)
102-
assert result.exit_code == 0
79+
assert "Vector: 0.000000" in result.stdout or "0.00000" in result.stdout
80+
81+
def test_set_settings(self):
82+
with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
83+
runner = CliRunner()
84+
85+
model_path = "mypath/mymodel.gguf"
86+
87+
result = runner.invoke(
88+
app,
89+
[
90+
"--database",
91+
tmp_db.name,
92+
"configure",
93+
"--model-path",
94+
model_path,
95+
"--other-vector-options",
96+
"distance=L2",
97+
],
98+
)
99+
assert result.exit_code == 0
100+
101+
assert f"model_path: {model_path}" in result.stdout
102+
assert "other_vector_options: distance=L2" in result.stdout
103+
104+
def test_change_database_path(self):
105+
with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
106+
runner = CliRunner()
107+
108+
result = runner.invoke(
109+
app,
110+
["--database", tmp_db.name, "settings"],
111+
)
112+
assert result.exit_code == 0
103113

104-
assert f"model_path: {model_path}" in result.stdout
105-
assert "other_vector_options: distance=L2" in result.stdout
114+
assert f"Database: {tmp_db.name}" in result.stdout

tests/test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_search_semantic_result(self, db_conn):
9191
engine.quantize()
9292

9393
# Act
94-
results = engine.search("lumberjack", limit=5)
94+
results = engine.search("about lumberjack", limit=5)
9595

9696
assert len(results) > 0
9797
assert doc3_id == results[0].document.id

0 commit comments

Comments
 (0)