1010from raglite ._database import Chunk , ChunkEmbedding , Eval , IndexMetadata , create_database_engine
1111from raglite ._embed import embed_strings
1212from raglite ._search import vector_search
13+ from raglite ._typing import FloatMatrix
1314
1415
1516def update_query_adapter ( # noqa: C901, PLR0915
1617 * ,
1718 max_triplets : int = 4096 ,
1819 max_triplets_per_eval : int = 64 ,
1920 optimize_top_k : int = 40 ,
21+ optimize_gap : float = 0.05 ,
2022 config : RAGLiteConfig | None = None ,
21- ) -> None :
23+ ) -> FloatMatrix :
2224 """Compute an optimal query adapter and update the database with it.
2325
2426 This function computes an optimal linear transform A, called a 'query adapter', that is used to
@@ -63,6 +65,36 @@ def update_query_adapter( # noqa: C901, PLR0915
6365 between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
6466 would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
6567 C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
68+
69+ Parameters
70+ ----------
71+ max_triplets
72+ The maximum number of (q, p, n) triplets to compute. Each triplet corresponds to a rank-one
73+ update of the query adapter A.
74+ max_triplets_per_eval
75+ The maximum number of (q, p, n) triplets a single eval may contribute to the query adapter.
76+ optimize_top_k
77+ The number of search results per eval to extract (q, p, n) triplets from.
78+ optimize_gap
79+ The strength of the query adapter, expressed as a fraction between 0 and 1 of the maximum
80+ relevance score gap. Should be large enough to correct incorrectly ranked results, but small
81+ enough to not affect correctly ranked results.
82+ config
83+ The RAGLite config to use to construct and store the query adapter.
84+
85+ Raises
86+ ------
87+ ValueError
88+ If no documents have been inserted into the database yet.
89+ ValueError
90+ If there aren't enough evals to compute the query adapter yet.
91+ ValueError
92+ If the `config.vector_search_distance_metric` is not supported.
93+
94+ Returns
95+ -------
96+ FloatMatrix
97+ The query adapter.
6698 """
6799 config = config or RAGLiteConfig ()
68100 config_no_query_adapter = RAGLiteConfig (
@@ -75,9 +107,7 @@ def update_query_adapter( # noqa: C901, PLR0915
75107 if chunk_embedding is None :
76108 error_message = "First run `insert_document()` to insert documents."
77109 raise ValueError (error_message )
78- evals = session .exec (
79- select (Eval ).order_by (Eval .id ).limit (max (8 , max_triplets // max_triplets_per_eval ))
80- ).all ()
110+ evals = session .exec (select (Eval ).order_by (Eval .id ).limit (max_triplets )).all ()
81111 # Exit if there aren't enough evals to compute the query adapter.
82112 embedding_dim = len (chunk_embedding .embedding )
83113 required_evals = np .ceil (embedding_dim / max_triplets_per_eval ) - len (evals )
@@ -146,20 +176,21 @@ def update_query_adapter( # noqa: C901, PLR0915
146176 # TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
147177 gap_before = np .sum ((P - N ) * Q , axis = 1 )
148178 gap_after = 2 * (1 - np .sum (P * N , axis = 1 )) / np .linalg .norm (P - N , axis = 1 )
149- gap_target = 0.05 * gap_after
179+ gap_target = optimize_gap * gap_after
150180 α = (gap_before - gap_target ) / (gap_before - gap_after ) # noqa: PLC2401
151181 MT = (α [:, np .newaxis ] * (P - N )).T @ Q # noqa: N806
152182 s = np .linalg .norm (MT , ord = "fro" ) / np .sqrt (MT .shape [0 ])
153183 MT = np .mean (α ) * (MT / s ) + np .mean (1 - α ) * np .eye (Q .shape [1 ]) # noqa: N806
154- if config .vector_search_index_metric == "dot" :
184+ A_star : FloatMatrix # noqa: N806
185+ if config .vector_search_distance_metric == "dot" :
155186 # Use the relaxed Procrustes solution.
156187 A_star = MT / np .linalg .norm (MT , ord = "fro" ) # noqa: N806
157- elif config .vector_search_index_metric == "cosine" :
188+ elif config .vector_search_distance_metric == "cosine" :
158189 # Use the orthogonal Procrustes solution.
159190 U , _ , VT = np .linalg .svd (MT , full_matrices = False ) # noqa: N806
160191 A_star = U @ VT # noqa: N806
161192 else :
162- error_message = f"Unsupported metric: { config .vector_search_index_metric } "
193+ error_message = f"Unsupported metric: { config .vector_search_distance_metric } "
163194 raise ValueError (error_message )
164195 # Store the optimal query adapter in the database.
165196 index_metadata = session .get (IndexMetadata , "default" ) or IndexMetadata (id = "default" )
@@ -169,3 +200,4 @@ def update_query_adapter( # noqa: C901, PLR0915
169200 session .commit ()
170201 if engine .dialect .name == "duckdb" :
171202 session .execute (text ("CHECKPOINT;" ))
203+ return A_star
0 commit comments