Skip to content

Commit 0cf2ebb

Browse files
authored
Vector rescoring - Simplify code for k == null (elastic#118997)
1 parent 14c21f2 commit 0cf2ebb

File tree

5 files changed

+34
-79
lines changed

5 files changed

+34
-79
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import org.apache.lucene.search.DoubleValues;
1919
import org.apache.lucene.search.DoubleValuesSource;
2020
import org.apache.lucene.search.IndexSearcher;
21-
import org.elasticsearch.search.profile.query.QueryProfiler;
22-
import org.elasticsearch.search.vectors.QueryProfilerProvider;
2321

2422
import java.io.IOException;
2523
import java.util.Arrays;
@@ -29,12 +27,11 @@
2927
* DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
3028
* original vector values stored in the index
3129
*/
32-
public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider {
30+
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
3331

3432
private final String field;
3533
private final float[] target;
3634
private final VectorSimilarityFunction vectorSimilarityFunction;
37-
private long vectorOpsCount;
3835

3936
public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
4037
this.field = field;
@@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws
5249
return new DoubleValues() {
5350
@Override
5451
public double doubleValue() throws IOException {
55-
vectorOpsCount++;
5652
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index()));
5753
}
5854

@@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
7369
return this;
7470
}
7571

76-
@Override
77-
public void profile(QueryProfiler queryProfiler) {
78-
queryProfiler.addVectorOpsCount(vectorOpsCount);
79-
}
80-
8172
@Override
8273
public int hashCode() {
8374
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -487,11 +487,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
487487
return this;
488488
}
489489

490-
@Override
491-
protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException {
492-
return super.doIndexMetadataRewrite(context);
493-
}
494-
495490
@Override
496491
protected Query doToQuery(SearchExecutionContext context) throws IOException {
497492
MappedFieldType fieldType = context.getFieldType(fieldName);
@@ -529,8 +524,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
529524
String parentPath = context.nestedLookup().getNestedParent(fieldName);
530525
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();
531526

527+
BitSetProducer parentBitSet = null;
532528
if (parentPath != null) {
533-
final BitSetProducer parentBitSet;
534529
final Query parentFilter;
535530
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
536531
if (originalObjectMapper != null) {
@@ -559,17 +554,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
559554
// Now join the filterQuery & parentFilter to provide the matching blocks of children
560555
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
561556
}
562-
return vectorFieldType.createKnnQuery(
563-
queryVector,
564-
k,
565-
adjustedNumCands,
566-
numCandidatesFactor,
567-
filterQuery,
568-
vectorSimilarity,
569-
parentBitSet
570-
);
571557
}
572-
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null);
558+
559+
return vectorFieldType.createKnnQuery(
560+
queryVector,
561+
k,
562+
adjustedNumCands,
563+
numCandidatesFactor,
564+
filterQuery,
565+
vectorSimilarity,
566+
parentBitSet
567+
);
573568
}
574569

575570
@Override

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
3232
private final String fieldName;
3333
private final float[] floatTarget;
3434
private final VectorSimilarityFunction vectorSimilarityFunction;
35-
private final Integer k;
35+
private final int k;
3636
private final Query innerQuery;
37-
38-
private QueryProfilerProvider vectorProfiling;
37+
private long vectorOperations = 0;
3938

4039
public RescoreKnnVectorQuery(
4140
String fieldName,
4241
float[] floatTarget,
4342
VectorSimilarityFunction vectorSimilarityFunction,
44-
Integer k,
43+
int k,
4544
Query innerQuery
4645
) {
4746
this.fieldName = fieldName;
@@ -54,19 +53,12 @@ public RescoreKnnVectorQuery(
5453
@Override
5554
public Query rewrite(IndexSearcher searcher) throws IOException {
5655
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
57-
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need
58-
// to calculate top k and return directly the query to understand how many comparisons were done
59-
vectorProfiling = (QueryProfilerProvider) valueSource;
6056
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
6157
Query query = searcher.rewrite(functionScoreQuery);
6258

63-
if (k == null) {
64-
// No need to calculate top k - let the request size limit the results.
65-
return query;
66-
}
67-
6859
// Retrieve top k documents from the rescored query
6960
TopDocs topDocs = searcher.search(query, k);
61+
vectorOperations = topDocs.totalHits.value();
7062
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
7163
int[] docIds = new int[scoreDocs.length];
7264
float[] scores = new float[scoreDocs.length];
@@ -82,7 +74,7 @@ public Query innerQuery() {
8274
return innerQuery;
8375
}
8476

85-
public Integer k() {
77+
public int k() {
8678
return k;
8779
}
8880

@@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) {
9284
queryProfilerProvider.profile(queryProfiler);
9385
}
9486

95-
if (vectorProfiling == null) {
96-
throw new IllegalStateException("Query should have been rewritten");
97-
}
98-
vectorProfiling.profile(queryProfiler);
87+
queryProfiler.addVectorOpsCount(vectorOperations);
9988
}
10089

10190
@Override

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,18 +456,19 @@ public void testRescoreOversampleModifiesNumCandidates() {
456456
);
457457

458458
// Total results is k, internal k is multiplied by oversample
459-
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10);
459+
checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10);
460460
// If numCands < k, update numCands to k
461-
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10);
461+
checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10);
462462
// Oversampling limits for num candidates
463-
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000);
464-
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000);
463+
checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000);
464+
checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000);
465465
}
466466

