Skip to content

Commit 17cf038

Browse files
authored
feat: parallelize inserts and add benchmarking (#150)
1 parent bddb36b commit 17cf038

17 files changed

+773
-263
lines changed

README.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,27 @@ my_config = RAGLiteConfig(
143143
144144
Next, insert some documents into the database. RAGLite will take care of the [conversion to Markdown](src/raglite/_markdown.py), [optimal level 4 semantic chunking](src/raglite/_split_chunks.py), and [multi-vector embedding with late chunking](src/raglite/_embed.py):
145145

146-
147146
```python
148-
# Insert a document given its file path
147+
# Insert documents given their file path
149148
from pathlib import Path
150-
from raglite import insert_document
149+
from raglite import Document, insert_documents
151150

152-
insert_document(Path("On the Measure of Intelligence.pdf"), config=my_config)
153-
insert_document(Path("Special Relativity.pdf"), config=my_config)
151+
documents = [
152+
Document.from_path(Path("On the Measure of Intelligence.pdf")),
153+
Document.from_path(Path("Special Relativity.pdf")),
154+
]
155+
insert_documents(documents, config=my_config)
154156

155-
# Insert a document given its Markdown content
156-
markdown_content = """
157+
# Insert documents given their text/plain or text/markdown content
158+
content = """
157159
# ON THE ELECTRODYNAMICS OF MOVING BODIES
158160
## By A. EINSTEIN June 30, 1905
159-
It is known that Maxwell
161+
It is known that Maxwell...
160162
"""
161-
insert_document(markdown_content, config=my_config)
163+
documents = [
164+
Document.from_text(content)
165+
]
166+
insert_documents(documents, config=my_config)
162167
```
163168

164169
### 3. Retrieval-Augmented Generation (RAG)

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ llama-cpp-python = ["llama-cpp-python (>=0.3.9)"]
8282
pandoc = ["pypandoc-binary (>=1.13)"]
8383
# Evaluation:
8484
ragas = ["pandas (>=2.1.1)", "ragas (>=0.1.12)"]
85+
# Benchmarking:
86+
bench = [
87+
"faiss-cpu (>=1.11.0)",
88+
"ir_datasets (>=0.5.10)",
89+
"ir_measures (>=0.3.7)",
90+
"llama-index (>=0.12.39)",
91+
"llama-index-vector-stores-faiss (>=0.4.0)",
92+
"openai (>=1.75.0)",
93+
"pandas (>=2.1.1)",
94+
"python-slugify (>=8.0.4)",
95+
]
8596

8697
[tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/
8798
bump_message = "bump: v$current_version → v$new_version"

src/raglite/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""RAGLite."""
22

33
from raglite._config import RAGLiteConfig
4+
from raglite._database import Document
45
from raglite._eval import answer_evals, evaluate, insert_evals
5-
from raglite._insert import insert_document
6+
from raglite._insert import insert_documents
67
from raglite._query_adapter import update_query_adapter
78
from raglite._rag import add_context, async_rag, rag, retrieve_context
89
from raglite._search import (
@@ -20,7 +21,8 @@
2021
# Config
2122
"RAGLiteConfig",
2223
# Insert
23-
"insert_document",
24+
"Document",
25+
"insert_documents",
2426
# Search
2527
"hybrid_search",
2628
"keyword_search",

src/raglite/_bench.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
"""Benchmarking with TREC runs."""
2+
3+
import warnings
4+
from abc import ABC, abstractmethod
5+
from collections.abc import Generator
6+
from dataclasses import replace
7+
from functools import cached_property
8+
from pathlib import Path
9+
from typing import Any
10+
11+
from ir_datasets.datasets.base import Dataset
12+
from ir_measures import ScoredDoc, read_trec_run
13+
from platformdirs import user_data_dir
14+
from slugify import slugify
15+
from tqdm.auto import tqdm
16+
17+
from raglite._config import RAGLiteConfig
18+
19+
20+
class IREvaluator(ABC):
21+
def __init__(
22+
self,
23+
dataset: Dataset,
24+
*,
25+
num_results: int = 10,
26+
insert_variant: str | None = None,
27+
search_variant: str | None = None,
28+
) -> None:
29+
self.dataset = dataset
30+
self.num_results = num_results
31+
self.insert_variant = insert_variant
32+
self.search_variant = search_variant
33+
self.insert_id = (
34+
slugify(self.__class__.__name__.lower().replace("evaluator", ""))
35+
+ (f"_{slugify(insert_variant)}" if insert_variant else "")
36+
+ f"_{slugify(dataset.docs_namespace())}"
37+
)
38+
self.search_id = (
39+
self.insert_id
40+
+ f"@{num_results}"
41+
+ (f"_{slugify(search_variant)}" if search_variant else "")
42+
)
43+
self.cwd = Path(user_data_dir("raglite", ensure_exists=True))
44+
45+
@abstractmethod
46+
def insert_documents(self, max_workers: int | None = None) -> None:
47+
"""Insert all of the dataset's documents into the search index."""
48+
raise NotImplementedError
49+
50+
@abstractmethod
51+
def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
52+
"""Search for documents given a query."""
53+
raise NotImplementedError
54+
55+
@property
56+
def trec_run_filename(self) -> str:
57+
return f"{self.search_id}.trec"
58+
59+
@property
60+
def trec_run_filepath(self) -> Path:
61+
return self.cwd / self.trec_run_filename
62+
63+
def score(self) -> Generator[ScoredDoc, None, None]:
64+
"""Read or compute a TREC run."""
65+
if self.trec_run_filepath.exists():
66+
yield from read_trec_run(self.trec_run_filepath.as_posix()) # type: ignore[no-untyped-call]
67+
return
68+
if not self.search("q0", next(self.dataset.queries_iter()).text):
69+
self.insert_documents()
70+
with self.trec_run_filepath.open(mode="w") as trec_run_file:
71+
for query in tqdm(
72+
self.dataset.queries_iter(),
73+
total=self.dataset.queries_count(),
74+
desc="Running queries",
75+
unit="query",
76+
dynamic_ncols=True,
77+
):
78+
results = self.search(query.query_id, query.text, num_results=self.num_results)
79+
unique_results = {doc.doc_id: doc for doc in sorted(results, key=lambda d: d.score)}
80+
top_results = sorted(unique_results.values(), key=lambda d: d.score, reverse=True)
81+
top_results = top_results[: self.num_results]
82+
for rank, scored_doc in enumerate(top_results):
83+
trec_line = f"{query.query_id} 0 {scored_doc.doc_id} {rank} {scored_doc.score} {self.trec_run_filename}\n"
84+
trec_run_file.write(trec_line)
85+
yield scored_doc
86+
87+
88+
class RAGLiteEvaluator(IREvaluator):
89+
def __init__(
90+
self,
91+
dataset: Dataset,
92+
*,
93+
num_results: int = 10,
94+
insert_variant: str | None = None,
95+
search_variant: str | None = None,
96+
config: RAGLiteConfig | None = None,
97+
):
98+
super().__init__(
99+
dataset,
100+
num_results=num_results,
101+
insert_variant=insert_variant,
102+
search_variant=search_variant,
103+
)
104+
self.db_filepath = self.cwd / f"{self.insert_id}.db"
105+
db_url = f"duckdb:///{self.db_filepath.as_posix()}"
106+
self.config = replace(config or RAGLiteConfig(), db_url=db_url)
107+
108+
def insert_documents(self, max_workers: int | None = None) -> None:
109+
from raglite import Document, insert_documents
110+
111+
documents = [
112+
Document.from_text(doc.text, id=doc.doc_id) for doc in self.dataset.docs_iter()
113+
]
114+
insert_documents(documents, max_workers=max_workers, config=self.config)
115+
116+
def update_query_adapter(self, num_evals: int = 1024) -> None:
117+
from raglite import insert_evals, update_query_adapter
118+
from raglite._database import IndexMetadata
119+
120+
if (
121+
self.config.vector_search_query_adapter
122+
and IndexMetadata.get(config=self.config).get("query_adapter") is None
123+
):
124+
insert_evals(num_evals=num_evals, config=self.config)
125+
update_query_adapter(config=self.config)
126+
127+
def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
128+
from raglite import retrieve_chunks, vector_search
129+
130+
self.update_query_adapter()
131+
chunk_ids, scores = vector_search(query, num_results=2 * num_results, config=self.config)
132+
chunks = retrieve_chunks(chunk_ids, config=self.config)
133+
scored_docs = [
134+
ScoredDoc(query_id=query_id, doc_id=chunk.document.id, score=score)
135+
for chunk, score in zip(chunks, scores, strict=True)
136+
]
137+
return scored_docs
138+
139+
140+
class LlamaIndexEvaluator(IREvaluator):
141+
def __init__(
142+
self,
143+
dataset: Dataset,
144+
*,
145+
num_results: int = 10,
146+
insert_variant: str | None = None,
147+
search_variant: str | None = None,
148+
):
149+
super().__init__(
150+
dataset,
151+
num_results=num_results,
152+
insert_variant=insert_variant,
153+
search_variant=search_variant,
154+
)
155+
self.embedder = "text-embedding-3-large"
156+
self.embedder_dim = 3072
157+
self.persist_path = self.cwd / self.insert_id
158+
159+
def insert_documents(self, max_workers: int | None = None) -> None:
160+
# Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/.
161+
import faiss
162+
from llama_index.core import Document, StorageContext, VectorStoreIndex
163+
from llama_index.embeddings.openai import OpenAIEmbedding
164+
from llama_index.vector_stores.faiss import FaissVectorStore
165+
166+
self.persist_path.mkdir(parents=True, exist_ok=True)
167+
faiss_index = faiss.IndexHNSWFlat(self.embedder_dim, 32, faiss.METRIC_INNER_PRODUCT)
168+
vector_store = FaissVectorStore(faiss_index=faiss_index)
169+
index = VectorStoreIndex.from_documents(
170+
[
171+
Document(id_=doc.doc_id, text=doc.text, metadata={"filename": doc.doc_id})
172+
for doc in self.dataset.docs_iter()
173+
],
174+
storage_context=StorageContext.from_defaults(vector_store=vector_store),
175+
embed_model=OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim),
176+
show_progress=True,
177+
)
178+
index.storage_context.persist(persist_dir=self.persist_path)
179+
180+
@cached_property
181+
def index(self) -> Any:
182+
from llama_index.core import StorageContext, load_index_from_storage
183+
from llama_index.embeddings.openai import OpenAIEmbedding
184+
from llama_index.vector_stores.faiss import FaissVectorStore
185+
186+
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix())
187+
storage_context = StorageContext.from_defaults(
188+
vector_store=vector_store, persist_dir=self.persist_path.as_posix()
189+
)
190+
embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim)
191+
index = load_index_from_storage(storage_context, embed_model=embed_model)
192+
return index
193+
194+
def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
195+
if not self.persist_path.exists():
196+
self.insert_documents()
197+
retriever = self.index.as_retriever(similarity_top_k=2 * num_results)
198+
nodes = retriever.retrieve(query)
199+
scored_docs = [
200+
ScoredDoc(
201+
query_id=query_id,
202+
doc_id=node.metadata.get("filename", node.id_),
203+
score=node.score if node.score is not None else 1.0,
204+
)
205+
for node in nodes
206+
]
207+
return scored_docs
208+
209+
210+
class OpenAIVectorStoreEvaluator(IREvaluator):
211+
def __init__(
212+
self,
213+
dataset: Dataset,
214+
*,
215+
num_results: int = 10,
216+
insert_variant: str | None = None,
217+
search_variant: str | None = None,
218+
):
219+
super().__init__(
220+
dataset,
221+
num_results=num_results,
222+
insert_variant=insert_variant,
223+
search_variant=search_variant,
224+
)
225+
self.vector_store_name = dataset.docs_namespace() + (
226+
f"_{slugify(insert_variant)}" if insert_variant else ""
227+
)
228+
229+
@cached_property
230+
def client(self) -> Any:
231+
import openai
232+
233+
return openai.OpenAI()
234+
235+
@property
236+
def vector_store_id(self) -> str | None:
237+
vector_stores = self.client.vector_stores.list()
238+
vector_store = next((vs for vs in vector_stores if vs.name == self.vector_store_name), None)
239+
if vector_store is None:
240+
return None
241+
if vector_store.file_counts.failed > 0:
242+
warnings.warn(
243+
f"Vector store {vector_store.name} has {vector_store.file_counts.failed} failed files.",
244+
stacklevel=2,
245+
)
246+
if vector_store.file_counts.in_progress > 0:
247+
error_message = f"Vector store {vector_store.name} has {vector_store.file_counts.in_progress} files in progress."
248+
raise RuntimeError(error_message)
249+
return vector_store.id # type: ignore[no-any-return]
250+
251+
def insert_documents(self, max_workers: int | None = None) -> None:
252+
import tempfile
253+
from pathlib import Path
254+
255+
vector_store = self.client.vector_stores.create(name=self.vector_store_name)
256+
files, max_files_per_batch = [], 32
257+
with tempfile.TemporaryDirectory() as temp_dir:
258+
for i, doc in tqdm(
259+
enumerate(self.dataset.docs_iter()),
260+
total=self.dataset.docs_count(),
261+
desc="Inserting documents",
262+
unit="document",
263+
dynamic_ncols=True,
264+
):
265+
if not doc.text.strip():
266+
continue
267+
temp_file = Path(temp_dir) / f"{slugify(doc.doc_id)}.txt"
268+
temp_file.write_text(doc.text)
269+
files.append(temp_file.open("rb"))
270+
if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1):
271+
self.client.vector_stores.file_batches.upload_and_poll(
272+
vector_store_id=vector_store.id, files=files, max_concurrency=max_workers
273+
)
274+
for f in files:
275+
f.close()
276+
files = []
277+
278+
@cached_property
279+
def filename_to_doc_id(self) -> dict[str, str]:
280+
return {f"{slugify(doc.doc_id)}.txt": doc.doc_id for doc in self.dataset.docs_iter()}
281+
282+
def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
283+
if not self.vector_store_id:
284+
return []
285+
response = self.client.vector_stores.search(
286+
vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results
287+
)
288+
scored_docs = [
289+
ScoredDoc(
290+
query_id=query_id,
291+
doc_id=self.filename_to_doc_id[result.filename],
292+
score=result.score,
293+
)
294+
for result in response
295+
]
296+
return scored_docs

0 commit comments

Comments
 (0)