Skip to content

Commit 0911c5b

Browse files
authored
feat: improve query adapter algorithm (#146)
1 parent b0dbad4 commit 0911c5b

File tree

6 files changed

+49
-45
lines changed

6 files changed

+49
-45
lines changed

src/raglite/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# Lazily load the default search method to avoid circular imports.
2727
# TODO: Replace with search_and_rerank_chunk_spans after benchmarking.
2828
def _vector_search(
29-
query: str, *, num_results: int = 10, config: "RAGLiteConfig | None" = None
29+
query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None
3030
) -> tuple[list[ChunkId], list[float]]:
3131
from raglite._search import vector_search
3232

src/raglite/_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def from_body(
136136
)
137137

138138
@staticmethod
139-
def extract_heading_lines(doc: str, leading_only: bool = False) -> list[str]: # noqa: FBT001,FBT002
139+
def extract_heading_lines(doc: str, leading_only: bool = False) -> list[str]: # noqa: FBT001, FBT002
140140
"""Extract the leading or final state of the Markdown headings of a document."""
141141
md = MarkdownIt()
142142
heading_lines = [""] * 6

src/raglite/_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from raglite._database import Chunk, Document, Eval, create_database_engine
1717
from raglite._extract import extract_with_llm
1818
from raglite._rag import add_context, rag, retrieve_context
19-
from raglite._search import hybrid_search, retrieve_chunk_spans, vector_search
19+
from raglite._search import retrieve_chunk_spans, vector_search
2020

2121

2222
def insert_evals( # noqa: C901, PLR0912
@@ -95,7 +95,7 @@ def validate_question(cls, value: str) -> str:
9595
else:
9696
question = question_response.question
9797
# Search for candidate chunks to answer the generated question.
98-
candidate_chunk_ids, _ = hybrid_search(
98+
candidate_chunk_ids, _ = vector_search(
9999
query=question, num_results=max_contexts_per_eval, config=config
100100
)
101101
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]

src/raglite/_extract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def extract_with_llm(
1414
return_type: type[T],
1515
user_prompt: str | list[str],
16-
strict: bool = False, # noqa: FBT001,FBT002
16+
strict: bool = False, # noqa: FBT001, FBT002
1717
config: RAGLiteConfig | None = None,
1818
**kwargs: Any,
1919
) -> T:

src/raglite/_query_adapter.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from raglite._search import vector_search
1313

1414

15-
def update_query_adapter( # noqa: PLR0915, C901
15+
def update_query_adapter( # noqa: C901, PLR0915
1616
*,
1717
max_triplets: int = 4096,
1818
max_triplets_per_eval: int = 64,
@@ -78,59 +78,63 @@ def update_query_adapter( # noqa: PLR0915, C901
7878
evals = session.exec(
7979
select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
8080
).all()
81-
if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding):
82-
error_message = "First run `insert_evals()` to generate sufficient evals."
81+
# Exit if there aren't enough evals to compute the query adapter.
82+
embedding_dim = len(chunk_embedding.embedding)
83+
required_evals = np.ceil(embedding_dim / max_triplets_per_eval) - len(evals)
84+
if required_evals > 0:
85+
error_message = f"First run `insert_evals()` to generate {required_evals} more evals."
8386
raise ValueError(error_message)
8487
# Loop over the evals to generate (q, p, n) triplets.
85-
Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806
88+
Q = np.zeros((0, embedding_dim)) # noqa: N806
8689
P = np.zeros_like(Q) # noqa: N806
8790
N = np.zeros_like(Q) # noqa: N806
8891
for eval_ in tqdm(
8992
evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True
9093
):
9194
# Embed the question.
92-
question_embedding = embed_strings([eval_.question], config=config)[0]
95+
question_embedding = embed_strings([eval_.question], config=config)
9396
# Retrieve chunks that would be used to answer the question.
9497
chunk_ids, _ = vector_search(
95-
question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
98+
question_embedding[0], num_results=optimize_top_k, config=config_no_query_adapter
9699
)
97100
retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
98-
# Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
101+
retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id))
102+
# Extract (q, p, n) triplets from the eval.
99103
num_triplets = 0
100104
for i, retrieved_chunk in enumerate(retrieved_chunks):
101-
# Select irrelevant chunks.
105+
# Only loop over irrelevant chunks.
102106
if retrieved_chunk.id not in eval_.chunk_ids:
103-
# Look up all positive chunks (each represented by the mean of its multi-vector
104-
# embedding) that are ranked lower than this negative one (represented by the
105-
# embedding in the multi-vector embedding that best matches the query).
106-
p_mean = [
107-
np.mean(chunk.embedding_matrix, axis=0, keepdims=True)
108-
for chunk in retrieved_chunks[i + 1 :]
109-
if chunk is not None and chunk.id in eval_.chunk_ids
107+
continue
108+
irrelevant_chunk = retrieved_chunk
109+
# Grab the negative chunk embedding of this irrelevant chunk.
110+
n_top = irrelevant_chunk.embedding_matrix[
111+
[np.argmax(irrelevant_chunk.embedding_matrix @ question_embedding.T)]
112+
]
113+
# Grab the positive chunk embeddings that are ranked lower than the negative one.
114+
p_top = [
115+
chunk.embedding_matrix[
116+
[np.argmax(chunk.embedding_matrix @ question_embedding.T)]
110117
]
111-
n_top = retrieved_chunk.embedding_matrix[
112-
np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T),
113-
np.newaxis,
114-
:,
115-
]
116-
# Filter out any (p, n, q) triplets for which the mean positive embedding ranks
117-
# higher than the top negative one.
118-
p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0]
119-
if not p_mean:
120-
continue
121-
# Stack the (p, n, q) triplets.
122-
p = np.vstack(p_mean)
123-
n = np.repeat(n_top, p.shape[0], axis=0)
124-
q = np.repeat(question_embedding, p.shape[0], axis=0)
125-
num_triplets += p.shape[0]
126-
# Append the (query, positive, negative) tuples to the Q, P, N matrices.
127-
Q = np.vstack([Q, q]) # noqa: N806
128-
P = np.vstack([P, p]) # noqa: N806
129-
N = np.vstack([N, n]) # noqa: N806
130-
# Check if we have sufficient triplets for this eval.
131-
if num_triplets >= max_triplets_per_eval:
132-
break
133-
# Check if we have sufficient triplets to compute the query adapter.
118+
for chunk in retrieved_chunks[i + 1 :] # Chunks that are ranked lower.
119+
if chunk is not None and chunk.id in eval_.chunk_ids
120+
]
121+
# Ensure that we only have (q, p, n) triplets for which p is ranked lower than n.
122+
p_top = [p for p in p_top if (n_top - p) @ question_embedding.T > 0]
123+
if not p_top:
124+
continue
125+
# Stack the (q, p, n) triplets.
126+
p = np.vstack(p_top)
127+
n = np.repeat(n_top, p.shape[0], axis=0)
128+
q = np.repeat(question_embedding, p.shape[0], axis=0)
129+
num_triplets += p.shape[0]
130+
# Append the (q, p, n) triplets to the Q, P, N matrices.
131+
Q = np.vstack([Q, q]) # noqa: N806
132+
P = np.vstack([P, p]) # noqa: N806
133+
N = np.vstack([N, n]) # noqa: N806
134+
# Stop if we have enough triplets for this eval.
135+
if num_triplets >= max_triplets_per_eval:
136+
break
137+
# Stop if we have enough triplets to compute the query adapter.
134138
if Q.shape[0] > max_triplets:
135139
Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806
136140
break

src/raglite/_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def rerank_chunks(
280280
def search_and_rerank_chunks(
281281
query: str,
282282
*,
283-
num_results: int = 10,
283+
num_results: int = 8,
284284
oversample: int = 4,
285285
search: BasicSearchMethod = hybrid_search,
286286
config: RAGLiteConfig | None = None,
@@ -294,7 +294,7 @@ def search_and_rerank_chunks(
294294
def search_and_rerank_chunk_spans( # noqa: PLR0913
295295
query: str,
296296
*,
297-
num_results: int = 10,
297+
num_results: int = 8,
298298
oversample: int = 4,
299299
neighbors: tuple[int, ...] | None = (-1, 1),
300300
search: BasicSearchMethod = hybrid_search,

0 commit comments

Comments
 (0)