Skip to content

Commit 7c3eff3

Browse files
author
Daniele Briggi
committed
refact(ask): flow of text generation is integrated in the cli and sqliterag module
1 parent 2b120a1 commit 7c3eff3

File tree

10 files changed

+661
-305
lines changed

10 files changed

+661
-305
lines changed

src/sqlite_rag/chunker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
2-
import sqlite3
32
from typing import List, Optional
43

54
from sqlite_rag.models.document import Document
5+
from sqlite_rag.models.llm_model import LLMModel
66

77
from .models.chunk import Chunk
88
from .settings import Settings
@@ -11,8 +11,8 @@
1111
class Chunker:
1212
ESTIMATE_CHARS_PER_TOKEN = 4
1313

14-
def __init__(self, conn: sqlite3.Connection, settings: Settings):
15-
self._conn = conn
14+
def __init__(self, llm_model: LLMModel, settings: Settings):
15+
self.llm_model = llm_model
1616
self._settings = settings
1717

1818
def chunk(self, document: Document) -> list[Chunk]:
@@ -67,7 +67,8 @@ def _get_token_count(self, text: str) -> int:
6767
if len(text) > self._settings.chunk_size * self.ESTIMATE_CHARS_PER_TOKEN * 2:
6868
return self._estimate_tokens_count(text)
6969

70-
cursor = self._conn.execute("SELECT llm_token_count(?) AS count", (text,))
70+
conn = self.llm_model.ensure_loaded()
71+
cursor = conn.execute("SELECT llm_token_count(?) AS count", (text,))
7172
return cursor.fetchone()["count"]
7273

7374
def _estimate_tokens_count(self, text: str) -> int:

src/sqlite_rag/cli.py

Lines changed: 78 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#!/usr/bin/env python3
2+
import itertools
23
import json
34
import os
45
import shlex
6+
import sys
7+
import threading
58
import time
69
from pathlib import Path
7-
from typing import Optional
10+
from typing import Any, Optional
811

912
import typer
1013
from prompt_toolkit import prompt
@@ -13,6 +16,10 @@
1316
from sqlite_rag.database import Database
1417
from sqlite_rag.settings import SettingsManager
1518

19+
from .cli_configure import (
20+
build_configure_signature,
21+
filter_setting_updates,
22+
)
1623
from .formatters import get_formatter
1724
from .sqliterag import SQLiteRag
1825

@@ -112,74 +119,8 @@ def show_settings(ctx: typer.Context):
112119
@app.command("configure")
113120
def configure_settings(
114121
ctx: typer.Context,
115-
force: bool = typer.Option(
116-
False,
117-
"-f",
118-
"--force",
119-
help="Force update even if critical settings change (like model or embedding dimension)",
120-
),
121-
model_path: Optional[str] = typer.Option(
122-
None, help="Path to the embedding model file (.gguf)"
123-
),
124-
model_options: Optional[str] = typer.Option(
125-
None,
126-
help="options specific for the model: See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_model_loadpath-text-options-text",
127-
),
128-
model_context_options: Optional[str] = typer.Option(
129-
None,
130-
help="Options specific for model context creation. See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_context_createcontext_settings-text",
131-
),
132-
embedding_dim: Optional[int] = typer.Option(
133-
None, help="Dimension of the embedding vectors"
134-
),
135-
vector_type: Optional[str] = typer.Option(
136-
None, help="Vector storage type (FLOAT16, FLOAT32, etc.)"
137-
),
138-
other_vector_options: Optional[str] = typer.Option(
139-
None, help="Additional vector configuration"
140-
),
141-
chunk_size: Optional[int] = typer.Option(
142-
None, help="Size of text chunks for processing"
143-
),
144-
chunk_overlap: Optional[int] = typer.Option(
145-
None, help="Token overlap between consecutive chunks"
146-
),
147-
quantize_scan: Optional[bool] = typer.Option(
148-
None, help="Whether to quantize vector for faster search"
149-
),
150-
quantize_preload: Optional[bool] = typer.Option(
151-
None, help="Whether to preload quantized vectors in memory for faster search"
152-
),
153-
weight_fts: Optional[float] = typer.Option(
154-
None, help="Weight for full-text search results"
155-
),
156-
weight_vec: Optional[float] = typer.Option(
157-
None, help="Weight for vector search results"
158-
),
159-
use_gpu: Optional[bool] = typer.Option(
160-
None, help="Whether to allow sqlite-ai extension to use the GPU"
161-
),
162-
no_prompt_templates: bool = typer.Option(
163-
False,
164-
"--no-prompt-templates",
165-
help="Disable prompt templates for embedding generation",
166-
),
167-
prompt_template_retrieval_document: Optional[str] = typer.Option(
168-
None,
169-
help="Template for retrieval document prompts. Supported placeholders are `{title}` and `{content}`",
170-
),
171-
prompt_template_retrieval_query: Optional[str] = typer.Option(
172-
None,
173-
help="Template for retrieval query prompts, use `{content}` as placeholder",
174-
),
175-
max_document_size_bytes: Optional[int] = typer.Option(
176-
None,
177-
help="Maximum size of a document to process (in bytes) before being truncated",
178-
),
179-
max_chunks_per_document: Optional[int] = typer.Option(
180-
None,
181-
help="Maximum number of chunks to generate per document (0 for no limit)",
182-
),
122+
force: bool = False,
123+
**settings_values: Any,
183124
):
184125
"""Configure settings for the RAG system.
185126
@@ -189,32 +130,7 @@ def configure_settings(
189130
"""
190131
rag_context = ctx.obj["rag_context"]
191132

