Skip to content

Commit 6f77261

Browse files
committed
replace the check from dfs phase to KnnVectorQueryBuilder doToQuery
1 parent 0573a3f commit 6f77261

File tree

4 files changed

+28
-55
lines changed

4 files changed

+28
-55
lines changed

server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
import java.util.Map;
4545

4646
import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST;
47-
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
4847

4948
/**
5049
* DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
@@ -178,7 +177,7 @@ private static Timer maybeStartTimer(DfsProfiler profiler, DfsTimingType dtt) {
178177
return null;
179178
};
180179

181-
static void executeKnnVectorQuery(SearchContext context) throws IOException {
180+
private static void executeKnnVectorQuery(SearchContext context) throws IOException {
182181
SearchSourceBuilder source = context.request().source();
183182
if (source == null || source.knnSearch().isEmpty()) {
184183
return;
@@ -187,15 +186,6 @@ static void executeKnnVectorQuery(SearchContext context) throws IOException {
187186
SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
188187
List<KnnSearchBuilder> knnSearch = source.knnSearch();
189188
List<KnnVectorQueryBuilder> knnVectorQueryBuilders = knnSearch.stream().map(KnnSearchBuilder::toQueryBuilder).toList();
190-
int maxKnnNumCandidates = context.indexShard().indexSettings().getMaxKnnNumCandidates();
191-
for (KnnVectorQueryBuilder knnVectorQueryBuilder : knnVectorQueryBuilders) {
192-
if (knnVectorQueryBuilder.numCands() != null && knnVectorQueryBuilder.numCands() > maxKnnNumCandidates) {
193-
throw new IllegalArgumentException(
194-
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]"
195-
);
196-
}
197-
}
198-
199189
// Since we apply boost during the DfsQueryPhase, we should not apply boost here:
200190
knnVectorQueryBuilders.forEach(knnVectorQueryBuilder -> knnVectorQueryBuilder.boost(DEFAULT_BOOST));
201191

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
482482

483483
@Override
484484
protected Query doToQuery(SearchExecutionContext context) throws IOException {
485-
MappedFieldType fieldType = context.getFieldType(fieldName);
486485
int k;
487486
if (this.k != null) {
488487
k = this.k;
@@ -492,7 +491,15 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
492491
k = Math.min(k, numCands);
493492
}
494493
}
495-
int adjustedNumCands = numCands == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * k) : numCands;
494+
495+
int maxKnnNumCandidates = context.getIndexSettings().getMaxKnnNumCandidates();
496+
if (numCands != null && numCands > maxKnnNumCandidates) {
497+
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]");
498+
}
499+
500+
int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, maxKnnNumCandidates)) : numCands;
501+
502+
MappedFieldType fieldType = context.getFieldType(fieldName);
496503
if (fieldType == null) {
497504
return new MatchNoDocsQuery();
498505
}

server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,12 @@
1717
import org.apache.lucene.search.Query;
1818
import org.apache.lucene.store.Directory;
1919
import org.apache.lucene.tests.index.RandomIndexWriter;
20-
import org.elasticsearch.common.settings.Settings;
21-
import org.elasticsearch.index.IndexSettings;
22-
import org.elasticsearch.index.shard.IndexShard;
23-
import org.elasticsearch.search.builder.SearchSourceBuilder;
2420
import org.elasticsearch.search.internal.ContextIndexSearcher;
25-
import org.elasticsearch.search.internal.SearchContext;
26-
import org.elasticsearch.search.internal.ShardSearchRequest;
2721
import org.elasticsearch.search.profile.Profilers;
2822
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
2923
import org.elasticsearch.search.profile.query.CollectorResult;
3024
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
31-
import org.elasticsearch.search.vectors.KnnSearchBuilder;
3225
import org.elasticsearch.test.ESTestCase;
33-
import org.elasticsearch.test.IndexSettingsModule;
3426
import org.elasticsearch.threadpool.TestThreadPool;
3527
import org.elasticsearch.threadpool.ThreadPool;
3628
import org.junit.After;
@@ -40,9 +32,6 @@
4032
import java.util.List;
4133
import java.util.concurrent.ThreadPoolExecutor;
4234

43-
import static org.mockito.Mockito.mock;
44-
import static org.mockito.Mockito.when;
45-
4635
public class DfsPhaseTests extends ESTestCase {
4736

4837
ThreadPoolExecutor threadPoolExecutor;
@@ -113,35 +102,4 @@ public void testSingleKnnSearch() throws IOException {
113102
reader.close();
114103
}
115104
}
116-
117-
public void testNumCandidatesExceedsMax() {
118-
Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build();
119-
IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings);
120-
121-
SearchContext context = mock(SearchContext.class);
122-
when(context.indexShard()).thenAnswer(invocation -> {
123-
IndexShard mockIndexShard = mock(IndexShard.class);
124-
when(mockIndexShard.indexSettings()).thenReturn(indexSettings);
125-
return mockIndexShard;
126-
});
127-
128-
KnnSearchBuilder queryBuilder = new KnnSearchBuilder(
129-
"float_vector",
130-
new float[] { 0, 0, 0 },
131-
10,
132-
150, // 超过maxKnnNumCandidates的值
133-
null,
134-
null
135-
);
136-
SearchSourceBuilder source = new SearchSourceBuilder();
137-
source.knnSearch(List.of(queryBuilder));
138-
ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
139-
when(searchRequest.source()).thenReturn(source);
140-
when(context.request()).thenReturn(searchRequest);
141-
142-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> DfsPhase.executeKnnVectorQuery(context));
143-
assertEquals("[num_candidates] cannot exceed [100]", e.getMessage());
144-
145-
}
146-
147105
}

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.elasticsearch.common.io.stream.BytesStreamOutput;
2222
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
2323
import org.elasticsearch.common.io.stream.StreamInput;
24+
import org.elasticsearch.common.settings.Settings;
25+
import org.elasticsearch.index.IndexSettings;
2426
import org.elasticsearch.index.mapper.MapperService;
2527
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2628
import org.elasticsearch.index.query.InnerHitsRewriteContext;
@@ -34,6 +36,7 @@
3436
import org.elasticsearch.index.query.TermQueryBuilder;
3537
import org.elasticsearch.test.AbstractBuilderTestCase;
3638
import org.elasticsearch.test.AbstractQueryTestCase;
39+
import org.elasticsearch.test.IndexSettingsModule;
3740
import org.elasticsearch.test.TransportVersionUtils;
3841
import org.elasticsearch.xcontent.XContentBuilder;
3942
import org.elasticsearch.xcontent.XContentFactory;
@@ -53,6 +56,8 @@
5356
import static org.hamcrest.Matchers.hasSize;
5457
import static org.hamcrest.Matchers.instanceOf;
5558
import static org.hamcrest.Matchers.nullValue;
59+
import static org.mockito.Mockito.mock;
60+
import static org.mockito.Mockito.when;
5661

5762
abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
5863
private static final String VECTOR_FIELD = "vector";
@@ -458,4 +463,17 @@ public void testRewriteWithQueryVectorBuilder() throws Exception {
458463
assertThat(rewritten.filterQueries(), hasSize(numFilters));
459464
assertThat(rewritten.filterQueries(), equalTo(filters));
460465
}
466+
467+
public void testMaxNumCandidatesExceeded() {
468+
Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build();
469+
IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings);
470+
471+
SearchExecutionContext context = mock(SearchExecutionContext.class);
472+
when(context.getIndexSettings()).thenReturn(indexSettings);
473+
474+
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 150, null, null);
475+
476+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
477+
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [100]"));
478+
}
461479
}

0 commit comments

Comments
 (0)