467467
private static void checkRescoreQueryParameters(
468468
DenseVectorFieldType fieldType,
469-
Integer k,
469+
int k,
470470
int candidates,
471+
int requestSize,
471472
float numCandsFactor,
472473
Integer expectedK,
473474
int expectedCandidates,

server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12-
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13-
1412
import org.apache.lucene.document.Document;
1513
import org.apache.lucene.document.KnnFloatVectorField;
1614
import org.apache.lucene.index.DirectoryReader;
@@ -33,11 +31,9 @@
3331

3432
import java.io.IOException;
3533
import java.io.UnsupportedEncodingException;
36-
import java.util.ArrayList;
3734
import java.util.Arrays;
3835
import java.util.Collection;
3936
import java.util.HashSet;
40-
import java.util.List;
4137
import java.util.Map;
4238
import java.util.PriorityQueue;
4339
import java.util.stream.Collectors;
@@ -49,21 +45,11 @@
4945
public class RescoreKnnVectorQueryTests extends ESTestCase {
5046

5147
public static final String FIELD_NAME = "float_vector";
52-
private final int numDocs;
53-
private final Integer k;
54-
55-
public RescoreKnnVectorQueryTests(boolean useK) {
56-
this.numDocs = randomIntBetween(10, 100);
57-
this.k = useK ? randomIntBetween(1, numDocs - 1) : null;
58-
}
5948

6049
public void testRescoreDocs() throws Exception {
50+
int numDocs = randomIntBetween(10, 100);
6151
int numDims = randomIntBetween(5, 100);
62-
63-
Integer adjustedK = k;
64-
if (k == null) {
65-
adjustedK = numDocs;
66-
}
52+
int k = randomIntBetween(1, numDocs - 1);
6753

6854
try (Directory d = newDirectory()) {
6955
addRandomDocuments(numDocs, d, numDims);
@@ -77,7 +63,7 @@ public void testRescoreDocs() throws Exception {
7763
FIELD_NAME,
7864
queryVector,
7965
VectorSimilarityFunction.COSINE,
80-
adjustedK,
66+
k,
8167
new MatchAllDocsQuery()
8268
);
8369

@@ -86,7 +72,7 @@ public void testRescoreDocs() throws Exception {
8672
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs)
8773
.collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score));
8874

89-
assertThat(rescoredDocs.size(), equalTo(adjustedK));
75+
assertThat(rescoredDocs.size(), equalTo(k));
9076

9177
Collection<Float> rescoredScores = new HashSet<>(rescoredDocs.values());
9278

@@ -113,7 +99,7 @@ public void testRescoreDocs() throws Exception {
11399
assertThat(rescoredDocs.size(), equalTo(0));
114100

115101
// Check top scoring docs are contained in rescored docs
116-
for (int i = 0; i < adjustedK; i++) {
102+
for (int i = 0; i < k; i++) {
117103
Float topScore = topK.poll();
118104
if (rescoredScores.contains(topScore) == false) {
119105
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores);
@@ -124,21 +110,23 @@ public void testRescoreDocs() throws Exception {
124110
}
125111

126112
public void testProfiling() throws Exception {
113+
int numDocs = randomIntBetween(10, 100);
127114
int numDims = randomIntBetween(5, 100);
115+
int k = randomIntBetween(1, numDocs - 1);
128116

129117
try (Directory d = newDirectory()) {
130118
addRandomDocuments(numDocs, d, numDims);
131119

132120
try (IndexReader reader = DirectoryReader.open(d)) {
133121
float[] queryVector = randomVector(numDims);
134122

135-
checkProfiling(queryVector, reader, new MatchAllDocsQuery());
136-
checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100)));
123+
checkProfiling(k, numDocs, queryVector, reader, new MatchAllDocsQuery());
124+
checkProfiling(k, numDocs, queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100)));
137125
}
138126
}
139127
}
140128

141-
private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
129+
private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
142130
RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
143131
FIELD_NAME,
144132
queryVector,
@@ -229,13 +217,4 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th
229217
w.forceMerge(1);
230218
}
231219
}
232-
233-
@ParametersFactory
234-
public static Iterable<Object[]> parameters() {
235-
List<Object[]> params = new ArrayList<>();
236-
params.add(new Object[] { true });
237-
params.add(new Object[] { false });
238-
239-
return params;
240-
}
241220
}

0 commit comments

Comments
 (0)