192-
# Build updates dict from all provided parameters
193-
updates = {
194-
"model_path": model_path,
195-
"model_options": model_options,
196-
"model_context_options": model_context_options,
197-
"use_gpu": use_gpu,
198-
"embedding_dim": embedding_dim,
199-
"vector_type": vector_type,
200-
"other_vector_options": other_vector_options,
201-
"chunk_size": chunk_size,
202-
"chunk_overlap": chunk_overlap,
203-
"quantize_scan": quantize_scan,
204-
"quantize_preload": quantize_preload,
205-
"weight_fts": weight_fts,
206-
"weight_vec": weight_vec,
207-
"use_prompt_templates": (
208-
False if no_prompt_templates else None
209-
), # Set only if True
210-
"prompt_template_retrieval_document": prompt_template_retrieval_document,
211-
"prompt_template_retrieval_query": prompt_template_retrieval_query,
212-
"max_document_size_bytes": max_document_size_bytes,
213-
"max_chunks_per_document": max_chunks_per_document,
214-
}
215-
print(updates)
216-
# Filter out None values (unset options)
217-
updates = {k: v for k, v in updates.items() if v is not None}
133+
updates = filter_setting_updates(settings_values)
218134

219135
if not updates:
220136
typer.echo("No settings provided to configure.")
@@ -229,6 +145,9 @@ def configure_settings(
229145
typer.echo("Settings updated.")
230146

231147

148+
configure_settings.__signature__ = build_configure_signature()
149+
150+
232151
@app.command()
233152
def add(
234153
ctx: typer.Context,
@@ -472,18 +391,78 @@ def search(
472391
def ask(
473392
ctx: typer.Context,
474393
question: str,
394+
use_last_chat: bool = typer.Option(
395+
False,
396+
"--use-last-chat",
397+
help="Reuse the previous chat session (REPL mode only)",
398+
),
475399
):
476400
"""Ask a question and get an answer using the LLM"""
477401
rag_context = ctx.obj["rag_context"]
402+
403+
if use_last_chat and not rag_context.in_repl:
404+
raise typer.BadParameter(
405+
"--use-last-chat is only available when running the REPL."
406+
)
407+
478408
start_time = time.time()
479409

480410
rag = rag_context.get_rag(require_existing=True)
481-
answer = rag.ask(question)
411+
cursor = rag.ask(question, reuse_chat=use_last_chat)
482412

483-
elapsed_time = time.time() - start_time
413+
spinner_stop = threading.Event()
414+
415+
def spinner() -> None:
416+
frames = itertools.cycle("\\|/-")
417+
while not spinner_stop.is_set():
418+
sys.stdout.write(f"\rthinking {next(frames)}")
419+
sys.stdout.flush()
420+
time.sleep(0.1)
421+
sys.stdout.write("\r" + " " * 20 + "\r")
422+
sys.stdout.flush()
484423

485-
typer.echo(answer)
486-
typer.echo(f"{elapsed_time:.3f} seconds")
424+
spinner_thread = threading.Thread(target=spinner, daemon=True)
425+
spinner_thread.start()
426+
427+
has_tokens = False
428+
token_count = 0
429+
try:
430+
while True:
431+
row = cursor.fetchone()
432+
if row is None:
433+
break
434+
435+
token = row["reply"]
436+
if token is None:
437+
continue
438+
439+
if not has_tokens:
440+
spinner_stop.set()
441+
spinner_thread.join()
442+
sys.stdout.write("\n")
443+
has_tokens = True
444+
445+
sys.stdout.write(token)
446+
sys.stdout.flush()
447+
token_count += 1
448+
finally:
449+
cursor.close()
450+
451+
spinner_stop.set()
452+
spinner_thread.join()
453+
454+
if has_tokens:
455+
sys.stdout.write("\n")
456+
sys.stdout.flush()
457+
else:
458+
typer.echo("\nNo response received.")
459+
460+
elapsed_time = time.time() - start_time
461+
stats_line = f"{elapsed_time:.3f} seconds"
462+
if token_count > 0 and elapsed_time > 0:
463+
tokens_per_sec = token_count / elapsed_time
464+
stats_line = f"{stats_line} ({token_count} tokens, {tokens_per_sec:.2f} tok/s)"
465+
typer.echo(stats_line)
487466

488467

489468
@app.command()

src/sqlite_rag/cli_configure.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import inspect
2+
from dataclasses import fields
3+
from typing import Any, Optional, Union, get_args, get_origin
4+
5+
import typer
6+
7+
from .settings import Settings
8+
9+
SETTINGS_FIELDS = tuple(fields(Settings))
10+
SETTINGS_FIELD_NAMES = {field.name for field in SETTINGS_FIELDS}
11+
12+
13+
def _strip_optional(annotation: Any) -> Any:
14+
origin = get_origin(annotation)
15+
if origin is Union:
16+
args = [arg for arg in get_args(annotation) if arg is not type(None)]
17+
if len(args) == 1:
18+
return args[0]
19+
return annotation
20+
21+
22+
def _is_bool(annotation: Any) -> bool:
23+
return _strip_optional(annotation) is bool
24+
25+
26+
def _option_help(field_obj) -> str:
27+
return field_obj.metadata.get(
28+
"help", f"Override {field_obj.name.replace('_', ' ')}"
29+
)
30+
31+
32+
def _cli_name(field_obj) -> str:
33+
return field_obj.metadata.get("cli_name", field_obj.name.replace("_", "-"))
34+
35+
36+
def _build_setting_parameter(field_obj) -> inspect.Parameter:
37+
option_names = []
38+
cli_name = _cli_name(field_obj)
39+
if _is_bool(field_obj.type):
40+
option_names.append(f"--{cli_name}/--no-{cli_name}")
41+
else:
42+
option_names.append(f"--{cli_name}")
43+
44+
option = typer.Option(
45+
None,
46+
*option_names,
47+
help=_option_help(field_obj),
48+
show_default=False,
49+
)
50+
51+
annotation = Optional[_strip_optional(field_obj.type)]
52+
return inspect.Parameter(
53+
field_obj.name,
54+
inspect.Parameter.KEYWORD_ONLY,
55+
default=option,
56+
annotation=annotation,
57+
)
58+
59+
60+
def build_configure_signature() -> inspect.Signature:
61+
params = [
62+
inspect.Parameter(
63+
"ctx",
64+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
65+
annotation=typer.Context,
66+
),
67+
inspect.Parameter(
68+
"force",
69+
inspect.Parameter.KEYWORD_ONLY,
70+
default=typer.Option(
71+
False,
72+
"-f",
73+
"--force",
74+
help=(
75+
"Force update even if critical settings change "
76+
"(like model or embedding dimension)"
77+
),
78+
),
79+
annotation=bool,
80+
),
81+
]
82+
83+
for field_obj in SETTINGS_FIELDS:
84+
params.append(_build_setting_parameter(field_obj))
85+
86+
return inspect.Signature(params)
87+
88+
89+
def filter_setting_updates(settings_values: dict[str, Any]) -> dict[str, Any]:
90+
"""Return only the settings provided as CLI overrides."""
91+
return {
92+
key: value
93+
for key, value in settings_values.items()
94+
if key in SETTINGS_FIELD_NAMES and value is not None
95+
}

0 commit comments

Comments
 (0)