Skip to content

Commit 10f9d7e

Browse files
author
Daniele Briggi
committed
feat(prompts): support to model's prompts
1 parent e146b3e commit 10f9d7e

File tree

9 files changed

+64
-12
lines changed

9 files changed

+64
-12
lines changed

model_evaluation/configs/gemma_300M_Q8_650rows.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"rag_settings": {
44
"chunk_size": 1000,
55
"chunk_overlap": 0,
6-
"model_path_or_name": "./../models/unsloth/embeddinggemma-300m-GGUF/embeddinggemma-300M-Q8_0.gguf",
6+
"model_path": "./../models/unsloth/embeddinggemma-300m-GGUF/embeddinggemma-300M-Q8_0.gguf",
77
"model_options": "",
88
"model_context_options": "generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8",
99
"vector_type": "INT8",

model_evaluation/configs/qwen3_Q8_650rows.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"rag_settings": {
44
"chunk_size": 1000,
55
"chunk_overlap": 0,
6-
"model_path_or_name": "./../models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf",
6+
"model_path": "./../models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf",
77
"model_options": "",
88
"model_context_options": "generate_embedding=1,normalize_embedding=1,pooling_type=last,embedding_type=INT8",
99
"vector_type": "INT8",

model_evaluation/ms_marco.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,10 @@ def test_ms_marco_processing(
212212

213213

214214
def evaluate_search_quality(
215-
limit_rows=None, database_path="ms_marco_test.sqlite", output_file=None
215+
limit_rows=None,
216+
database_path="ms_marco_test.sqlite",
217+
output_file=None,
218+
rag_settings=None,
216219
):
217220
"""Evaluate search quality using proper metrics"""
218221

@@ -243,7 +246,7 @@ def output(text):
243246
output(f"Evaluating on {len(df)} queries")
244247

245248
# Create RAG instance
246-
rag = SQLiteRag.create(database_path)
249+
rag = SQLiteRag.create(database_path, settings=rag_settings)
247250
memory_monitor.record() # After RAG initialization
248251

249252
# Metrics for different top-k values
@@ -275,6 +278,8 @@ def output(text):
275278
total_queries += 1
276279

277280
# Perform search
281+
# EmbeddingGemma works better with task specific prefix
282+
# query_text = f"task: search result | query: {query_text}"
278283
search_results = rag.search(query_text, top_k=10)
279284

280285
# Check results for each k value
@@ -562,6 +567,7 @@ def main():
562567
limit_rows=args.limit_rows,
563568
database_path=database_path,
564569
output_file=output_file,
570+
rag_settings=rag_settings,
565571
)
566572

567573

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
[project.optional-dependencies]
3535
dev = [
3636
"pytest",
37+
"pytest-mock",
3738
"pytest-cov",
3839
"black",
3940
"flake8",

src/sqlite_rag/cli.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def configure_settings(
148148
use_gpu: Optional[bool] = typer.Option(
149149
None, help="Whether to allow sqlite-ai extension to use the GPU"
150150
),
151+
prompt_template_retrieval_query: Optional[str] = typer.Option(
152+
None,
153+
help="Template for retrieval query prompts, use {content} as placeholder",
154+
),
151155
):
152156
"""Configure settings for the RAG system.
153157
@@ -171,6 +175,7 @@ def configure_settings(
171175
"weight_fts": weight_fts,
172176
"weight_vec": weight_vec,
173177
"use_gpu": use_gpu,
178+
"prompt_template_retrieval_query": prompt_template_retrieval_query,
174179
}
175180

176181
# Filter out None values (unset options)
@@ -404,6 +409,11 @@ def search(
404409
@app.command()
405410
def quantize(
406411
ctx: typer.Context,
412+
preload: bool = typer.Option(
413+
False,
414+
"--preload",
415+
help="Preload quantized vectors into memory for faster search",
416+
),
407417
cleanup: bool = typer.Option(
408418
False,
409419
"--cleanup",
@@ -420,9 +430,14 @@ def quantize(
420430
typer.echo("Quantization cleanup completed.")
421431
else:
422432
typer.echo("Starting vector quantization...")
433+
423434
rag.quantize_vectors()
435+
if preload:
436+
typer.echo("Preloading quantized vectors into memory...")
437+
rag.quantize_preload()
438+
424439
typer.echo(
425-
"Vector quantization completed. Now you can search with `--quantize-scan` and `--quantize-preload` enabled."
440+
"Vector quantization completed. Now you can search with `--quantize-scan` enabled."
426441
)
427442

428443

src/sqlite_rag/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, conn: sqlite3.Connection, settings: Settings, chunker: Chunke
2525
def load_model(self):
2626
"""Load the model model from the specified path."""
2727

28-
model_path = Path(self._settings.model_path)
28+
model_path = Path(self._settings.model_path).resolve()
2929
if not model_path.exists():
3030
raise FileNotFoundError(f"Model file not found at {model_path}")
3131

src/sqlite_rag/settings.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class Settings:
2121
"generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8"
2222
)
2323

24+
# Allow the sqlite-ai extension to use the GPU
25+
# See: https://github.com/sqliteai/sqlite-ai
26+
use_gpu = False
27+
2428
vector_type: str = "INT8"
2529
embedding_dim: int = 768
2630
other_vector_options: str = (
@@ -44,9 +48,14 @@ class Settings:
4448
weight_fts: float = 1.0
4549
weight_vec: float = 1.0
4650

47-
# Allow the sqlite-ai extension to use the GPU
48-
# See: https://github.com/sqliteai/sqlite-ai
49-
use_gpu = False
51+
#
52+
# Prompt templates
53+
# Some models are trained to work better with specific prompts
54+
# depending on the task. For example, Gemma models work better
55+
# when the prompt includes a task description.
56+
#
57+
58+
prompt_template_retrieval_query: str = "task: search result | query: {content}"
5059

5160

5261
class SettingsManager:
@@ -92,7 +101,7 @@ def configure(
92101
)
93102
else:
94103
raise ValueError(
95-
"Critical settings changes detected. Please reset the database."
104+
"Critical settings changes detected. Please force the settings update or reset the database."
96105
)
97106
# Update new settings
98107
current_settings = self.store(new_settings)

src/sqlite_rag/sqliterag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def search(
265265
if new_context:
266266
self._engine.create_new_context()
267267

268-
if self._settings.quantize_scan and self._settings.quantize_preload:
269-
self._engine.quantize_preload()
268+
if self._settings.prompt_template_retrieval_query:
269+
query = self._settings.prompt_template_retrieval_query.format(content=query)
270270

271271
return self._engine.search(query, top_k=top_k)
272272

tests/test_sqlite_rag.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,3 +580,24 @@ def test_search_samples_exact_match_by_scan_type(self, quantize_scan: bool):
580580
# Second result should have distance > 0
581581
second_result = results[1]
582582
assert second_result.vec_distance and second_result.vec_distance > 0.0
583+
584+
def test_search_uses_retrieval_query_template(self, mocker):
585+
template = "task: search | Do something with {content}"
586+
587+
settings = {"prompt_template_retrieval_query": template}
588+
589+
rag = SQLiteRag.create(":memory:", settings=settings)
590+
591+
mock_engine = mocker.Mock()
592+
mock_engine.search.return_value = []
593+
594+
rag._engine = mock_engine
595+
596+
query = "test query"
597+
rag.search(query)
598+
599+
# Assert that engine.search was called with the formatted template
600+
expected_query = rag._settings.prompt_template_retrieval_query.format(
601+
content=query
602+
)
603+
mock_engine.search.assert_called_once_with(expected_query, top_k=10)

0 commit comments

Comments
 (0)