Skip to content

Commit 7f18eee

Browse files
author
Daniele Briggi
committed
refact(chunker): better support to prompts
They are included in the chunk size as long as it's for the overlap text. Chunk's content now is stored without overlap. It's intended to be used as a preview.
1 parent 5b83f4e commit 7f18eee

File tree

16 files changed

+425
-227
lines changed

16 files changed

+425
-227
lines changed

.github/workflows/pypi-package.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ jobs:
5555
# Avoid workflow to fail if the version has already been published
5656
skip-existing: true
5757
# Upload to Test Pypi for testing
58-
repository-url: https://test.pypi.org/legacy/
58+
#repository-url: https://test.pypi.org/legacy/

src/sqlite_rag/chunker.py

Lines changed: 80 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22
import sqlite3
3-
from typing import List
3+
from typing import List, Optional
4+
5+
from sqlite_rag.models.document import Document
46

57
from .models.chunk import Chunk
68
from .settings import Settings
@@ -13,16 +15,47 @@ def __init__(self, conn: sqlite3.Connection, settings: Settings):
1315
self._conn = conn
1416
self._settings = settings
1517

16-
def chunk(self, text: str, metadata: dict = {}) -> list[Chunk]:
18+
def chunk(self, document: Document) -> list[Chunk]:
1719
"""Chunk text using Recursive Character Text Splitter."""
18-
chunks = []
20+
chunk = self._create_chunk(document.content, title=document.get_title())
21+
22+
if (
23+
self._get_token_count(chunk.get_embedding_text())
24+
<= self._settings.chunk_size
25+
):
26+
return [chunk]
27+
28+
return self._recursive_split(document)
29+
30+
def _create_chunk(
31+
self,
32+
content: str,
33+
head_overlap_text: str = "",
34+
title: Optional[str] = None,
35+
) -> Chunk:
36+
prompt = None
37+
if self._settings.use_prompt_templates:
38+
prompt = self._settings.prompt_template_retrieval_document
39+
40+
return Chunk(
41+
content=content,
42+
head_overlap_text=head_overlap_text,
43+
prompt=prompt,
44+
title=title,
45+
)
1946

20-
if self._get_token_count(text) <= self._settings.chunk_size:
21-
chunks = [Chunk(content=text)]
22-
else:
23-
chunks = self._recursive_split(text)
47+
def _get_effective_chunk_size(self, prompt: str) -> int:
48+
"""Calculate effective chunk size considering overlap and other
49+
prompt data useful to the model.
2450
25-
return self._enrich_chunk(chunks, metadata)
51+
Args:
52+
prompt: The prompt template without content.
53+
"""
54+
if self._settings.chunk_size <= self._settings.chunk_overlap:
55+
raise ValueError("Chunk size must be greater than chunk overlap.")
56+
57+
prompt_size = self._get_token_count(prompt)
58+
return self._settings.chunk_size - self._settings.chunk_overlap - prompt_size
2659

2760
def _get_token_count(self, text: str) -> int:
2861
"""Get token count using SQLite AI extension."""
@@ -42,7 +75,7 @@ def _estimate_tokens_count(self, text: str) -> int:
4275
# This is a simple heuristic; adjust as needed
4376
return (len(text) + 3) // self.ESTIMATE_CHARS_PER_TOKEN
4477

45-
def _recursive_split(self, text: str) -> List[Chunk]:
78+
def _recursive_split(self, document: Document) -> List[Chunk]:
4679
"""Recursively split text into chunks with overlap."""
4780
separators = [
4881
"\n\n", # Double newlines (paragraphs)
@@ -59,32 +92,47 @@ def _recursive_split(self, text: str) -> List[Chunk]:
5992
"", # Character level (fallback)
6093
]
6194

