Skip to content

Commit 42f798b

Browse files
author
Daniele Briggi
committed
feat(cli): repl keep model loaded
1 parent 7fc8738 commit 42f798b

File tree

10 files changed

+146
-84
lines changed

10 files changed

+146
-84
lines changed

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
semsearch/
2-
docs/
3-
samples/headlines
1+
samples
2+
extensions
43

54
# LLM models
65
*.gguf

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ repos:
1515
rev: v2.3.1
1616
hooks:
1717
- id: autoflake
18-
args: ["--remove-all-unused-imports", "--remove-unused-variables", "--ignore-init-module-imports", "--in-place", "--recursive", "."]
18+
args: ["--remove-all-unused-imports", "--remove-unused-variables", "--ignore-init-module-imports", "--in-place", "--recursive", "./src", "./test"]
1919
- repo: https://github.com/pycqa/isort
2020
rev: 6.0.1
2121
hooks:

model_evaluation/ms_marco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def create_example_config():
462462
"chunk_overlap": 0,
463463
"weight_fts": 1.0,
464464
"weight_vec": 1.0,
465-
"model_path_or_name": "./models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf",
465+
"model_path": "./models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf",
466466
"model_options": "",
467467
"model_context_options": "generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8",
468468
"vector_type": "INT8",

src/sqlite_rag/cli.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
import os
44
import shlex
5-
import sys
65
import time
76
from pathlib import Path
87
from typing import Optional
@@ -16,43 +15,79 @@
1615
from .sqliterag import SQLiteRag
1716

1817

18+
class RAGContext:
19+
"""Manage CLI state and RAG object reuse"""
20+
21+
def __init__(self):
22+
self.rag: Optional[SQLiteRag] = None
23+
self.in_repl = False
24+
self.database: str = ""
25+
26+
def enter_repl(self):
27+
"""Enter REPL mode"""
28+
self.in_repl = True
29+
30+
def get_rag(self, database_path: str, require_existing: bool = False) -> SQLiteRag:
31+
"""Create or reuse SQLiteRag instance"""
32+
if not self.database:
33+
raise ValueError("Database path not set. Use --database option.")
34+
35+
if self.in_repl:
36+
if self.rag is None:
37+
typer.echo(f"Debug: Using database path: {self.database}")
38+
self.rag = SQLiteRag.create(
39+
self.database, require_existing=require_existing
40+
)
41+
typer.echo(f"Database: {Path(self.database).resolve()}")
42+
return self.rag
43+
else:
44+
# Regular mode - create new instance
45+
return SQLiteRag.create(database_path, require_existing=require_existing)
46+
47+
48+
rag_context = RAGContext()
49+
50+
1951
class CLI:
2052
"""Main class to handle CLI commands"""
2153

2254
def __init__(self, app: typer.Typer):
2355
self.app = app
2456

2557
def __call__(self, *args, **kwds):
26-
if len(sys.argv) == 1:
27-
repl_mode()
28-
else:
29-
self.app()
58+
self.app()
3059

3160

3261
app = typer.Typer()
3362
cli = CLI(app)
3463

3564

36-
# Global database option
37-
def database_option():
38-
return typer.Option(
65+
@app.callback(invoke_without_command=True)
66+
def main(
67+
ctx: typer.Context,
68+
database: str = typer.Option(
3969
"./sqliterag.sqlite",
4070
"--database",
4171
"-db",
4272
help="Path to the SQLite database file",
43-
)
44-
73+
),
74+
):
75+
"""SQLite RAG - Retrieval Augmented Generation with SQLite"""
76+
ctx.ensure_object(dict)
77+
rag_context.database = database
78+
ctx.obj["rag_context"] = rag_context
4579

46-
def show_database_path(db_path: str):
47-
"""Display current database path"""
48-
typer.echo(f"Database: {Path(db_path).resolve()}")
80+
# If no subcommand was invoked, enter REPL mode
81+
if ctx.invoked_subcommand is None:
82+
rag_context.enter_repl()
83+
repl_mode()
4984

5085

5186
@app.command("settings")
52-
def show_settings(database: str = database_option()):
87+
def show_settings(ctx: typer.Context):
5388
"""Show current settings"""
54-
show_database_path(database)
55-
rag = SQLiteRag.create(database, require_existing=True)
89+
rag_context = ctx.obj["rag_context"]
90+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
5691
current_settings = rag.get_settings()
5792

5893
typer.echo("Current settings:")
@@ -62,9 +97,14 @@ def show_settings(database: str = database_option()):
6297

