Skip to content

Commit c8bc314

Browse files
authored
feat: add ability to control the gap in query adapter (#147)
1 parent 0911c5b commit c8bc314

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

src/raglite/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class RAGLiteConfig:
6060
# Chunk config used to partition documents into chunks.
6161
chunk_max_size: int = 2048 # Max number of characters per chunk.
6262
# Vector search config.
63-
vector_search_index_metric: Literal["cosine", "dot", "l2"] = "cosine"
63+
vector_search_distance_metric: Literal["cosine", "dot", "l2"] = "cosine"
6464
vector_search_multivector: bool = True
6565
vector_search_query_adapter: bool = True # Only supported for "cosine" and "dot" metrics.
6666
# Reranking config.

src/raglite/_database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no
455455
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
456456
USING hnsw (
457457
(embedding::halfvec({embedding_dim}))
458-
halfvec_{metrics[config.vector_search_index_metric]}_ops
458+
halfvec_{metrics[config.vector_search_distance_metric]}_ops
459459
);
460460
SET hnsw.ef_search = {ef_search};
461461
"""
@@ -505,7 +505,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no
505505
CREATE INDEX vector_search_chunk_index
506506
ON chunk_embedding
507507
USING HNSW (embedding)
508-
WITH (metric = '{metrics[config.vector_search_index_metric]}');
508+
WITH (metric = '{metrics[config.vector_search_distance_metric]}');
509509
"""
510510
session.execute(text(create_vector_index_sql))
511511
session.commit()

src/raglite/_query_adapter.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
1111
from raglite._embed import embed_strings
1212
from raglite._search import vector_search
13+
from raglite._typing import FloatMatrix
1314

1415

1516
def update_query_adapter( # noqa: C901, PLR0915
1617
*,
1718
max_triplets: int = 4096,
1819
max_triplets_per_eval: int = 64,
1920
optimize_top_k: int = 40,
21+
optimize_gap: float = 0.05,
2022
config: RAGLiteConfig | None = None,
21-
) -> None:
23+
) -> FloatMatrix:
2224
"""Compute an optimal query adapter and update the database with it.
2325
2426
This function computes an optimal linear transform A, called a 'query adapter', that is used to
@@ -63,6 +65,36 @@ def update_query_adapter( # noqa: C901, PLR0915
6365
between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
6466
would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
6567
C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
68+
69+
Parameters
70+
----------
71+
max_triplets
72+
The maximum number of (q, p, n) triplets to compute. Each triplet corresponds to a rank-one
73+
update of the query adapter A.
74+
max_triplets_per_eval
75+
The maximum number of (q, p, n) triplets a single eval may contribute to the query adapter.
76+
optimize_top_k
77+
The number of search results per eval to extract (q, p, n) triplets from.
78+
optimize_gap
79+
The strength of the query adapter, expressed as a fraction between 0 and 1 of the maximum
80+
relevance score gap. Should be large enough to correct incorrectly ranked results, but small
81+
enough to not affect correctly ranked results.
82+
config
83+
The RAGLite config to use to construct and store the query adapter.
84+
85+
Raises
86+
------
87+
ValueError
88+
If no documents have been inserted into the database yet.
89+
ValueError
90+
If there aren't enough evals to compute the query adapter yet.
91+
ValueError
92+
If the `config.vector_search_distance_metric` is not supported.
93+
94+
Returns
95+
-------
96+
FloatMatrix
97+
The query adapter.
6698
"""
6799
config = config or RAGLiteConfig()
68100
config_no_query_adapter = RAGLiteConfig(
@@ -75,9 +107,7 @@ def update_query_adapter( # noqa: C901, PLR0915
75107
if chunk_embedding is None:
76108
error_message = "First run `insert_document()` to insert documents."
77109
raise ValueError(error_message)
78-
evals = session.exec(
79-
select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
80-
).all()
110+
evals = session.exec(select(Eval).order_by(Eval.id).limit(max_triplets)).all()
81111
# Exit if there aren't enough evals to compute the query adapter.
82112
embedding_dim = len(chunk_embedding.embedding)
83113
required_evals = np.ceil(embedding_dim / max_triplets_per_eval) - len(evals)
@@ -146,20 +176,21 @@ def update_query_adapter( # noqa: C901, PLR0915
146176
# TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
147177
gap_before = np.sum((P - N) * Q, axis=1)
148178
gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
149-
gap_target = 0.05 * gap_after
179+
gap_target = optimize_gap * gap_after
150180
α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401
151181
MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806
152182
s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
153183
MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806
154-
if config.vector_search_index_metric == "dot":
184+
A_star: FloatMatrix # noqa: N806
185+
if config.vector_search_distance_metric == "dot":
155186
# Use the relaxed Procrustes solution.
156187
A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806
157-
elif config.vector_search_index_metric == "cosine":
188+
elif config.vector_search_distance_metric == "cosine":
158189
# Use the orthogonal Procrustes solution.
159190
U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806
160191
A_star = U @ VT # noqa: N806
161192
else:
162-
error_message = f"Unsupported metric: {config.vector_search_index_metric}"
193+
error_message = f"Unsupported metric: {config.vector_search_distance_metric}"
163194
raise ValueError(error_message)
164195
# Store the optimal query adapter in the database.
165196
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
@@ -169,3 +200,4 @@ def update_query_adapter( # noqa: C901, PLR0915
169200
session.commit()
170201
if engine.dialect.name == "duckdb":
171202
session.execute(text("CHECKPOINT;"))
203+
return A_star

src/raglite/_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def vector_search(
5050
corrected_oversample = oversample * config.chunk_max_size / RAGLiteConfig.chunk_max_size
5151
num_hits = round(corrected_oversample) * max(num_results, 10)
5252
dist = ChunkEmbedding.embedding.distance( # type: ignore[attr-defined]
53-
query_embedding, metric=config.vector_search_index_metric
53+
query_embedding, metric=config.vector_search_distance_metric
5454
).label("dist")
5555
sim = (1.0 - dist).label("sim")
5656
top_vectors = select(ChunkEmbedding.chunk_id, sim).order_by(dist).limit(num_hits).subquery()

0 commit comments

Comments
 (0)