62-
chunks = self._split_text_with_separators(text, separators)
63-
return self._apply_overlap(chunks)
95+
empty_chunk = self._create_chunk("", title=document.get_title())
96+
effective_chunk_size = max(
97+
1, self._get_effective_chunk_size(empty_chunk.get_embedding_text())
98+
)
99+
100+
chunks_content = self._split_text_with_separators(
101+
document.content, separators, effective_chunk_size
102+
)
103+
overlaps = self._create_overlaps(chunks_content)
104+
105+
assert len(chunks_content) == len(overlaps), "Mismatch in chunks and overlaps"
106+
return [
107+
self._create_chunk(
108+
content=chunk, head_overlap_text=overlap, title=document.get_title()
109+
)
110+
for chunk, overlap in zip(chunks_content, overlaps)
111+
]
64112

65113
def _split_text_with_separators(
66-
self, text: str, separators: List[str]
67-
) -> List[Chunk]:
68-
"""Split text using hierarchical separators."""
114+
self, text: str, separators: List[str], effective_chunk_size: int
115+
) -> List[str]:
116+
"""Split text using hierarchical separators.
117+
Args:
118+
text: The text to split.
119+
separators: List of separators to use in order.
120+
effective_chunk_size: Reserved space for actual chunk content.
121+
"""
69122
chunks = []
70123

71124
if self._settings.chunk_size <= self._settings.chunk_overlap:
72125
raise ValueError("Chunk size must be greater than chunk overlap.")
73126

74127
if not separators:
75128
# Fallback: character-level splitting
76-
return self._split_by_characters(text)
129+
return self._split_by_characters(text, effective_chunk_size)
77130

78131
separator = separators[0]
79132
remaining_separators = separators[1:]
80133

81134
if separator == "":
82-
return self._split_by_characters(text)
83-
84-
# Reserve space for overlap
85-
effective_chunk_size = max(
86-
1, self._settings.chunk_size - self._settings.chunk_overlap
87-
)
135+
return self._split_by_characters(text, effective_chunk_size)
88136

