Skip to content

Commit be22085

Browse files
cbcoutinhojoein
andauthored
fix: apply score_threshold filtering after fusion queries in local mode (#1138)
* fix: apply score_threshold filtering after fusion queries in local mode The local/memory client was not applying score_threshold filtering after RRF and DBSF fusion operations. This caused query_points with prefetch and fusion queries to return results below the specified score_threshold. This fix adds score_threshold filtering after fusion results are computed, matching the behavior of the remote Qdrant server. * tests: simplify score threshold tests, add formula threshold test --------- Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
1 parent 144f4ea commit be22085

File tree

3 files changed

+218
-1
lines changed

3 files changed

+218
-1
lines changed

qdrant_client/local/local_collection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,10 @@ def _merge_sources(
823823
else:
824824
raise ValueError(f"Fusion method {query.fusion} does not exist")
825825

826+
# Apply score_threshold filtering (matching server behavior)
827+
if score_threshold is not None:
828+
fused = [p for p in fused if p.score >= score_threshold]
829+
826830
# Fetch payload and vectors
827831
ids = [point.id for point in fused]
828832
fetched_points = self.retrieve(

tests/congruence_tests/test_query.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,36 @@ def dense_query_dbsf(self, client: QdrantBase) -> models.QueryResponse:
413413
limit=10,
414414
)
415415

416+
def dense_query_rrf_score_threshold(self, client: QdrantBase) -> list[models.ScoredPoint]:
417+
return client.query_points(
418+
collection_name=COLLECTION_NAME,
419+
prefetch=[
420+
models.Prefetch(
421+
query=self.dense_vector_query_text,
422+
using="text",
423+
)
424+
],
425+
query=models.RrfQuery(rrf=models.Rrf(k=1)),
426+
with_payload=True,
427+
limit=10,
428+
score_threshold=0.25, # should return 3 results: 1.0, 0.5, 0.3(3)
429+
).points
430+
431+
def dense_query_formula_score_threshold(self, client: QdrantBase) -> list[models.ScoredPoint]:
432+
return client.query_points(
433+
collection_name=COLLECTION_NAME,
434+
prefetch=[
435+
models.Prefetch(
436+
query=self.dense_vector_query_text,
437+
using="text",
438+
)
439+
],
440+
query=models.FormulaQuery(formula=models.MultExpression(mult=["$score", 1.0])),
441+
with_payload=True,
442+
limit=10,
443+
score_threshold=1.0, # todo: score threshold is not applied in formula queries in core
444+
).points
445+
416446
def deep_dense_queries_rrf(self, client: QdrantBase) -> models.QueryResponse:
417447
return client.query_points(
418448
collection_name=COLLECTION_NAME,
@@ -1297,10 +1327,15 @@ def test_dense_query_fusion():
12971327
compare_clients_results(
12981328
local_client, http_client, grpc_client, searcher.deep_dense_queries_dbsf
12991329
)
1300-
13011330
compare_clients_results(
13021331
local_client, http_client, grpc_client, searcher.dense_query_parametrized_rrf
13031332
)
1333+
compare_clients_results(
1334+
local_client, http_client, grpc_client, searcher.dense_query_rrf_score_threshold
1335+
)
1336+
compare_clients_results(
1337+
local_client, http_client, grpc_client, searcher.dense_query_formula_score_threshold
1338+
)
13041339

13051340

13061341
def test_dense_query_discovery_context():

tests/test_in_memory.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,181 @@ def test_sparse_in_memory_key_filter_returns_results(qdrant: QdrantClient):
118118
).points
119119

120120
assert [r.id for r in search_result] == [4, 2]
121+
122+
123+
def test_fusion_rrf_score_threshold(qdrant: QdrantClient):
124+
"""Test that RRF fusion with score_threshold correctly filters results.
125+
126+
RRF scores in local mode are normalized and for 5 points we get roughly:
127+
- ID 1: 1.0
128+
- ID 2: 0.667
129+
- ID 3: 0.5
130+
- ID 5: 0.4
131+
- ID 4: 0.333
132+
133+
A threshold of 0.45 should filter out IDs 4 and 5.
134+
"""
135+
qdrant.create_collection(
136+
collection_name="test_collection",
137+
vectors_config={
138+
"text": models.VectorParams(size=4, distance=models.Distance.COSINE),
139+
"image": models.VectorParams(size=4, distance=models.Distance.COSINE),
140+
},
141+
)
142+
143+
qdrant.upsert(
144+
collection_name="test_collection",
145+
wait=True,
146+
points=[
147+
models.PointStruct(
148+
id=1,
149+
vector={"text": [1.0, 0.0, 0.0, 0.0], "image": [1.0, 0.0, 0.0, 0.0]},
150+
),
151+
models.PointStruct(
152+
id=2,
153+
vector={"text": [0.9, 0.1, 0.0, 0.0], "image": [0.9, 0.1, 0.0, 0.0]},
154+
),
155+
models.PointStruct(
156+
id=3,
157+
vector={"text": [0.5, 0.5, 0.0, 0.0], "image": [0.5, 0.5, 0.0, 0.0]},
158+
),
159+
models.PointStruct(
160+
id=4,
161+
vector={"text": [0.0, 1.0, 0.0, 0.0], "image": [0.0, 1.0, 0.0, 0.0]},
162+
),
163+
models.PointStruct(
164+
id=5,
165+
vector={"text": [0.0, 0.0, 1.0, 0.0], "image": [0.0, 0.0, 1.0, 0.0]},
166+
),
167+
],
168+
)
169+
170+
query_vector = [1.0, 0.0, 0.0, 0.0]
171+
172+
# Without score_threshold - should return all 5 points
173+
result_no_threshold = qdrant.query_points(
174+
collection_name="test_collection",
175+
prefetch=[
176+
models.Prefetch(query=query_vector, using="text", limit=10),
177+
models.Prefetch(query=query_vector, using="image", limit=10),
178+
],
179+
query=models.FusionQuery(fusion=models.Fusion.RRF),
180+
limit=10,
181+
)
182+
assert len(result_no_threshold.points) == 5
183+
184+
# Find points with scores below 0.45 - IDs 4 (0.333) and 5 (0.4) should be filtered
185+
low_score_count = sum(1 for p in result_no_threshold.points if p.score < 0.45)
186+
assert low_score_count == 2, f"Expected 2 low-scoring points, got {low_score_count}"
187+
188+
# With a threshold of 0.45, points with scores below should be filtered
189+
result_with_threshold = qdrant.query_points(
190+
collection_name="test_collection",
191+
prefetch=[
192+
models.Prefetch(query=query_vector, using="text", limit=10),
193+
models.Prefetch(query=query_vector, using="image", limit=10),
194+
],
195+
query=models.FusionQuery(fusion=models.Fusion.RRF),
196+
score_threshold=0.45,
197+
limit=10,
198+
)
199+
200+
# Verify all returned points have score >= threshold
201+
for point in result_with_threshold.points:
202+
assert point.score >= 0.45, f"Score {point.score} is below threshold 0.45"
203+
204+
# Key assertion: filtering should reduce the count from 5 to 3
205+
assert len(result_with_threshold.points) == 3, (
206+
f"Expected 3 points after filtering (threshold 0.45), got {len(result_with_threshold.points)}. "
207+
f"Scores: {[p.score for p in result_no_threshold.points]}"
208+
)
209+
210+
211+
def test_fusion_dbsf_score_threshold(qdrant: QdrantClient):
212+
"""Test that DBSF fusion with score_threshold correctly filters results.
213+
214+
DBSF scores for the test data:
215+
- ID 1: ~1.30
216+
- ID 2: ~1.30
217+
- ID 3: ~1.11
218+
- ID 4: ~0.64
219+
- ID 5: ~0.64
220+
221+
A threshold of 1.0 should filter out IDs 4 and 5.
222+
"""
223+
qdrant.create_collection(
224+
collection_name="test_collection",
225+
vectors_config={
226+
"text": models.VectorParams(size=4, distance=models.Distance.COSINE),
227+
"image": models.VectorParams(size=4, distance=models.Distance.COSINE),
228+
},
229+
)
230+
231+
qdrant.upsert(
232+
collection_name="test_collection",
233+
wait=True,
234+
points=[
235+
models.PointStruct(
236+
id=1,
237+
vector={"text": [1.0, 0.0, 0.0, 0.0], "image": [1.0, 0.0, 0.0, 0.0]},
238+
),
239+
models.PointStruct(
240+
id=2,
241+
vector={"text": [0.9, 0.1, 0.0, 0.0], "image": [0.9, 0.1, 0.0, 0.0]},
242+
),
243+
models.PointStruct(
244+
id=3,
245+
vector={"text": [0.5, 0.5, 0.0, 0.0], "image": [0.5, 0.5, 0.0, 0.0]},
246+
),
247+
models.PointStruct(
248+
id=4,
249+
vector={"text": [0.0, 1.0, 0.0, 0.0], "image": [0.0, 1.0, 0.0, 0.0]},
250+
),
251+
models.PointStruct(
252+
id=5,
253+
vector={"text": [0.0, 0.0, 1.0, 0.0], "image": [0.0, 0.0, 1.0, 0.0]},
254+
),
255+
],
256+
)
257+
258+
query_vector = [1.0, 0.0, 0.0, 0.0]
259+
260+
# Without score_threshold - should return all 5 points
261+
result_no_threshold = qdrant.query_points(
262+
collection_name="test_collection",
263+
prefetch=[
264+
models.Prefetch(query=query_vector, using="text", limit=10),
265+
models.Prefetch(query=query_vector, using="image", limit=10),
266+
],
267+
query=models.FusionQuery(fusion=models.Fusion.DBSF),
268+
limit=10,
269+
)
270+
assert len(result_no_threshold.points) == 5
271+
272+
# Find points with scores below 1.0 - IDs 4 and 5 (~0.64) should be filtered
273+
low_score_count = sum(1 for p in result_no_threshold.points if p.score < 1.0)
274+
assert low_score_count == 2, f"Expected 2 low-scoring points, got {low_score_count}"
275+
276+
# With score_threshold of 1.0, points below should be filtered
277+
result_with_threshold = qdrant.query_points(
278+
collection_name="test_collection",
279+
prefetch=[
280+
models.Prefetch(query=query_vector, using="text", limit=10),
281+
models.Prefetch(query=query_vector, using="image", limit=10),
282+
],
283+
query=models.FusionQuery(fusion=models.Fusion.DBSF),
284+
score_threshold=1.0,
285+
limit=10,
286+
)
287+
288+
# Verify all returned points have score >= threshold
289+
for point in result_with_threshold.points:
290+
assert point.score >= 1.0, f"Score {point.score} is below threshold 1.0"
291+
292+
# Key assertion: filtering should reduce the count from 5 to 3
293+
assert len(result_with_threshold.points) == 3, (
294+
f"Expected 3 points after filtering (threshold 1.0), got {len(result_with_threshold.points)}. "
295+
f"Scores: {[p.score for p in result_no_threshold.points]}"
296+
)
297+
298+

0 commit comments

Comments
 (0)