Skip to content

Commit 80fd54d

Browse files
author
Daniele Briggi
committed
refact(quantize): manage it manually
- detach from quantize_scan setting
1 parent d72ea20 commit 80fd54d

File tree

5 files changed

+110
-37
lines changed

5 files changed

+110
-37
lines changed

src/sqlite_rag/cli.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def add(
137137
rag.add(
138138
path,
139139
recursive=recursive,
140-
absolute_paths=absolute_paths,
140+
use_absolute_paths=absolute_paths,
141141
metadata=json.loads(metadata or "{}"),
142142
)
143143

@@ -345,6 +345,29 @@ def search(
345345
typer.echo(f"{idx:<3} {snippet:<60} {uri:<40}")
346346

347347

348+
@app.command()
349+
def quantize(
350+
cleanup: bool = typer.Option(
351+
False,
352+
"--cleanup",
353+
help="Clean up quantization structures instead of creating them",
354+
)
355+
):
356+
"""Quantize vectors for faster search or clean up quantization structures"""
357+
rag = SQLiteRag.create()
358+
359+
if cleanup:
360+
typer.echo("Cleaning up quantization structures...")
361+
rag.quantize_cleanup()
362+
typer.echo("Quantization cleanup completed.")
363+
else:
364+
typer.echo("Starting vector quantization...")
365+
rag.quantize_vectors()
366+
typer.echo(
367+
"Vector quantization completed. Now you can search with `--quantize-scan` and `--quantize-preload` enabled."
368+
)
369+
370+
348371
@app.command("download-model")
349372
def download_model(
350373
model_id: str = typer.Argument(

src/sqlite_rag/engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sqlite3
44
from pathlib import Path
55

6+
from sqlite_rag.logger import Logger
67
from sqlite_rag.models.document_result import DocumentResult
78

89
from .chunker import Chunker
@@ -19,6 +20,7 @@ def __init__(self, conn: sqlite3.Connection, settings: Settings, chunker: Chunke
1920
self._conn = conn
2021
self._settings = settings
2122
self._chunker = chunker
23+
self._logger = Logger()
2224

2325
def load_model(self):
2426
"""Load the model model from the specified path
@@ -77,6 +79,9 @@ def quantize(self) -> None:
7779

7880
cursor.execute("SELECT vector_quantize('chunks', 'embedding');")
7981

82+
self._conn.commit()
83+
self._logger.debug("Quantization completed.")
84+
8085
def quantize_preload(self) -> None:
8186
"""Preload quantized vectors into memory for faster search."""
8287
cursor = self._conn.cursor()
@@ -89,6 +94,8 @@ def quantize_cleanup(self) -> None:
8994

9095
cursor.execute("SELECT vector_quantize_cleanup('chunks', 'embedding');")
9196

97+
self._conn.commit()
98+
9299
def create_new_context(self) -> None:
93100
""""""
94101
cursor = self._conn.cursor()

src/sqlite_rag/sqliterag.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ def _ensure_initialized(self):
3232
if not self.ready:
3333
self._engine.load_model()
3434

35-
if self._settings.quantize_scan:
36-
# TODO: quantize again if already quantized?
37-
# TODO: DO NOT REPEAT FOR NOTHING, it takes time
38-
self._engine.quantize()
39-
else:
40-
self._engine.quantize_cleanup()
41-
4235
self.ready = True
4336

4437
@staticmethod
@@ -64,7 +57,7 @@ def add(
6457
self,
6558
path: str,
6659
recursive: bool = False,
67-
absolute_paths: bool = True,
60+
use_absolute_paths: bool = True,
6861
metadata: dict = {},
6962
) -> int:
7063
"""Add the file content into the database"""
@@ -81,36 +74,36 @@ def add(
8174

8275
processed = 0
8376
self._logger.info(f"Processing {len(files_to_process)} files...")
84-
for file_path in files_to_process:
85-
content = FileReader.parse_file(file_path)
77+
try:
78+
for file_path in files_to_process:
79+
content = FileReader.parse_file(file_path)
8680

87-
if not content:
88-
self._logger.warning(f"Skipping empty file: {file_path}")
89-
continue
81+
if not content:
82+
self._logger.warning(f"Skipping empty file: {file_path}")
83+
continue
9084

91-
uri = (
92-
str(file_path.absolute())
93-
if absolute_paths
94-
else str(file_path.relative_to(parent))
95-
)
96-
document = Document(content=content, uri=uri, metadata=metadata)
85+
uri = (
86+
str(file_path.absolute())
87+
if use_absolute_paths
88+
else str(file_path.relative_to(parent))
89+
)
90+
document = Document(content=content, uri=uri, metadata=metadata)
9791

98-
exists = self._repository.document_exists_by_hash(document.hash())
99-
if exists:
100-
self._logger.info(f"Unchanged: {file_path}")
101-
continue
92+
exists = self._repository.document_exists_by_hash(document.hash())
93+
if exists:
94+
self._logger.info(f"Unchanged: {file_path}")
95+
continue
10296

103-
self._logger.info(f"Processing: {file_path}")
104-
document = self._engine.process(document)
97+
self._logger.info(f"Processing: {file_path}")
98+
document = self._engine.process(document)
10599

106-
self._repository.add_document(document)
100+
self._repository.add_document(document)
107101

108-
# TODO: try/expect and run in the finally block?
102+
processed += 1
103+
finally:
109104
if self._settings.quantize_scan:
110105
self._engine.quantize()
111106

112-
processed += 1
113-
114107
self._engine.free_context()
115108

116109
return processed
@@ -247,7 +240,7 @@ def search(
247240
if new_context:
248241
self._engine.create_new_context()
249242

250-
if self._settings.quantize_preload:
243+
if self._settings.quantize_scan and self._settings.quantize_preload:
251244
self._engine.quantize_preload()
252245

253246
return self._engine.search(query, limit=top_k)
@@ -257,7 +250,17 @@ def get_settings(self) -> dict:
257250
versions = self._engine.versions()
258251
return {**versions, **asdict(self._settings)}
259252

260-
def destroy(self) -> None:
253+
def quantize_vectors(self) -> None:
254+
"""Quantize vectors for faster search"""
255+
self._ensure_initialized()
256+
self._engine.quantize()
257+
258+
def quantize_cleanup(self) -> None:
259+
"""Clean up quantization structures"""
260+
self._ensure_initialized()
261+
self._engine.quantize_cleanup()
262+
263+
def close(self) -> None:
261264
"""Free up resources"""
262265
if self._conn:
263266
self._engine.close()

tests/integration/test_cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_search_exact_match(self):
4646
"set",
4747
"--model-path-or-name",
4848
str(model_path),
49-
"--other-vector-config",
49+
"--other-vector-options",
5050
"distance=cosine",
5151
],
5252
)
@@ -95,11 +95,11 @@ def test_set_settings(self, temp_dir):
9595
"set",
9696
"--model-path-or-name",
9797
model_path,
98-
"--other-vector-config",
98+
"--other-vector-options",
9999
"distance=L2",
100100
],
101101
)
102102
assert result.exit_code == 0
103103

104104
assert f"model_path_or_name: {model_path}" in result.stdout
105-
assert "other_vector_config: distance=L2" in result.stdout
105+
assert "other_vector_options: distance=L2" in result.stdout

tests/test_sqlite_rag.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import tempfile
44
from pathlib import Path
55

6+
import pytest
7+
68
from sqlite_rag import SQLiteRag
79

810

@@ -94,7 +96,7 @@ def test_add_with_absolute_paths_option_true(self):
9496

9597
rag = SQLiteRag.create(":memory:")
9698

97-
rag.add(temp_file_path, absolute_paths=True)
99+
rag.add(temp_file_path, use_absolute_paths=True)
98100

99101
conn = rag._conn
100102
cursor = conn.execute("SELECT uri FROM documents")
@@ -109,7 +111,7 @@ def test_add_with_absolute_paths_option_false(self):
109111

110112
rag = SQLiteRag.create(":memory:")
111113

112-
rag.add(str(temp_file_path), absolute_paths=False)
114+
rag.add(str(temp_file_path), use_absolute_paths=False)
113115

114116
conn = rag._conn
115117
cursor = conn.execute("SELECT uri FROM documents")
@@ -540,3 +542,41 @@ def test_search_exact_match(self):
540542
assert expected_string == results[0].document.content
541543
assert 1 == results[0].vec_rank
542544
assert 0.0 == results[0].vec_distance
545+
546+
@pytest.mark.parametrize(
547+
"quantize_scan", [True, False], ids=["quantize", "no-quantize"]
548+
)
549+
def test_search_samples_exact_match_by_scan_type(self, quantize_scan: bool):
550+
# Test that searching for exact content from sample files returns distance 0
551+
# FTS not included in the combined score
552+
settings = {
553+
"other_vector_options": "distance=cosine",
554+
"weight_fts": 0.0,
555+
"quantize_scan": quantize_scan,
556+
}
557+
558+
temp_file_path = os.path.join(tempfile.mkdtemp(), "mydb.db")
559+
rag = SQLiteRag.create(temp_file_path, settings=settings)
560+
561+
# Index all sample files
562+
samples_dir = Path(__file__).parent / "assets" / "samples"
563+
rag.add(str(samples_dir))
564+
565+
# Get all sample files and test each one
566+
sample_files = list(samples_dir.glob("*.txt"))
567+
568+
for sample_file in sample_files:
569+
file_content = sample_file.read_text(encoding="utf-8")
570+
571+
# Search for the exact content
572+
results = rag.search(file_content, top_k=2)
573+
574+
assert len(results) == 2
575+
576+
first_result = results[0]
577+
assert first_result.vec_distance == 0.0
578+
assert first_result.document.content == file_content
579+
580+
# Second result should have distance > 0
581+
second_result = results[1]
582+
assert second_result.vec_distance and second_result.vec_distance > 0.0

0 commit comments

Comments
 (0)