89137
splits = text.split(separator)
90138
current_chunk = ""
@@ -97,12 +145,12 @@ def _split_text_with_separators(
97145
else:
98146
# Save current chunk if it exists
99147
if current_chunk:
100-
chunks.append(Chunk(content=current_chunk.strip()))
148+
chunks.append(current_chunk)
101149

102150
# If single split is too large, recursively split it
103151
if self._get_token_count(split) > effective_chunk_size:
104152
sub_chunks = self._split_text_with_separators(
105-
split, remaining_separators
153+
split, remaining_separators, effective_chunk_size
106154
)
107155
chunks.extend(sub_chunks)
108156
current_chunk = ""
@@ -111,19 +159,14 @@ def _split_text_with_separators(
111159

112160
# Add final chunk
113161
if current_chunk:
114-
chunks.append(Chunk(content=current_chunk.strip()))
162+
chunks.append(current_chunk)
115163

116164
return chunks
117165

118-
def _split_by_characters(self, text: str) -> List[Chunk]:
166+
def _split_by_characters(self, text: str, effective_chunk_size: int) -> List[str]:
119167
"""Split text at character level when no separators work."""
120168
chunks = []
121169

122-
# Reserve space for overlap
123-
effective_chunk_size = max(
124-
1, self._settings.chunk_size - self._settings.chunk_overlap
125-
)
126-
127170
total_tokens = self._get_token_count(text)
128171
chars_per_token = (
129172
math.ceil(len(text) / total_tokens)
@@ -151,40 +194,29 @@ def _split_by_characters(self, text: str) -> List[Chunk]:
151194
chunk_text = text[start:end]
152195

153196
if chunk_text.strip():
154-
chunks.append(Chunk(content=chunk_text.strip()))
197+
chunks.append(chunk_text)
155198

156199
start = end
157200

158201
return chunks
159202

160-
def _apply_overlap(self, chunks: List[Chunk]) -> List[Chunk]:
203+
def _create_overlaps(self, chunks: List[str]) -> List[str]:
161204
"""Apply overlap between consecutive chunks."""
162205
if len(chunks) <= 1 or self._settings.chunk_overlap <= 0:
163-
return chunks
206+
# Empty overlap for each chunk
207+
return [""] * len(chunks)
164208

165-
overlapped_chunks = [chunks[0]] # First chunk has no overlap
209+
overlapped_chunks = [""] # First chunk has no overlap
166210

167211
for i in range(1, len(chunks)):
168-
current_content = chunks[i].content
169-
prev_content = chunks[i - 1].content
212+
prev_content = chunks[i - 1]
170213

171214
# Get overlap text from end of previous chunk
172215
overlap_text = self._get_overlap_text(
173216
prev_content, self._settings.chunk_overlap
174217
)
175218

176-
if overlap_text:
177-
combined_content = overlap_text + " " + current_content
178-
# Core content starts after overlap and separator
179-
core_start_pos = len(overlap_text) + 1
180-
else:
181-
combined_content = current_content
182-
# No overlap, core starts at beginning
183-
core_start_pos = 0
184-
185-
overlapped_chunks.append(
186-
Chunk(content=combined_content, core_start_pos=core_start_pos)
187-
)
219+
overlapped_chunks.append(overlap_text)
188220

189221
return overlapped_chunks
190222

@@ -202,13 +234,3 @@ def _get_overlap_text(self, text: str, max_overlap_tokens: int) -> str:
202234

203235
# If even single word is too large, return empty
204236
return ""
205-
206-
def _enrich_chunk(self, chunks: List[Chunk], metadata: dict) -> List[Chunk]:
207-
"""Add extra information to chunk which may improve the model embeddings."""
208-
for chunk in chunks:
209-
if "title" in metadata:
210-
chunk.title = metadata["title"]
211-
elif "title" in metadata.get("generated", {}):
212-
chunk.title = metadata["generated"]["title"]
213-
214-
return chunks

src/sqlite_rag/cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,12 @@ def configure_settings(
117117
None, help="Path to the embedding model file (.gguf)"
118118
),
119119
model_options: Optional[str] = typer.Option(
120-
None, help="options specific for the model: See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_model_loadpath-text-options-text"
120+
None,
121+
help="options specific for the model: See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_model_loadpath-text-options-text",
121122
),
122123
model_context_options: Optional[str] = typer.Option(
123-
None, help="Options specific for model context creation. See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_context_createcontext_settings-text"
124+
None,
125+
help="Options specific for model context creation. See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_context_createcontext_settings-text",
124126
),
125127
embedding_dim: Optional[int] = typer.Option(
126128
None, help="Dimension of the embedding vectors"

src/sqlite_rag/database.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _create_schema(conn: sqlite3.Connection, settings: Settings):
8383
document_id TEXT,
8484
content TEXT,
8585
embedding BLOB,
86-
core_start_pos INTEGER DEFAULT 0,
8786
FOREIGN KEY (document_id) REFERENCES documents (id) ON DELETE CASCADE
8887
);
8988
"""

src/sqlite_rag/engine.py

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from sqlite_rag.models.document_result import DocumentResult
88

99
from .chunker import Chunker
10-
from .models.chunk import Chunk
1110
from .models.document import Document
1211
from .settings import Settings
1312

@@ -16,8 +15,6 @@ class Engine:
1615
# Considered a good default to normilize the score for RRF
1716
DEFAULT_RRF_K = 60
1817

19-
GENERATED_TITLE_MAX_CHARS = 100
20-
2118
def __init__(self, conn: sqlite3.Connection, settings: Settings, chunker: Chunker):
2219
self._conn = conn
2320
self._settings = settings
@@ -37,43 +34,38 @@ def load_model(self):
3734
)
3835

3936
def process(self, document: Document) -> Document:
40-
chunks = self._chunker.chunk(document.content, document.metadata)
37+
if not document.get_title():
38+
document.set_generated_title()
39+
40+
chunks = self._chunker.chunk(document)
4141

4242
if self._settings.max_chunks_per_document > 0:
4343
chunks = chunks[: self._settings.max_chunks_per_document]
4444

45-
chunks = self.generate_embedding(chunks)
45+
for chunk in chunks:
46+
chunk.title = document.get_title()
47+
chunk.embedding = self.generate_embedding(chunk.get_embedding_text())
48+
4649
document.chunks = chunks
50+
4751
return document
4852

49-
def generate_embedding(self, chunks: list[Chunk]) -> list[Chunk]:
53+
def generate_embedding(self, text: str) -> bytes:
5054
"""Generate embedding for the given text."""
55+
cursor = self._conn.cursor()
5156

52-
for chunk in chunks:
53-
cursor = self._conn.cursor()
54-
55-
# Format using the prompt template if available
56-
content = chunk.content
57-
if self._settings.use_prompt_templates:
58-
title = chunk.title if chunk.title else "none"
59-
content = self._settings.prompt_template_retrieval_document.format(
60-
title=title, content=chunk.content
61-
)
62-
63-
try:
64-
cursor.execute("SELECT llm_embed_generate(?) AS embedding", (content,))
65-
except sqlite3.Error as e:
66-
print(f"Error generating embedding for chunk\n: ```{content}```")
67-
raise e
68-
69-
result = cursor.fetchone()
57+
try:
58+
cursor.execute("SELECT llm_embed_generate(?) AS embedding", (text,))
59+
except sqlite3.Error as e:
60+
print(f"Error generating embedding for text\n: ```{text}```")
61+
raise e
7062

71-
if result is None:
72-
raise RuntimeError("Failed to generate embedding.")
63+
result = cursor.fetchone()
7364

74-
chunk.embedding = result["embedding"]
65+
if result is None:
66+
raise RuntimeError("Failed to generate embedding.")
7567

76-
return chunks
68+
return result["embedding"]
7769

7870
def quantize(self) -> None:
7971
"""Quantize stored vector for faster search via quantized scan."""
@@ -114,7 +106,7 @@ def free_context(self) -> None:
114106

115107
def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
116108
"""Semantic search and full-text search sorted with Reciprocal Rank Fusion."""
117-
query_embedding = self.generate_embedding([Chunk(content=query)])[0].embedding
109+
query_embedding = self.generate_embedding(query)
118110

119111
# Clean up and split into words
120112
# '*' is used to match while typing
@@ -172,7 +164,6 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
172164
documents.content as document_content,
173165
documents.metadata,
174166
chunks.content AS snippet,
175-
chunks.core_start_pos,
176167
vec_rank,
177168
fts_rank,
178169
combined_rank,
@@ -203,8 +194,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
203194
content=row["document_content"],
204195
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
205196
),
206-
# remove overlapping text from the snippet
207-
snippet=row["snippet"][row["core_start_pos"] :],
197+
snippet=row["snippet"],
208198
vec_rank=row["vec_rank"],
209199
fts_rank=row["fts_rank"],
210200
combined_rank=row["combined_rank"],
@@ -227,24 +217,6 @@ def versions(self) -> dict:
227217
"vector_version": row["vector_version"],
228218
}
229219

230-
def extract_document_title(
231-
self, text: str, fallback_first_line: bool = False
232-
) -> str | None:
233-
"""Extract title from markdown content."""
234-
# Look for first level-1 heading
235-
match = re.search(r"^# (.+)$", text, re.MULTILINE)
236-
if match:
237-
return match.group(1).strip()
238-
239-
# Fallback: first non-empty line
240-
if fallback_first_line:
241-
for line in text.splitlines():
242-
line = line.strip()
243-
if line:
244-
return line[:self.GENERATED_TITLE_MAX_CHARS]
245-
246-
return None
247-
248220
def close(self):
249221
"""Close the database connection."""
250222
if self._conn:

0 commit comments

Comments
 (0)