Skip to content

Commit bddb36b

Browse files
authored
feat: optimally separate result sets in query adapter (#149)
1 parent c8bc314 commit bddb36b

File tree

2 files changed

+170
-115
lines changed

2 files changed

+170
-115
lines changed

src/raglite/_query_adapter.py

Lines changed: 130 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
"""Compute and update an optimal query adapter."""
22

3+
# ruff: noqa: N806
4+
5+
from dataclasses import replace
6+
37
import numpy as np
8+
from scipy.optimize import lsq_linear
49
from sqlalchemy import text
510
from sqlalchemy.orm.attributes import flag_modified
611
from sqlmodel import Session, col, select
@@ -10,13 +15,32 @@
1015
from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
1116
from raglite._embed import embed_strings
1217
from raglite._search import vector_search
13-
from raglite._typing import FloatMatrix
18+
from raglite._typing import FloatMatrix, FloatVector
19+
20+
21+
def _optimize_query_target(
22+
q: FloatVector,
23+
P: FloatMatrix, # noqa: N803,
24+
N: FloatMatrix, # noqa: N803,
25+
*,
26+
α: float = 0.05, # noqa: PLC2401
27+
) -> FloatVector:
28+
# Convert to double precision for the optimizer.
29+
q_dtype = q.dtype
30+
q, P, N = q.astype(np.float64), P.astype(np.float64), N.astype(np.float64)
31+
# Construct the constraint matrix D := P - (1 + α) * N. # noqa: RUF003
32+
D = np.reshape(P[:, np.newaxis, :] - (1.0 + α) * N[np.newaxis, :, :], (-1, P.shape[1]))
33+
# Solve the dual problem min_μ ½ ‖q + Dᵀ μ‖² s.t. μ ≥ 0.
34+
A, b = D.T, -q
35+
μ_star = lsq_linear(A, b, bounds=(0.0, np.inf), tol=np.finfo(A.dtype).eps).x # noqa: PLC2401
36+
# Recover the primal solution q* = q + Dᵀ μ*.
37+
q_star: FloatVector = (q + D.T @ μ_star).astype(q_dtype)
38+
return q_star
1439

1540

16-
def update_query_adapter( # noqa: C901, PLR0915
41+
def update_query_adapter( # noqa: PLR0915
1742
*,
18-
max_triplets: int = 4096,
19-
max_triplets_per_eval: int = 64,
43+
max_evals: int = 4096,
2044
optimize_top_k: int = 40,
2145
optimize_gap: float = 0.05,
2246
config: RAGLiteConfig | None = None,
@@ -28,57 +52,75 @@ def update_query_adapter( # noqa: C901, PLR0915
2852
order to improve the quality of the search results.
2953
3054
Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
31-
score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.
55+
score pᵢᵀqᵢ of the positive chunk pᵢ and decreases the score nᵢᵀqᵢ of the negative chunk nᵢ.
3256
3357
If the nearest neighbour search uses the dot product as its relevance score, we can find the
3458
optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
3559
bound on the Frobenius norm of A:
3660
37-
A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
38-
Σᵢ (pᵢ - nᵢ)' A qᵢ
39-
trace[ (P - N) A Q' ] where Q := [q₁'; ...; qₖ']
40-
P := [p₁'; ...; pₖ']
41-
N := [n₁'; ...; nₖ']
42-
trace[ Q' (P - N) A ]
43-
trace[ M A ] where M := Q' (P - N)
44-
s.t. ||A||_F == 1
45-
= M' / ||M||_F
61+
A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ)
62+
Σᵢ (pᵢ - nᵢ) A qᵢ
63+
trace[ (P - N) A Qᵀ ] where Q := [q₁; ...; qₖᵀ]
64+
P := [p₁; ...; pₖᵀ]
65+
N := [n₁; ...; nₖᵀ]
66+
trace[ Qᵀ (P - N) A ]
67+
trace[ Mᵀ A ] where M := (P - N)ᵀ Q
68+
s.t. ||A||_F == 1
69+
= M / ||M||_F
4670
4771
If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
4872
the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
49-
with an orthogonality constraint on A:
50-
51-
A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
52-
Σᵢ (pᵢ - nᵢ)' A qᵢ
53-
trace[ (P - N) A Q' ]
54-
trace[ Q' (P - N) A ]
55-
trace[ M A ]
56-
trace[ U Σ V' A ] where U Σ V' := M is the SVD of M
57-
trace[ Σ V' A U ]
58-
s.t. A'A == 𝕀
59-
= V U'
60-
61-
Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
62-
incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
63-
To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
64-
and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
65-
between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
66-
would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
67-
C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
73+
[1] with an orthogonality constraint on A:
74+
75+
A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ)
76+
Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ
77+
trace[ (P - N) A Qᵀ ]
78+
trace[ Qᵀ (P - N) A ]
79+
trace[ Mᵀ A ]
80+
trace[ (U Σ V)ᵀ A ] where U Σ Vᵀ := M is the SVD of M
81+
trace[ Σ V A Uᵀ ]
82+
s.t. AᵀA == 𝕀
83+
= U Vᵀ
84+
85+
The action of A* is to map a query embedding qᵢ to a target vector t := (pᵢ - nᵢ) that maximally
86+
separates the positive and negative chunks. For a given query embedding qᵢ, a retrieval method
87+
will yield a result set containing both positive and negative chunks. Instead of extracting
88+
multiple triplets (qᵢ, pᵢ, nᵢ) from each such result set, we can compute a single optimal target
89+
vector t* for the query embedding qᵢ as follows:
90+
91+
t* := argmax ½ ||t - qᵢ||²
92+
s.t. Dᵢ t >= 0
93+
94+
where the constraint matrix Dᵢ := [pₘᵀ - (1 + α) nₙᵀ]ₘₙ comprises all pairs of positive and
95+
negative chunk embeddings in the result set corresponding to the query embedding qᵢ. This
96+
optimisation problem expresses the idea that the target vector t* should be as close as
97+
possible to the query embedding qᵢ, while separating all positive and negative chunk embeddings
98+
in the result set by a margin of at least α. To solve this problem, we'll first introduce
99+
a Lagrangian with Lagrange multipliers μ:
100+
101+
L(t, μ) := ½ ||t - qᵢ||² + μᵀ (-Dᵢ t)
102+
103+
Now we can set the gradient of the Lagrangian to zero to find the optimal target vector t*:
104+
105+
∇ₜL = t - qᵢ - Dᵢᵀ μ = 0
106+
t* = qᵢ + Dᵢᵀ μ*
107+
108+
where μ* is the solution to the dual nonnegative least squares problem
109+
110+
μ* := argmin ½ ||qᵢ + Dᵢᵀ μ||²
111+
s.t. μ >= 0
68112
69113
Parameters
70114
----------
71-
max_triplets
72-
The maximum number of (q, p, n) triplets to compute. Each triplet corresponds to a rank-one
73-
update of the query adapter A.
74-
max_triplets_per_eval
75-
The maximum number of (q, p, n) triplets a single eval may contribute to the query adapter.
115+
max_evals
116+
The maximum number of evals to use to compute the query adapter. Each eval corresponds to a
117+
rank-one update of the query adapter A.
76118
optimize_top_k
77-
The number of search results per eval to extract (q, p, n) triplets from.
119+
The number of search results per eval to optimize.
78120
optimize_gap
79-
The strength of the query adapter, expressed as a fraction between 0 and 1 of the maximum
80-
relevance score gap. Should be large enough to correct incorrectly ranked results, but small
81-
enough to not affect correctly ranked results.
121+
The strength of the query adapter, expressed as a nonnegative number. Should be large enough
122+
to correct incorrectly ranked results, but small enough to not affect correctly ranked
123+
results.
82124
config
83125
The RAGLite config to use to construct and store the query adapter.
84126
@@ -87,7 +129,7 @@ def update_query_adapter( # noqa: C901, PLR0915
87129
ValueError
88130
If no documents have been inserted into the database yet.
89131
ValueError
90-
If there aren't enough evals to compute the query adapter yet.
132+
If no evals have been inserted into the database yet.
91133
ValueError
92134
If the `config.vector_search_distance_metric` is not supported.
93135
@@ -97,98 +139,69 @@ def update_query_adapter( # noqa: C901, PLR0915
97139
The query adapter.
98140
"""
99141
config = config or RAGLiteConfig()
100-
config_no_query_adapter = RAGLiteConfig(
101-
**{**config.__dict__, "vector_search_query_adapter": False}
102-
)
142+
config_no_query_adapter = replace(config, vector_search_query_adapter=False)
103143
engine = create_database_engine(config)
104144
with Session(engine) as session:
105145
# Get random evals from the database.
106146
chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
107147
if chunk_embedding is None:
108148
error_message = "First run `insert_document()` to insert documents."
109149
raise ValueError(error_message)
110-
evals = session.exec(select(Eval).order_by(Eval.id).limit(max_triplets)).all()
111-
# Exit if there aren't enough evals to compute the query adapter.
112-
embedding_dim = len(chunk_embedding.embedding)
113-
required_evals = np.ceil(embedding_dim / max_triplets_per_eval) - len(evals)
114-
if required_evals > 0:
115-
error_message = f"First run `insert_evals()` to generate {required_evals} more evals."
150+
evals = session.exec(select(Eval).order_by(Eval.id).limit(max_evals)).all()
151+
if len(evals) == 0:
152+
error_message = "First run `insert_evals()` to generate evals."
116153
raise ValueError(error_message)
117-
# Loop over the evals to generate (q, p, n) triplets.
118-
Q = np.zeros((0, embedding_dim)) # noqa: N806
119-
P = np.zeros_like(Q) # noqa: N806
120-
N = np.zeros_like(Q) # noqa: N806
121-
for eval_ in tqdm(
122-
evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True
123-
):
154+
# Construct the query and target matrices.
155+
Q = np.zeros((0, len(chunk_embedding.embedding)))
156+
T = np.zeros_like(Q)
157+
for eval_ in tqdm(evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True):
124158
# Embed the question.
125-
question_embedding = embed_strings([eval_.question], config=config)
159+
q = embed_strings([eval_.question], config=config)[0]
126160
# Retrieve chunks that would be used to answer the question.
127161
chunk_ids, _ = vector_search(
128-
question_embedding[0], num_results=optimize_top_k, config=config_no_query_adapter
162+
q, num_results=optimize_top_k, config=config_no_query_adapter
129163
)
130164
retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
131165
retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id))
132-
# Extract (q, p, n) triplets from the eval.
133-
num_triplets = 0
134-
for i, retrieved_chunk in enumerate(retrieved_chunks):
135-
# Only loop over irrelevant chunks.
136-
if retrieved_chunk.id not in eval_.chunk_ids:
137-
continue
138-
irrelevant_chunk = retrieved_chunk
139-
# Grab the negative chunk embedding of this irrelevant chunk.
140-
n_top = irrelevant_chunk.embedding_matrix[
141-
[np.argmax(irrelevant_chunk.embedding_matrix @ question_embedding.T)]
166+
# Skip this eval if it doesn't contain both relevant and irrelevant chunks.
167+
is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks])
168+
if not np.any(is_relevant) or not np.any(~is_relevant):
169+
continue
170+
# Extract the positive and negative chunk embeddings.
171+
P = np.vstack(
172+
[
173+
chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
174+
for chunk in np.array(retrieved_chunks)[is_relevant]
142175
]
143-
# Grab the positive chunk embeddings that are ranked lower than the negative one.
144-
p_top = [
145-
chunk.embedding_matrix[
146-
[np.argmax(chunk.embedding_matrix @ question_embedding.T)]
147-
]
148-
for chunk in retrieved_chunks[i + 1 :] # Chunks that are ranked lower.
149-
if chunk is not None and chunk.id in eval_.chunk_ids
176+
)
177+
N = np.vstack(
178+
[
179+
chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
180+
for chunk in np.array(retrieved_chunks)[~is_relevant]
150181
]
151-
# Ensure that we only have (q, p, n) triplets for which p is ranked lower than n.
152-
p_top = [p for p in p_top if (n_top - p) @ question_embedding.T > 0]
153-
if not p_top:
154-
continue
155-
# Stack the (q, p, n) triplets.
156-
p = np.vstack(p_top)
157-
n = np.repeat(n_top, p.shape[0], axis=0)
158-
q = np.repeat(question_embedding, p.shape[0], axis=0)
159-
num_triplets += p.shape[0]
160-
# Append the (q, p, n) triplets to the Q, P, N matrices.
161-
Q = np.vstack([Q, q]) # noqa: N806
162-
P = np.vstack([P, p]) # noqa: N806
163-
N = np.vstack([N, n]) # noqa: N806
164-
# Stop if we have enough triplets for this eval.
165-
if num_triplets >= max_triplets_per_eval:
166-
break
167-
# Stop if we have enough triplets to compute the query adapter.
168-
if Q.shape[0] > max_triplets:
169-
Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806
170-
break
171-
# Normalise the rows of Q, P, N.
172-
Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806
173-
P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806
174-
N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806
175-
# Compute the optimal weighted query adapter A*.
176-
# TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
177-
gap_before = np.sum((P - N) * Q, axis=1)
178-
gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
179-
gap_target = optimize_gap * gap_after
180-
α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401
181-
MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806
182-
s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
183-
MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806
184-
A_star: FloatMatrix # noqa: N806
182+
)
183+
# Compute the optimal target vector t for this query embedding q.
184+
t = _optimize_query_target(q, P, N, α=optimize_gap)
185+
Q = np.vstack([Q, q[np.newaxis, :]])
186+
T = np.vstack([T, t[np.newaxis, :]])
187+
# Normalise the rows of Q and T.
188+
Q /= np.linalg.norm(Q, axis=1, keepdims=True)
189+
if config.vector_search_distance_metric == "cosine":
190+
T /= np.linalg.norm(T, axis=1, keepdims=True)
191+
# Compute the optimal unconstrained query adapter M.
192+
n, d = Q.shape
193+
M = (1 / n) * T.T @ Q
194+
if n < d or np.linalg.matrix_rank(Q) < d:
195+
M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q
196+
# Compute the optimal constrained query adapter A* from M, given the distance metric.
197+
A_star: FloatMatrix
185198
if config.vector_search_distance_metric == "dot":
186199
# Use the relaxed Procrustes solution.
187-
A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806
200+
A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d)
188201
elif config.vector_search_distance_metric == "cosine":
189202
# Use the orthogonal Procrustes solution.
190-
U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806
191-
A_star = U @ VT # noqa: N806
203+
U, _, VT = np.linalg.svd(M, full_matrices=False)
204+
A_star = U @ VT
192205
else:
193206
error_message = f"Unsupported metric: {config.vector_search_distance_metric}"
194207
raise ValueError(error_message)
@@ -200,4 +213,6 @@ def update_query_adapter( # noqa: C901, PLR0915
200213
session.commit()
201214
if engine.dialect.name == "duckdb":
202215
session.execute(text("CHECKPOINT;"))
216+
# Clear the index metadata cache to allow the new query adapter to be used.
217+
IndexMetadata._get.cache_clear() # noqa: SLF001
203218
return A_star

tests/test_query_adapter.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Test RAGLite's query adapter."""
2+
3+
from dataclasses import replace
4+
5+
import numpy as np
6+
import pytest
7+
8+
from raglite import RAGLiteConfig, insert_evals, update_query_adapter, vector_search
9+
from raglite._database import IndexMetadata
10+
11+
12+
@pytest.mark.slow
13+
def test_query_adapter(raglite_test_config: RAGLiteConfig) -> None:
14+
"""Test the query adapter update functionality."""
15+
# Create a config with and without the query adapter enabled.
16+
config_with_query_adapter = replace(raglite_test_config, vector_search_query_adapter=True)
17+
config_without_query_adapter = replace(raglite_test_config, vector_search_query_adapter=False)
18+
# Verify that there is no query adapter in the database.
19+
Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806
20+
assert Q is None
21+
# Insert evals.
22+
insert_evals(num_evals=2, max_contexts_per_eval=10, config=config_with_query_adapter)
23+
# Update the query adapter.
24+
A = update_query_adapter(config=config_with_query_adapter) # noqa: N806
25+
assert isinstance(A, np.ndarray)
26+
assert A.ndim == 2 # noqa: PLR2004
27+
assert A.shape[0] == A.shape[1]
28+
assert np.isfinite(A).all()
29+
# Verify that there is a query adapter in the database.
30+
Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806
31+
assert isinstance(Q, np.ndarray)
32+
assert Q.ndim == 2 # noqa: PLR2004
33+
assert Q.shape[0] == Q.shape[1]
34+
assert np.isfinite(Q).all()
35+
assert np.all(A == Q)
36+
# Verify that the query adapter affects the results of vector search.
37+
query = "How does Einstein define 'simultaneous events' in his special relativity paper?"
38+
_, scores_qa = vector_search(query, config=config_with_query_adapter)
39+
_, scores_no_qa = vector_search(query, config=config_without_query_adapter)
40+
assert scores_qa != scores_no_qa

0 commit comments

Comments
 (0)