11"""Compute and update an optimal query adapter."""
22
3+ # ruff: noqa: N806
4+
5+ from dataclasses import replace
6+
37import numpy as np
8+ from scipy .optimize import lsq_linear
49from sqlalchemy import text
510from sqlalchemy .orm .attributes import flag_modified
611from sqlmodel import Session , col , select
1015from raglite ._database import Chunk , ChunkEmbedding , Eval , IndexMetadata , create_database_engine
1116from raglite ._embed import embed_strings
1217from raglite ._search import vector_search
13- from raglite ._typing import FloatMatrix
18+ from raglite ._typing import FloatMatrix , FloatVector
19+
20+
21+ def _optimize_query_target (
22+ q : FloatVector ,
23+ P : FloatMatrix , # noqa: N803,
24+ N : FloatMatrix , # noqa: N803,
25+ * ,
26+ α : float = 0.05 , # noqa: PLC2401
27+ ) -> FloatVector :
28+ # Convert to double precision for the optimizer.
29+ q_dtype = q .dtype
30+ q , P , N = q .astype (np .float64 ), P .astype (np .float64 ), N .astype (np .float64 )
31+ # Construct the constraint matrix D := P - (1 + α) * N. # noqa: RUF003
32+ D = np .reshape (P [:, np .newaxis , :] - (1.0 + α ) * N [np .newaxis , :, :], (- 1 , P .shape [1 ]))
33+ # Solve the dual problem min_μ ½ ‖q + Dᵀ μ‖² s.t. μ ≥ 0.
34+ A , b = D .T , - q
35+ μ_star = lsq_linear (A , b , bounds = (0.0 , np .inf ), tol = np .finfo (A .dtype ).eps ).x # noqa: PLC2401
36+ # Recover the primal solution q* = q + Dᵀ μ*.
37+ q_star : FloatVector = (q + D .T @ μ_star ).astype (q_dtype )
38+ return q_star
1439
1540
16- def update_query_adapter ( # noqa: C901, PLR0915
41+ def update_query_adapter ( # noqa: PLR0915
1742 * ,
18- max_triplets : int = 4096 ,
19- max_triplets_per_eval : int = 64 ,
43+ max_evals : int = 4096 ,
2044 optimize_top_k : int = 40 ,
2145 optimize_gap : float = 0.05 ,
2246 config : RAGLiteConfig | None = None ,
@@ -28,57 +52,75 @@ def update_query_adapter( # noqa: C901, PLR0915
2852 order to improve the quality of the search results.
2953
3054 Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
31- score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.
55+ score pᵢᵀqᵢ of the positive chunk pᵢ and decreases the score nᵢᵀqᵢ of the negative chunk nᵢ.
3256
3357 If the nearest neighbour search uses the dot product as its relevance score, we can find the
3458 optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
3559 bound on the Frobenius norm of A:
3660
37- A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
38- Σᵢ (pᵢ - nᵢ)' A qᵢ
39- trace[ (P - N) A Q' ] where Q := [q₁' ; ...; qₖ' ]
40- P := [p₁' ; ...; pₖ' ]
41- N := [n₁' ; ...; nₖ' ]
42- trace[ Q' (P - N) A ]
43- trace[ M A ] where M := Q' (P - N)
44- s.t. ||A||_F == 1
45- = M' / ||M||_F
61+ A* : = argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ)
62+ Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ
63+ trace[ (P - N) A Qᵀ ] where Q := [q₁ᵀ ; ...; qₖᵀ ]
64+ P := [p₁ᵀ ; ...; pₖᵀ ]
65+ N := [n₁ᵀ ; ...; nₖᵀ ]
66+ trace[ Qᵀ (P - N) A ]
67+ trace[ Mᵀ A ] where M := (P - N)ᵀ Q
68+ s.t. ||A||_F == 1
69+ = M / ||M||_F
4670
4771 If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
4872 the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
49- with an orthogonality constraint on A:
50-
51- A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
52- Σᵢ (pᵢ - nᵢ)' A qᵢ
53- trace[ (P - N) A Q' ]
54- trace[ Q' (P - N) A ]
55- trace[ M A ]
56- trace[ U Σ V' A ] where U Σ V' := M is the SVD of M
57- trace[ Σ V' A U ]
58- s.t. A'A == 𝕀
59- = V U'
60-
61- Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
62- incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
63- To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
64- and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
65- between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
66- would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
67- C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
73+ [1] with an orthogonality constraint on A:
74+
75+ A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ)
76+ Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ
77+ trace[ (P - N) A Qᵀ ]
78+ trace[ Qᵀ (P - N) A ]
79+ trace[ Mᵀ A ]
80+ trace[ (U Σ V)ᵀ A ] where U Σ Vᵀ := M is the SVD of M
81+ trace[ Σ V A Uᵀ ]
82+ s.t. AᵀA == 𝕀
83+ = U Vᵀ
84+
85+ The action of A* is to map a query embedding qᵢ to a target vector t := (pᵢ - nᵢ) that maximally
86+ separates the positive and negative chunks. For a given query embedding qᵢ, a retrieval method
87+ will yield a result set containing both positive and negative chunks. Instead of extracting
88+ multiple triplets (qᵢ, pᵢ, nᵢ) from each such result set, we can compute a single optimal target
89+ vector t* for the query embedding qᵢ as follows:
90+
91+ t* := argmax ½ ||t - qᵢ||²
92+ s.t. Dᵢ t >= 0
93+
94+ where the constraint matrix Dᵢ := [pₘᵀ - (1 + α) nₙᵀ]ₘₙ comprises all pairs of positive and
95+ negative chunk embeddings in the result set corresponding to the query embedding qᵢ. This
96+ optimisation problem expresses the idea that the target vector t* should be as close as
97+ possible to the query embedding qᵢ, while separating all positive and negative chunk embeddings
98+ in the result set by a margin of at least α. To solve this problem, we'll first introduce
99+ a Lagrangian with Lagrange multipliers μ:
100+
101+ L(t, μ) := ½ ||t - qᵢ||² + μᵀ (-Dᵢ t)
102+
103+ Now we can set the gradient of the Lagrangian to zero to find the optimal target vector t*:
104+
105+ ∇ₜL = t - qᵢ - Dᵢᵀ μ = 0
106+ t* = qᵢ + Dᵢᵀ μ*
107+
108+ where μ* is the solution to the dual nonnegative least squares problem
109+
110+ μ* := argmin ½ ||qᵢ + Dᵢᵀ μ||²
111+ s.t. μ >= 0
68112
69113 Parameters
70114 ----------
71- max_triplets
72- The maximum number of (q, p, n) triplets to compute. Each triplet corresponds to a rank-one
73- update of the query adapter A.
74- max_triplets_per_eval
75- The maximum number of (q, p, n) triplets a single eval may contribute to the query adapter.
115+ max_evals
116+ The maximum number of evals to use to compute the query adapter. Each eval corresponds to a
117+ rank-one update of the query adapter A.
76118 optimize_top_k
77- The number of search results per eval to extract (q, p, n) triplets from .
119+ The number of search results per eval to optimize .
78120 optimize_gap
79- The strength of the query adapter, expressed as a fraction between 0 and 1 of the maximum
80- relevance score gap. Should be large enough to correct incorrectly ranked results, but small
81- enough to not affect correctly ranked results.
121+ The strength of the query adapter, expressed as a nonnegative number. Should be large enough
122+ to correct incorrectly ranked results, but small enough to not affect correctly ranked
123+ results.
82124 config
83125 The RAGLite config to use to construct and store the query adapter.
84126
@@ -87,7 +129,7 @@ def update_query_adapter( # noqa: C901, PLR0915
87129 ValueError
88130 If no documents have been inserted into the database yet.
89131 ValueError
90- If there aren't enough evals to compute the query adapter yet.
132+ If no evals have been inserted into the database yet.
91133 ValueError
92134 If the `config.vector_search_distance_metric` is not supported.
93135
@@ -97,98 +139,69 @@ def update_query_adapter( # noqa: C901, PLR0915
97139 The query adapter.
98140 """
99141 config = config or RAGLiteConfig ()
100- config_no_query_adapter = RAGLiteConfig (
101- ** {** config .__dict__ , "vector_search_query_adapter" : False }
102- )
142+ config_no_query_adapter = replace (config , vector_search_query_adapter = False )
103143 engine = create_database_engine (config )
104144 with Session (engine ) as session :
105145 # Get random evals from the database.
106146 chunk_embedding = session .exec (select (ChunkEmbedding ).limit (1 )).first ()
107147 if chunk_embedding is None :
108148 error_message = "First run `insert_document()` to insert documents."
109149 raise ValueError (error_message )
110- evals = session .exec (select (Eval ).order_by (Eval .id ).limit (max_triplets )).all ()
111- # Exit if there aren't enough evals to compute the query adapter.
112- embedding_dim = len (chunk_embedding .embedding )
113- required_evals = np .ceil (embedding_dim / max_triplets_per_eval ) - len (evals )
114- if required_evals > 0 :
115- error_message = f"First run `insert_evals()` to generate { required_evals } more evals."
150+ evals = session .exec (select (Eval ).order_by (Eval .id ).limit (max_evals )).all ()
151+ if len (evals ) == 0 :
152+ error_message = "First run `insert_evals()` to generate evals."
116153 raise ValueError (error_message )
117- # Loop over the evals to generate (q, p, n) triplets.
118- Q = np .zeros ((0 , embedding_dim )) # noqa: N806
119- P = np .zeros_like (Q ) # noqa: N806
120- N = np .zeros_like (Q ) # noqa: N806
121- for eval_ in tqdm (
122- evals , desc = "Extracting triplets from evals" , unit = "eval" , dynamic_ncols = True
123- ):
154+ # Construct the query and target matrices.
155+ Q = np .zeros ((0 , len (chunk_embedding .embedding )))
156+ T = np .zeros_like (Q )
157+ for eval_ in tqdm (evals , desc = "Optimizing evals" , unit = "eval" , dynamic_ncols = True ):
124158 # Embed the question.
125- question_embedding = embed_strings ([eval_ .question ], config = config )
159+ q = embed_strings ([eval_ .question ], config = config )[ 0 ]
126160 # Retrieve chunks that would be used to answer the question.
127161 chunk_ids , _ = vector_search (
128- question_embedding [ 0 ] , num_results = optimize_top_k , config = config_no_query_adapter
162+ q , num_results = optimize_top_k , config = config_no_query_adapter
129163 )
130164 retrieved_chunks = session .exec (select (Chunk ).where (col (Chunk .id ).in_ (chunk_ids ))).all ()
131165 retrieved_chunks = sorted (retrieved_chunks , key = lambda chunk : chunk_ids .index (chunk .id ))
132- # Extract (q, p, n) triplets from the eval.
133- num_triplets = 0
134- for i , retrieved_chunk in enumerate (retrieved_chunks ):
135- # Only loop over irrelevant chunks.
136- if retrieved_chunk .id not in eval_ .chunk_ids :
137- continue
138- irrelevant_chunk = retrieved_chunk
139- # Grab the negative chunk embedding of this irrelevant chunk.
140- n_top = irrelevant_chunk .embedding_matrix [
141- [np .argmax (irrelevant_chunk .embedding_matrix @ question_embedding .T )]
166+ # Skip this eval if it doesn't contain both relevant and irrelevant chunks.
167+ is_relevant = np .array ([chunk .id in eval_ .chunk_ids for chunk in retrieved_chunks ])
168+ if not np .any (is_relevant ) or not np .any (~ is_relevant ):
169+ continue
170+ # Extract the positive and negative chunk embeddings.
171+ P = np .vstack (
172+ [
173+ chunk .embedding_matrix [[np .argmax (chunk .embedding_matrix @ q )]]
174+ for chunk in np .array (retrieved_chunks )[is_relevant ]
142175 ]
143- # Grab the positive chunk embeddings that are ranked lower than the negative one.
144- p_top = [
145- chunk .embedding_matrix [
146- [np .argmax (chunk .embedding_matrix @ question_embedding .T )]
147- ]
148- for chunk in retrieved_chunks [i + 1 :] # Chunks that are ranked lower.
149- if chunk is not None and chunk .id in eval_ .chunk_ids
176+ )
177+ N = np .vstack (
178+ [
179+ chunk .embedding_matrix [[np .argmax (chunk .embedding_matrix @ q )]]
180+ for chunk in np .array (retrieved_chunks )[~ is_relevant ]
150181 ]
151- # Ensure that we only have (q, p, n) triplets for which p is ranked lower than n.
152- p_top = [p for p in p_top if (n_top - p ) @ question_embedding .T > 0 ]
153- if not p_top :
154- continue
155- # Stack the (q, p, n) triplets.
156- p = np .vstack (p_top )
157- n = np .repeat (n_top , p .shape [0 ], axis = 0 )
158- q = np .repeat (question_embedding , p .shape [0 ], axis = 0 )
159- num_triplets += p .shape [0 ]
160- # Append the (q, p, n) triplets to the Q, P, N matrices.
161- Q = np .vstack ([Q , q ]) # noqa: N806
162- P = np .vstack ([P , p ]) # noqa: N806
163- N = np .vstack ([N , n ]) # noqa: N806
164- # Stop if we have enough triplets for this eval.
165- if num_triplets >= max_triplets_per_eval :
166- break
167- # Stop if we have enough triplets to compute the query adapter.
168- if Q .shape [0 ] > max_triplets :
169- Q , P , N = Q [:max_triplets , :], P [:max_triplets , :], N [:max_triplets , :] # noqa: N806
170- break
171- # Normalise the rows of Q, P, N.
172- Q /= np .linalg .norm (Q , axis = 1 , keepdims = True ) # noqa: N806
173- P /= np .linalg .norm (P , axis = 1 , keepdims = True ) # noqa: N806
174- N /= np .linalg .norm (N , axis = 1 , keepdims = True ) # noqa: N806
175- # Compute the optimal weighted query adapter A*.
176- # TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
177- gap_before = np .sum ((P - N ) * Q , axis = 1 )
178- gap_after = 2 * (1 - np .sum (P * N , axis = 1 )) / np .linalg .norm (P - N , axis = 1 )
179- gap_target = optimize_gap * gap_after
180- α = (gap_before - gap_target ) / (gap_before - gap_after ) # noqa: PLC2401
181- MT = (α [:, np .newaxis ] * (P - N )).T @ Q # noqa: N806
182- s = np .linalg .norm (MT , ord = "fro" ) / np .sqrt (MT .shape [0 ])
183- MT = np .mean (α ) * (MT / s ) + np .mean (1 - α ) * np .eye (Q .shape [1 ]) # noqa: N806
184- A_star : FloatMatrix # noqa: N806
182+ )
183+ # Compute the optimal target vector t for this query embedding q.
184+ t = _optimize_query_target (q , P , N , α = optimize_gap )
185+ Q = np .vstack ([Q , q [np .newaxis , :]])
186+ T = np .vstack ([T , t [np .newaxis , :]])
187+ # Normalise the rows of Q and T.
188+ Q /= np .linalg .norm (Q , axis = 1 , keepdims = True )
189+ if config .vector_search_distance_metric == "cosine" :
190+ T /= np .linalg .norm (T , axis = 1 , keepdims = True )
191+ # Compute the optimal unconstrained query adapter M.
192+ n , d = Q .shape
193+ M = (1 / n ) * T .T @ Q
194+ if n < d or np .linalg .matrix_rank (Q ) < d :
195+ M += np .eye (d ) - Q .T @ np .linalg .pinv (Q @ Q .T ) @ Q
196+ # Compute the optimal constrained query adapter A* from M, given the distance metric.
197+ A_star : FloatMatrix
185198 if config .vector_search_distance_metric == "dot" :
186199 # Use the relaxed Procrustes solution.
187- A_star = MT / np .linalg .norm (MT , ord = "fro" ) # noqa: N806
200+ A_star = M / np .linalg .norm (M , ord = "fro" ) * np . sqrt ( d )
188201 elif config .vector_search_distance_metric == "cosine" :
189202 # Use the orthogonal Procrustes solution.
190- U , _ , VT = np .linalg .svd (MT , full_matrices = False ) # noqa: N806
191- A_star = U @ VT # noqa: N806
203+ U , _ , VT = np .linalg .svd (M , full_matrices = False )
204+ A_star = U @ VT
192205 else :
193206 error_message = f"Unsupported metric: { config .vector_search_distance_metric } "
194207 raise ValueError (error_message )
@@ -200,4 +213,6 @@ def update_query_adapter( # noqa: C901, PLR0915
200213 session .commit ()
201214 if engine .dialect .name == "duckdb" :
202215 session .execute (text ("CHECKPOINT;" ))
216+ # Clear the index metadata cache to allow the new query adapter to be used.
217+ IndexMetadata ._get .cache_clear () # noqa: SLF001
203218 return A_star
0 commit comments