6398
@app.command("configure")
6499
def configure_settings(
65-
database: str = database_option(),
66-
model_path_or_name: Optional[str] = typer.Option(
67-
None, help="Path to the embedding model file or Hugging Face model name"
100+
ctx: typer.Context,
101+
force: bool = typer.Option(
102+
False,
103+
"--force",
104+
help="Force update even if critical settings change (like model or embedding dimension)",
105+
),
106+
model_path: Optional[str] = typer.Option(
107+
None, help="Path to the embedding model file (.gguf)"
68108
),
69109
model_config: Optional[str] = typer.Option(
70110
None, help="Model configuration parameters"
@@ -103,11 +143,11 @@ def configure_settings(
103143
and search weights. Only specify the options you want to change.
104144
Use 'sqlite-rag settings' to view current values.
105145
"""
106-
show_database_path(database)
146+
rag_context = ctx.obj["rag_context"]
107147

108148
# Build updates dict from all provided parameters
109149
updates = {
110-
"model_path_or_name": model_path_or_name,
150+
"model_path": model_path,
111151
"model_config": model_config,
112152
"embedding_dim": embedding_dim,
113153
"vector_type": vector_type,
@@ -125,21 +165,21 @@ def configure_settings(
125165

126166
if not updates:
127167
typer.echo("No settings provided to configure.")
128-
show_settings(database)
168+
show_settings(ctx)
129169
return
130170

131-
conn = Database.new_connection(database)
171+
conn = Database.new_connection(rag_context.database)
132172
settings_manager = SettingsManager(conn)
133-
settings_manager.prepare_settings(updates)
173+
settings_manager.configure(updates)
134174

135-
show_settings(database)
175+
show_settings(ctx)
136176
typer.echo("Settings updated.")
137177

138178

139179
@app.command()
140180
def add(
181+
ctx: typer.Context,
141182
path: str = typer.Argument(..., help="File or directory path to add"),
142-
database: str = database_option(),
143183
recursive: bool = typer.Option(
144184
False, "-r", "--recursive", help="Recursively add all files in directories"
145185
),
@@ -156,10 +196,10 @@ def add(
156196
),
157197
):
158198
"""Add a file path to the database"""
159-
show_database_path(database)
199+
rag_context = ctx.obj["rag_context"]
160200
start_time = time.time()
161201

162-
rag = SQLiteRag.create(database)
202+
rag = rag_context.get_rag(rag_context.database)
163203
rag.add(
164204
path,
165205
recursive=recursive,
@@ -173,9 +213,9 @@ def add(
173213

174214
@app.command()
175215
def add_text(
216+
ctx: typer.Context,
176217
text: str,
177218
uri: Optional[str] = None,
178-
database: str = database_option(),
179219
metadata: Optional[str] = typer.Option(
180220
None,
181221
"--metadata",
@@ -184,17 +224,17 @@ def add_text(
184224
),
185225
):
186226
"""Add a text to the database"""
187-
show_database_path(database)
188-
rag = SQLiteRag.create(database)
227+
rag_context = ctx.obj["rag_context"]
228+
rag = rag_context.get_rag(rag_context.database)
189229
rag.add_text(text, uri=uri, metadata=json.loads(metadata or "{}"))
190230
typer.echo("Text added.")
191231

192232

193233
@app.command("list")
194-
def list_documents(database: str = database_option()):
234+
def list_documents(ctx: typer.Context):
195235
"""List all documents in the database"""
196-
show_database_path(database)
197-
rag = SQLiteRag.create(database, require_existing=True)
236+
rag_context = ctx.obj["rag_context"]
237+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
198238
documents = rag.list_documents()
199239

200240
if not documents:
@@ -220,13 +260,13 @@ def list_documents(database: str = database_option()):
220260

221261
@app.command()
222262
def remove(
263+
ctx: typer.Context,
223264
identifier: str,
224-
database: str = database_option(),
225265
yes: bool = typer.Option(False, "-y", "--yes", help="Skip confirmation prompt"),
226266
):
227267
"""Remove document by path or UUID"""
228-
show_database_path(database)
229-
rag = SQLiteRag.create(database, require_existing=True)
268+
rag_context = ctx.obj["rag_context"]
269+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
230270

231271
# Find the document first
232272
document = rag.find_document(identifier)
@@ -264,14 +304,14 @@ def remove(
264304

265305
@app.command()
266306
def rebuild(
267-
database: str = database_option(),
307+
ctx: typer.Context,
268308
remove_missing: bool = typer.Option(
269309
False, "--remove-missing", help="Remove documents whose files are not found"
270310
),
271311
):
272312
"""Rebuild embeddings and full-text index"""
273-
show_database_path(database)
274-
rag = SQLiteRag.create(database, require_existing=True)
313+
rag_context = ctx.obj["rag_context"]
314+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
275315

276316
typer.echo("Rebuild process...")
277317

@@ -286,12 +326,12 @@ def rebuild(
286326

287327
@app.command()
288328
def reset(
289-
database: str = database_option(),
329+
ctx: typer.Context,
290330
yes: bool = typer.Option(False, "-y", "--yes", help="Skip confirmation prompt"),
291331
):
292332
"""Reset/clear the entire database"""
293-
show_database_path(database)
294-
rag = SQLiteRag.create(database, require_existing=True)
333+
rag_context = ctx.obj["rag_context"]
334+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
295335

296336
# Show warning and ask for confirmation unless -y flag is used
297337
if not yes:
@@ -317,8 +357,8 @@ def reset(
317357

318358
@app.command()
319359
def search(
360+
ctx: typer.Context,
320361
query: str,
321-
database: str = database_option(),
322362
limit: int = typer.Option(10, help="Number of results to return"),
323363
debug: bool = typer.Option(
324364
False,
@@ -331,10 +371,10 @@ def search(
331371
),
332372
):
333373
"""Search for documents using hybrid vector + full-text search"""
334-
show_database_path(database)
374+
rag_context = ctx.obj["rag_context"]
335375
start_time = time.time()
336376

337-
rag = SQLiteRag.create(database, require_existing=True)
377+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
338378
results = rag.search(query, top_k=limit)
339379

340380
search_time = time.time() - start_time
@@ -348,16 +388,16 @@ def search(
348388

349389
@app.command()
350390
def quantize(
351-
database: str = database_option(),
391+
ctx: typer.Context,
352392
cleanup: bool = typer.Option(
353393
False,
354394
"--cleanup",
355395
help="Clean up quantization structures instead of creating them",
356396
),
357397
):
358398
"""Quantize vectors for faster search or clean up quantization structures"""
359-
show_database_path(database)
360-
rag = SQLiteRag.create(database, require_existing=True)
399+
rag_context = ctx.obj["rag_context"]
400+
rag = rag_context.get_rag(rag_context.database, require_existing=True)
361401

362402
if cleanup:
363403
typer.echo("Cleaning up quantization structures...")
@@ -426,6 +466,11 @@ def download_model(
426466
def repl_mode():
427467
"""Interactive REPL mode"""
428468
typer.echo("Entering interactive mode. Type 'help' for commands or 'exit' to quit.")
469+
typer.echo(
470+
"Note: --database and configure commands are not available in REPL mode."
471+
)
472+
473+
disabled_features = ["configure", "--database", "-db"]
429474

430475
while True:
431476
try:
@@ -448,6 +493,11 @@ def repl_mode():
448493
try:
449494
# Parse command and delegate to typer app
450495
args = shlex.split(command)
496+
# Check for disabled commands in REPL
497+
if args and args[0] in disabled_features:
498+
typer.echo("Error: command is not available in REPL mode")
499+
continue
500+
451501
app(args, standalone_mode=False)
452502
except SystemExit:
453503
# Typer raises SystemExit on errors, catch it to stay in REPL

src/sqlite_rag/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def initialize(conn: sqlite3.Connection, settings: Settings) -> sqlite3.Connecti
2121
conn.enable_load_extension(True)
2222
try:
2323
conn.load_extension(
24-
str(importlib.resources.files("sqliteai.binaries.gpu") / "ai")
24+
str(importlib.resources.files("sqliteai.binaries.cpu") / "ai")
2525
)
2626
conn.load_extension(
2727
str(importlib.resources.files("sqlite-vector.binaries") / "vector")

src/sqlite_rag/engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,22 @@ def load_model(self):
2626
"""Load the model model from the specified path
2727
or download it from Hugging Face if not found."""
2828

29-
model_path = Path(self._settings.model_path_or_name)
29+
model_path = Path(self._settings.model_path)
3030
if not model_path.exists():
3131
raise FileNotFoundError(f"Model file not found at {model_path}")
3232

33-
# model_path = self.settings.model_path_or_name
34-
# if not Path(self.settings.model_path_or_name).exists():
33+
# model_path = self.settings.model_path
34+
# if not Path(self.settings.model_path).exists():
3535
# # check if exists locally or try to download it from Hugging Face
3636
# model_path = hf_hub_download(
37-
# repo_id=self.settings.model_path_or_name,
37+
# repo_id=self.settings.model_path,
3838
# filename="model-q4_0.gguf", # GGUF format
3939
# cache_dir="./models"
4040
# )
4141

4242
self._conn.execute(
4343
"SELECT llm_model_load(?, ?);",
44-
(self._settings.model_path_or_name, self._settings.model_options),
44+
(self._settings.model_path, self._settings.model_options),
4545
)
4646

4747
def process(self, document: Document) -> Document:

0 commit comments

Comments
 (0)