Skip to content

Commit 4d701fb

Browse files
author
Daniele Briggi
committed
feat(settings): improve
1 parent 802ff55 commit 4d701fb

File tree

7 files changed

+160
-40
lines changed

7 files changed

+160
-40
lines changed

.devcontainer/devcontainer.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@
1515
]
1616
}
1717
},
18-
"runArgs": ["--network=host"]
1918
}

src/sqlite_rag/cli.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import json
33
import shlex
44
import sys
5-
from dataclasses import replace
65
from typing import Optional
76

87
import typer
98

10-
from sqlite_rag.settings import Settings
9+
from sqlite_rag.database import Database
10+
from sqlite_rag.settings import SettingsManager
1111

1212
from .sqliterag import SQLiteRag
1313

@@ -40,7 +40,6 @@ def show_settings():
4040
typer.echo(f" {key}: {value}")
4141

4242

43-
# TODO: separate store settings from SQLiteRag.create()?
4443
@app.command("set")
4544
def set_settings(
4645
model_path_or_name: Optional[str] = typer.Option(
@@ -106,9 +105,9 @@ def set_settings(
106105
show_settings()
107106
return
108107

109-
# Create new settings with updated fields
110-
new_settings = replace(Settings(), **updates)
111-
SQLiteRag.create(settings=new_settings)
108+
conn = Database.new_connection()
109+
settings_manager = SettingsManager(conn)
110+
settings_manager.prepare_settings(updates)
112111

113112
show_settings()
114113
typer.echo("Settings updated.")

src/sqlite_rag/engine.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def load_model(self):
3737
# )
3838

3939
self._conn.execute(
40-
f"SELECT llm_model_load('{self._settings.model_path_or_name}', '{self._settings.model_config}');"
40+
"SELECT llm_model_load(?, ?);",
41+
(self._settings.model_path_or_name, self._settings.model_config),
4142
)
4243

4344
def process(self, document: Document) -> Document:
@@ -52,6 +53,8 @@ def generate_embedding(self, chunks: list[Chunk]) -> list[Chunk]:
5253
for chunk in chunks:
5354
cursor = self._conn.cursor()
5455

56+
self.create_new_context()
57+
5558
try:
5659
cursor.execute(
5760
"SELECT llm_embed_generate(?) AS embedding", (chunk.content,)
@@ -75,23 +78,29 @@ def quantize(self) -> None:
7578

7679
cursor.execute("SELECT vector_quantize('chunks', 'embedding');")
7780

78-
self._conn.commit()
79-
8081
def quantize_preload(self) -> None:
8182
"""Preload quantized vectors into memory for faster search."""
8283
cursor = self._conn.cursor()
8384

8485
cursor.execute("SELECT vector_quantize_preload('chunks', 'embedding');")
8586

86-
self._conn.commit()
87-
8887
def quantize_cleanup(self) -> None:
8988
"""Clean up internal structures related to a previously quantized table/column."""
9089
cursor = self._conn.cursor()
9190

9291
cursor.execute("SELECT vector_quantize_cleanup('chunks', 'embedding');")
9392

94-
self._conn.commit()
93+
def create_new_context(self) -> None:
94+
""""""
95+
cursor = self._conn.cursor()
96+
97+
cursor.execute("SELECT llm_context_create(?);", (self._settings.model_config,))
98+
99+
def free_context(self) -> None:
100+
""""""
101+
cursor = self._conn.cursor()
102+
103+
cursor.execute("SELECT llm_context_free();")
95104

96105
def search(self, query: str, limit: int = 10) -> list[DocumentResult]:
97106
"""Semantic search and full-text search sorted with Reciprocal Rank Fusion."""

src/sqlite_rag/settings.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import sqlite3
3-
from dataclasses import asdict, dataclass, fields
3+
from dataclasses import asdict, dataclass, fields, replace
4+
from typing import Any, Optional
45

56

67
@dataclass
@@ -54,6 +55,34 @@ def _ensure_table_exists(self):
5455
)
5556
self.connection.commit()
5657

58+
def prepare_settings(self, settings: Optional[dict[str, Any]]) -> Settings:
59+
"""Load, initialize or update settings.
60+
61+
If no settings are provided, load the last used settings or use defaults.
62+
If settings are provided, check for critical changes and update them.
63+
"""
64+
current_settings = self.load_settings()
65+
if current_settings:
66+
if settings:
67+
new_settings = replace(current_settings, **settings)
68+
69+
if self.has_critical_changes(new_settings, current_settings):
70+
raise ValueError(
71+
"Critical settings changes detected. Please reset the database."
72+
)
73+
# Update new settings
74+
current_settings = self.store(new_settings)
75+
elif settings:
76+
# Store initial settings with customs
77+
new_settings = replace(Settings(), **settings)
78+
current_settings = self.store(new_settings)
79+
else:
80+
# Store default settings
81+
new_settings = Settings()
82+
current_settings = self.store(new_settings)
83+
84+
return current_settings
85+
5786
def load_settings(self) -> Settings | None:
5887
cursor = self.connection.cursor()
5988

src/sqlite_rag/sqliterag.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sqlite3
2-
from dataclasses import asdict, replace
2+
from dataclasses import asdict
33
from pathlib import Path
44
from typing import Optional
55

@@ -51,30 +51,10 @@ def create(
5151
If no new settings are provided, it uses the default settings or load
5252
the settings used in the last execution."""
5353

54-
conn = sqlite3.connect(db_path)
55-
conn.row_factory = sqlite3.Row
54+
conn = Database.new_connection(db_path)
5655

57-
# Load, initialize or update settings
5856
settings_manager = SettingsManager(conn)
59-
current_settings = settings_manager.load_settings()
60-
if current_settings:
61-
if settings:
62-
settings = replace(current_settings, **asdict(settings))
63-
64-
if settings_manager.has_critical_changes(settings, current_settings):
65-
raise ValueError(
66-
"Critical settings changes detected. Please reset the database."
67-
)
68-
# Update new settings
69-
current_settings = settings_manager.store(settings)
70-
elif settings:
71-
# Store initial settings with customs
72-
settings = replace(Settings(), **asdict(settings))
73-
current_settings = settings_manager.store(settings)
74-
else:
75-
# Store default settings
76-
settings = Settings()
77-
current_settings = settings_manager.store(settings)
57+
current_settings = settings_manager.prepare_settings(settings)
7858

7959
Database.initialize(conn, current_settings)
8060

@@ -97,6 +77,8 @@ def add(
9777

9878
files_to_process = FileReader.collect_files(Path(path), recursive=recursive)
9979

80+
self._engine.create_new_context()
81+
10082
processed = 0
10183
self._logger.info(f"Processing {len(files_to_process)} files...")
10284
for file_path in files_to_process:
@@ -138,6 +120,7 @@ def add_text(
138120
self._ensure_initialized()
139121

140122
document = Document(content=text, uri=uri, metadata=metadata)
123+
self._engine.create_new_context()
141124
document = self._engine.process(document)
142125

143126
self._repository.add_document(document)
@@ -177,6 +160,8 @@ def rebuild(self, remove_missing: bool = False) -> dict:
177160
not_found = 0
178161
removed = 0
179162

163+
self._engine.create_new_context()
164+
180165
for doc in documents:
181166
doc_id = doc.id or ""
182167

@@ -218,8 +203,8 @@ def rebuild(self, remove_missing: bool = False) -> dict:
218203
except Exception as e:
219204
self._logger.error(f"Error processing text document {doc.id}: {e}")
220205

221-
if self._settings.quantize_scan:
222-
self._engine.quantize()
206+
if self._settings.quantize_scan:
207+
self._engine.quantize()
223208

224209
return {
225210
"total": total_docs,
@@ -247,9 +232,13 @@ def reset(self) -> bool:
247232
self._logger.error(f"Error during database reset: {e}")
248233
return False
249234

250-
def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
235+
def search(
236+
self, query: str, top_k: int = 10, new_context: bool = True
237+
) -> list[DocumentResult]:
251238
"""Search for documents matching the query"""
252239
self._ensure_initialized()
240+
if new_context:
241+
self._engine.create_new_context()
253242

254243
if self._settings.quantize_preload:
255244
self._engine.quantize_preload()

tests/integration/test_cli.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,26 @@
22
import tempfile
33
from pathlib import Path
44

5+
from pytest import fixture
56
from typer.testing import CliRunner
67

78
from sqlite_rag.cli import app
89
from sqlite_rag.settings import Settings
910

1011

12+
@fixture
13+
def temp_dir():
14+
"""Change the current working directory in order to create
15+
the default database in a temporary location."""
16+
with tempfile.TemporaryDirectory() as tmpdir:
17+
original_cwd = os.getcwd()
18+
try:
19+
os.chdir(tmpdir)
20+
yield tmpdir
21+
finally:
22+
os.chdir(original_cwd)
23+
24+
1125
class TestCLI:
1226
def test_search_exact_match(self):
1327
# Use SQLiteRag to set up the test data directly
@@ -69,3 +83,23 @@ def test_search_exact_match(self):
6983
assert "Found 1 documents" in result.stdout
7084
# For exact match with cosine distance, we expect distance close to 0.0
7185
assert "0.000000" in result.stdout or "0.00000" in result.stdout
86+
87+
def test_set_settings(self, temp_dir):
88+
runner = CliRunner()
89+
90+
model_path = "mypath/mymodel.gguf"
91+
92+
result = runner.invoke(
93+
app,
94+
[
95+
"set",
96+
"--model-path-or-name",
97+
model_path,
98+
"--other-vector-config",
99+
"distance=L2",
100+
],
101+
)
102+
assert result.exit_code == 0
103+
104+
assert f"model_path_or_name: {model_path}" in result.stdout
105+
assert "other_vector_config: distance=L2" in result.stdout

tests/test_settings.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,64 @@ def test_has_critical_changes(self, db_conn):
144144
new_settings, current_settings
145145
)
146146
assert has_changes
147+
148+
def test_prepare_settings_with_no_existing_and_no_input(self, db_conn):
149+
"""Test prepare_settings returns default settings when no existing settings and no input"""
150+
settings_manager = SettingsManager(db_conn[0])
151+
152+
result = settings_manager.prepare_settings(None)
153+
154+
defaults = Settings()
155+
assert result.model_path_or_name == defaults.model_path_or_name
156+
assert result.embedding_dim == defaults.embedding_dim
157+
assert result.chunk_size == defaults.chunk_size
158+
159+
def test_prepare_settings_with_no_existing_and_custom_input(self, db_conn):
160+
"""Test prepare_settings stores and returns custom settings when no existing settings"""
161+
settings_manager = SettingsManager(db_conn[0])
162+
163+
result = settings_manager.prepare_settings(
164+
{"chunk_size": 5000, "quantize_scan": False}
165+
)
166+
167+
assert result.chunk_size == 5000
168+
assert result.quantize_scan is False
169+
# Check defaults are preserved
170+
defaults = Settings()
171+
assert result.model_path_or_name == defaults.model_path_or_name
172+
173+
def test_prepare_settings_with_existing_and_no_input(self, db_conn):
174+
"""Test prepare_settings returns existing settings when they exist and no input provided"""
175+
settings_manager = SettingsManager(db_conn[0])
176+
existing = Settings(chunk_size=3000, quantize_scan=False)
177+
settings_manager.store(existing)
178+
179+
result = settings_manager.prepare_settings(None)
180+
181+
assert result.chunk_size == 3000
182+
assert result.quantize_scan is False
183+
184+
def test_prepare_settings_with_existing_and_non_critical_updates(self, db_conn):
185+
"""Test prepare_settings updates non-critical settings when existing settings present"""
186+
settings_manager = SettingsManager(db_conn[0])
187+
existing = Settings(chunk_size=3000, chunk_overlap=100)
188+
settings_manager.store(existing)
189+
190+
result = settings_manager.prepare_settings(
191+
{"chunk_size": 4000, "quantize_scan": False}
192+
)
193+
194+
assert result.chunk_size == 4000
195+
assert result.chunk_overlap == 100
196+
assert result.quantize_scan is False
197+
198+
def test_prepare_settings_with_critical_changes_raises_error(self, db_conn):
199+
"""Test prepare_settings raises ValueError when critical settings change"""
200+
settings_manager = SettingsManager(db_conn[0])
201+
existing = Settings()
202+
settings_manager.store(existing)
203+
204+
import pytest
205+
206+
with pytest.raises(ValueError, match="Critical settings changes detected"):
207+
settings_manager.prepare_settings({"model_path_or_name": "new_model"})

0 commit comments

Comments
 (0)