Skip to content

Commit 2d22ad5

Browse files
author
Daniele Briggi
committed
fix(tests): consider prompts
1 parent a90e534 commit 2d22ad5

File tree

11 files changed

+147
-18
lines changed

11 files changed

+147
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
dev = [
3636
"pytest",
3737
"pytest-mock",
38-
"pytest-cov",
38+
"pytest-cov==6.3.0",
3939
"black",
4040
"flake8",
4141
"bandit",

src/sqlite_rag/chunker.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ def __init__(self, conn: sqlite3.Connection, settings: Settings):
1111
self._conn = conn
1212
self._settings = settings
1313

14-
def chunk(self, text: str) -> list[Chunk]:
14+
def chunk(self, text: str, metadata: dict = {}) -> list[Chunk]:
1515
"""Chunk text using Recursive Character Text Splitter."""
16+
chunks = []
1617
if self._get_token_count(text) <= self._settings.chunk_size:
17-
return [Chunk(content=text)]
18+
chunks = [Chunk(content=text)]
19+
else:
20+
chunks = self._recursive_split(text)
1821

19-
return self._recursive_split(text)
22+
return self._enrich_chunk(chunks, metadata)
2023

2124
def _get_token_count(self, text: str) -> int:
2225
"""Get token count using SQLite AI extension."""
@@ -190,3 +193,13 @@ def _get_overlap_text(self, text: str, max_overlap_tokens: int) -> str:
190193

191194
# If even single word is too large, return empty
192195
return ""
196+
197+
def _enrich_chunk(self, chunks: List[Chunk], metadata: dict) -> List[Chunk]:
198+
"""Add extra information to chunk which may improve the model embeddings."""
199+
for chunk in chunks:
200+
if "title" in metadata:
201+
chunk.title = metadata["title"]
202+
elif "title" in metadata.get("generated", {}):
203+
chunk.title = metadata["generated"]["title"]
204+
205+
return chunks

src/sqlite_rag/cli.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ def configure_settings(
149149
use_gpu: Optional[bool] = typer.Option(
150150
None, help="Whether to allow sqlite-ai extension to use the GPU"
151151
),
152+
no_prompt_templates: bool = typer.Option(
153+
False,
154+
"--no-prompt-templates",
155+
help="Disable prompt templates for embedding generation",
156+
),
157+
prompt_template_retrieval_document: Optional[str] = typer.Option(
158+
None,
159+
help="Template for retrieval document prompts. Supported placeholders are `{title}` and `{content}`",
160+
),
152161
prompt_template_retrieval_query: Optional[str] = typer.Option(
153162
None,
154163
help="Template for retrieval query prompts, use `{content}` as placeholder",
@@ -176,9 +185,13 @@ def configure_settings(
176185
"weight_fts": weight_fts,
177186
"weight_vec": weight_vec,
178187
"use_gpu": use_gpu,
188+
"use_prompt_templates": (
189+
False if no_prompt_templates else None
190+
), # Set only if True
191+
"prompt_template_retrieval_document": prompt_template_retrieval_document,
179192
"prompt_template_retrieval_query": prompt_template_retrieval_query,
180193
}
181-
194+
print(updates)
182195
# Filter out None values (unset options)
183196
updates = {k: v for k, v in updates.items() if v is not None}
184197

src/sqlite_rag/engine.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def load_model(self):
3535
)
3636

3737
def process(self, document: Document) -> Document:
38-
chunks = self._chunker.chunk(document.content)
38+
chunks = self._chunker.chunk(document.content, document.metadata)
3939
chunks = self.generate_embedding(chunks)
4040
document.chunks = chunks
4141
return document
@@ -46,12 +46,18 @@ def generate_embedding(self, chunks: list[Chunk]) -> list[Chunk]:
4646
for chunk in chunks:
4747
cursor = self._conn.cursor()
4848

49-
try:
50-
cursor.execute(
51-
"SELECT llm_embed_generate(?) AS embedding", (chunk.content,)
49+
# Format using the prompt template if available
50+
content = chunk.content
51+
if self._settings.use_prompt_templates:
52+
title = chunk.title if chunk.title else "none"
53+
content = self._settings.prompt_template_retrieval_document.format(
54+
title=title, content=chunk.content
5255
)
56+
57+
try:
58+
cursor.execute("SELECT llm_embed_generate(?) AS embedding", (content,))
5359
except sqlite3.Error as e:
54-
print(f"Error generating embedding for chunk\n: ```{chunk.content}```")
60+
print(f"Error generating embedding for chunk\n: ```{content}```")
5561
raise e
5662

5763
result = cursor.fetchone()

src/sqlite_rag/models/chunk.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ class Chunk:
88
content: str = ""
99
embedding: str | bytes = b""
1010
core_start_pos: int = 0
11+
12+
title: str | None = None

src/sqlite_rag/settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Settings:
3131
"distance=cosine" # e.g. distance=metric,other=value,...
3232
)
3333

34+
# It includes the overlap size but not the prompt template length
3435
chunk_size: int = 384
3536
# Tokens overlap between chunks
3637
chunk_overlap: int = 48
@@ -53,8 +54,13 @@ class Settings:
5354
# Some models are trained to work better with specific prompts
5455
# depending on the task. For example, Gemma models work better
5556
# when the prompt includes a task description.
57+
# More: https://huggingface.co/unsloth/embeddinggemma-300m-GGUF#prompt-instructions
5658
#
5759

60+
use_prompt_templates: bool = True
61+
62+
# Template to index documents for retrieval, use `{title}` with the title or the string `"none"`
63+
prompt_template_retrieval_document: str = "title: {title} | text: {content}"
5864
prompt_template_retrieval_query: str = "task: search result | query: {content}"
5965

6066

src/sqlite_rag/sqliterag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def search(
272272
if new_context:
273273
self._engine.create_new_context()
274274

275-
if self._settings.prompt_template_retrieval_query:
275+
if self._settings.use_prompt_templates:
276276
query = self._settings.prompt_template_retrieval_query.format(content=query)
277277

278278
return self._engine.search(query, top_k=top_k)

tests/integration/test_cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def test_search_exact_match(self):
2929
"configure",
3030
"--model-path",
3131
str(model_path),
32-
"--prompt-template-retrieval-query",
33-
"",
32+
"--no-prompt-templates",
3433
"--other-vector-options",
3534
"distance=cosine",
3635
],

tests/test_chunker.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def mock_conn():
4242
@pytest.fixture
4343
def chunker_large(mock_conn):
4444
"""Fixture providing a chunker with large chunk size."""
45-
settings = Settings("test-model")
45+
settings = Settings("test-model", use_prompt_templates=False)
4646
settings.chunk_size = 100
4747
settings.chunk_overlap = 20
4848
return Chunker(mock_conn, settings)
@@ -51,7 +51,7 @@ def chunker_large(mock_conn):
5151
@pytest.fixture
5252
def chunker_small(mock_conn):
5353
"""Fixture providing a chunker with small chunk size."""
54-
settings = Settings("test-model")
54+
settings = Settings("test-model", use_prompt_templates=False)
5555
settings.chunk_size = 25
5656
settings.chunk_overlap = 5
5757
return Chunker(mock_conn, settings)
@@ -60,7 +60,7 @@ def chunker_small(mock_conn):
6060
@pytest.fixture
6161
def chunker_tiny(mock_conn):
6262
"""Fixture providing a chunker with tiny chunk size."""
63-
settings = Settings("test-model")
63+
settings = Settings("test-model", use_prompt_templates=False)
6464
settings.chunk_size = 8
6565
settings.chunk_overlap = 2
6666
return Chunker(mock_conn, settings)
@@ -85,6 +85,27 @@ def test_empty_text(self, chunker_large):
8585
assert len(chunks) == 1
8686
assert chunks[0].content == ""
8787

88+
def test_chunk_enrichness_with_input_title(self, chunker_large):
89+
"""Test that chunk enrichment adds metadata correctly."""
90+
text = "This is a test chunk."
91+
metadata = {"title": "Test Title"}
92+
93+
chunks = chunker_large.chunk(text, metadata)
94+
95+
assert len(chunks) == 1
96+
assert chunks[0].content == text
97+
assert chunks[0].title == "Test Title"
98+
99+
def test_chunk_enrichness_with_generated_title(self, chunker_large):
100+
text = "# My title\n\nThis is a paragraph to test chunk."
101+
metadata = {"generated": {"title": "My title"}}
102+
103+
chunks = chunker_large.chunk(text, metadata)
104+
105+
assert len(chunks) == 1
106+
assert chunks[0].content == text
107+
assert chunks[0].title == "My title"
108+
88109

89110
class TestParagraphSplitting:
90111
"""Test cases for paragraph-level splitting."""

tests/test_engine.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,41 @@ def test_generate_embedding(self, engine):
1818
assert result_chunks[0].embedding is not None
1919
assert isinstance(result_chunks[0].embedding, bytes)
2020

21+
@pytest.mark.parametrize("use_prompt_templates", [True, False])
22+
def test_generate_embedding_with_prompt_template(
23+
self, mocker, use_prompt_templates
24+
):
25+
# Arrange
26+
mock_conn = mocker.Mock()
27+
mock_cursor = mocker.Mock()
28+
mock_cursor.fetchone.return_value = {"embedding": b"fake_embedding"}
29+
mock_conn.cursor.return_value = mock_cursor
30+
31+
settings = Settings(
32+
use_prompt_templates=use_prompt_templates,
33+
prompt_template_retrieval_document="Title: {title}\nContent: {content}",
34+
)
35+
36+
engine = Engine(mock_conn, settings, mocker.Mock())
37+
38+
chunk = Chunk(
39+
content="Test content",
40+
title="Test Title",
41+
)
42+
43+
# Act
44+
engine.generate_embedding([chunk])
45+
46+
# Assert - verify cursor.execute was called with formatted template
47+
expected_content = (
48+
"Title: Test Title\nContent: Test content"
49+
if use_prompt_templates
50+
else "Test content"
51+
)
52+
mock_cursor.execute.assert_called_with(
53+
"SELECT llm_embed_generate(?) AS embedding", (expected_content,)
54+
)
55+
2156
def test_search_with_empty_database(self, engine):
2257
results = engine.search("nonexistent query", top_k=5)
2358

0 commit comments

Comments
 (0)