Skip to content

Commit 2b120a1

Browse files
author
Daniele Briggi
committed
feat(chat): experiments
1 parent 8bd55db commit 2b120a1

File tree

7 files changed

+159
-9
lines changed

7 files changed

+159
-9
lines changed

documentation_ai.sqlite

15.6 MB
Binary file not shown.

src/sqlite_rag/cli.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,24 @@ def search(
468468
typer.echo(f"{search_time:.3f} seconds")
469469

470470

471+
@app.command()
472+
def ask(
473+
ctx: typer.Context,
474+
question: str,
475+
):
476+
"""Ask a question and get an answer using the LLM"""
477+
rag_context = ctx.obj["rag_context"]
478+
start_time = time.time()
479+
480+
rag = rag_context.get_rag(require_existing=True)
481+
answer = rag.ask(question)
482+
483+
elapsed_time = time.time() - start_time
484+
485+
typer.echo(answer)
486+
typer.echo(f"{elapsed_time:.3f} seconds")
487+
488+
471489
@app.command()
472490
def quantize(
473491
ctx: typer.Context,

src/sqlite_rag/engine.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,79 @@ def search_sentences(
316316

317317
return sentences[:top_k]
318318

319+
def create_new_chat(self) -> None:
320+
"""Create a new LLM chat context with empty history."""
321+
# self._conn.execute(
322+
# "SELECT llm_context_create(?);", (self._settings.other_gen_context_options,)
323+
# )
324+
# self._conn.execute("SELECT llm_chat_create();")
325+
326+
def ask(self, query: str) -> str:
327+
"""Generate an answer to the query using the LLM."""
328+
results = self.search(query, top_k=10)
329+
results = results[:3]
330+
331+
context = ""
332+
for result in results:
333+
# if result.combined_rank < 0.3:
334+
print(
335+
f"doc uri: {result.document.uri}, vector: {result.vec_distance}, fts: {result.fts_score}, score: {result.combined_rank}"
336+
)
337+
preview = result.document.content[:5000].replace("\n", "\\n")
338+
context += f"{preview}\n\n"
339+
340+
prompt = query
341+
if context != "":
342+
# prompt = f"""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say you that don't know. Use three sentences maximum and keep the answer coincise.
343+
prompt = f"""Answer the question based only on the following documents.
344+
Answer with the summary of the documents provided.
345+
Do **NOT** include any introductory phrases, titles, or prefixes such as "Answer:", "The answer is:", "Final Answer:", or "Based on the context,". Start your response with the answer itself:
346+
347+
{context}
348+
349+
{query}
350+
"""
351+
352+
print("---\n", prompt)
353+
print(
354+
"token count:",
355+
self._conn.execute(
356+
"SELECT llm_token_count(?) AS token_count;", (prompt,)
357+
).fetchone()["token_count"],
358+
)
359+
360+
self._conn.execute(
361+
"SELECT llm_model_load(?, ?);",
362+
(self._settings.gen_model_path, self._settings.other_gen_model_options),
363+
)
364+
self._conn.execute(
365+
"SELECT llm_context_create(?);", (self._settings.other_gen_context_options,)
366+
)
367+
self._conn.execute("SELECT llm_chat_create();")
368+
369+
self._conn.executescript(
370+
"""
371+
SELECT llm_sampler_init_temp(1.0);
372+
SELECT llm_sampler_init_top_k(64);
373+
SELECT llm_sampler_init_top_p(0.95, 1);
374+
SELECT llm_sampler_init_min_p(0.0, 1);
375+
SELECT llm_sampler_init_dist(-1);
376+
SELECT llm_sampler_init_penalties(1024, 1.1, 0.0, 0.0);
377+
"""
378+
)
379+
380+
r = self._conn.execute("SELECT llm_chat_respond(?) AS response;", (prompt,))
381+
382+
response = r.fetchone()[0]
383+
print(
384+
"token count:",
385+
self._conn.execute(
386+
"SELECT llm_token_count(?) AS token_count;", (response,)
387+
).fetchone()["token_count"],
388+
)
389+
390+
return response
391+
319392
def versions(self) -> dict:
320393
"""Get versions of the loaded extensions."""
321394
cursor = self._conn.cursor()

src/sqlite_rag/settings.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,29 @@ class Settings:
7676
# Zero means no limit
7777
max_chunks_per_document: int = 1000
7878
# Number of top sentences to return per document
79-
top_k_sentences: int = 3
79+
top_k_sentences: int = 10
80+
81+
#
82+
# Text generation
83+
#
84+
85+
# gen_model_path: str = (
86+
# "./models/unsloth/gemma-3-270m-it-GGUF/gemma-3-270m-it-Q8_0.gguf"
87+
# )
88+
gen_model_path: str = "./models/unsloth/gemma-3-1b-it-GGUF/gemma-3-1b-it-Q8_0.gguf"
89+
90+
# See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_model_loadpath-text-options-text
91+
other_gen_model_options: str = ""
92+
# See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_context_createoptions-text
93+
other_gen_context_options: str = (
94+
"n_ctx=6000,context_size=6000,max_tokens=3000,n_threads=8,n_predict=800"
95+
)
96+
97+
context_size: int = 2048
98+
# Max input tokens to the model for generation
99+
max_tokens: int = 2048
100+
101+
n_predict: int = 400
80102

81103
def get_embeddings_context_options(self) -> str:
82104
"""Get the context options for embeddings generation."""
@@ -94,6 +116,20 @@ def get_embeddings_context_options(self) -> str:
94116
else ""
95117
)
96118

119+
def get_generation_context_options(self) -> str:
120+
"""Get the context options for text generation."""
121+
options = {
122+
"context_size": self.context_size,
123+
"max_tokens": self.max_tokens,
124+
"n_predict": self.n_predict,
125+
}
126+
127+
return ",".join(f"{k}={v}" for k, v in options.items()) + (
128+
f",{self.other_gen_context_options}"
129+
if self.other_gen_context_options
130+
else ""
131+
)
132+
97133
def get_vector_init_options(self) -> str:
98134
"""Get the vector init options for the vector store."""
99135
options = {"type": self.vector_type, "dimension": self.embedding_dim}

src/sqlite_rag/sqliterag.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,21 @@ def search(
318318

319319
return self._engine.search(query, top_k=top_k)
320320

321+
def ask(self, question: str, new_context: bool = True) -> str:
322+
"""Generate an answer to the question using the LLM.
323+
324+
Args:
325+
question: The question string
326+
new_context: Whether to create a new LLM context for this question
327+
"""
328+
self._ensure_initialized()
329+
if new_context:
330+
self._engine.create_new_context()
331+
332+
self._engine.create_new_chat()
333+
334+
return self._engine.ask(question)
335+
321336
def get_settings(self) -> dict:
322337
"""Get settings and more useful information"""
323338
versions = self._engine.versions()

tests/conftest.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sqlite3
2-
import tempfile
32
from collections.abc import Generator
43

54
import pytest
@@ -13,17 +12,17 @@
1312

1413
@pytest.fixture
1514
def db_conn():
16-
with tempfile.NamedTemporaryFile(suffix=".db") as tmp_db:
17-
settings = Settings()
15+
# with tempfile.NamedTemporaryFile(suffix=".db") as tmp_db:
16+
settings = Settings()
1817

19-
conn = sqlite3.connect(tmp_db.name)
20-
conn.row_factory = sqlite3.Row
18+
conn = sqlite3.connect("./documentation_ai.sqlite")
19+
conn.row_factory = sqlite3.Row
2120

22-
Database.initialize(conn, settings)
21+
Database.initialize(conn, settings)
2322

24-
yield conn, settings
23+
yield conn, settings
2524

26-
conn.close()
25+
conn.close()
2726

2827

2928
@pytest.fixture

tests/integration/test_engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,12 @@ def test_search_sentences(self, db_conn):
320320
assert len(results) > 0
321321
assert results[0].start_offset == 61 # it's the second sentence
322322
assert results[0].end_offset == 89
323+
324+
325+
class TestEngineAsk:
326+
def test_ask(self, engine: Engine):
327+
engine.create_new_chat()
328+
329+
result = engine.ask("what's the difference between offsync and sqlite sync?")
330+
assert isinstance(result, str)
331+
print(result)

0 commit comments

Comments
 (0)