Skip to content

Commit 2c1f686

Browse files
committed
add index.max_knn_num_candidates settings
1 parent dc013bb commit 2c1f686

File tree

10 files changed

+110
-41
lines changed

10 files changed

+110
-41
lines changed

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,32 @@ setup:
605605
- match: { hits.hits.0._score: $knn_score0 }
606606
- match: { hits.hits.1._score: $knn_score1 }
607607
- match: { hits.hits.2._score: $knn_score2 }
608+
609+
---
610+
"kNN search with num_candidates exceeds max allowed value":
611+
- requires:
612+
reason: 'num_candidates exceeds max allowed value'
613+
test_runner_features: [capabilities]
614+
615+
- do:
616+
indices.create:
617+
index: test_num_candidates
618+
body:
619+
mappings:
620+
properties:
621+
vector:
622+
type: dense_vector
623+
element_type: float
624+
dims: 5
625+
settings:
626+
index.max_knn_num_candidates: 100
627+
- do:
628+
catch: /\[num_candidates\] cannot exceed \[100\]/
629+
search:
630+
index: test_num_candidates
631+
body:
632+
knn:
633+
field: vector
634+
query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
635+
k: 2
636+
num_candidates: 200

server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings {
141141
IndexSettings.MAX_REFRESH_LISTENERS_PER_SHARD,
142142
IndexSettings.MAX_SLICES_PER_SCROLL,
143143
IndexSettings.MAX_REGEX_LENGTH_SETTING,
144+
IndexSettings.INDEX_MAX_KNN_NUM_CANDIDATES_SETTING,
144145
ShardsLimitAllocationDecider.INDEX_TOTAL_SHARDS_PER_NODE_SETTING,
145146
IndexSettings.INDEX_GC_DELETES_SETTING,
146147
IndexSettings.INDEX_SOFT_DELETES_SETTING,

server/src/main/java/org/elasticsearch/index/IndexSettings.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@ public final class IndexSettings {
284284
Property.IndexScope
285285
);
286286

287+
/**
288+
* The maximum number of candidates to be considered for KNN search. The default value is 10_000.
289+
*/
290+
public static final Setting<Integer> INDEX_MAX_KNN_NUM_CANDIDATES_SETTING = Setting.intSetting(
291+
"index.max_knn_num_candidates",
292+
10_000,
293+
1,
294+
Property.Dynamic,
295+
Property.IndexScope
296+
);
297+
287298
public static final TimeValue DEFAULT_REFRESH_INTERVAL = new TimeValue(1, TimeUnit.SECONDS);
288299
public static final Setting<TimeValue> NODE_DEFAULT_REFRESH_INTERVAL_SETTING = Setting.timeSetting(
289300
"node._internal.default_refresh_interval",
@@ -930,6 +941,8 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) {
930941
*/
931942
private volatile int maxRegexLength;
932943

944+
private volatile int maxKnnNumCandidates;
945+
933946
private final IndexRouting indexRouting;
934947

935948
/**
@@ -1083,6 +1096,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
10831096
mappingDepthLimit = scopedSettings.get(INDEX_MAPPING_DEPTH_LIMIT_SETTING);
10841097
mappingFieldNameLengthLimit = scopedSettings.get(INDEX_MAPPING_FIELD_NAME_LENGTH_LIMIT_SETTING);
10851098
mappingDimensionFieldsLimit = scopedSettings.get(INDEX_MAPPING_DIMENSION_FIELDS_LIMIT_SETTING);
1099+
maxKnnNumCandidates = scopedSettings.get(INDEX_MAX_KNN_NUM_CANDIDATES_SETTING);
10861100
indexRouting = IndexRouting.fromIndexMetadata(indexMetadata);
10871101
sourceKeepMode = scopedSettings.get(Mapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING);
10881102
es87TSDBCodecEnabled = scopedSettings.get(TIME_SERIES_ES87TSDB_CODEC_ENABLED_SETTING);
@@ -1203,6 +1217,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
12031217
this::setSkipIgnoredSourceWrite
12041218
);
12051219
scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead);
1220+
scopedSettings.addSettingsUpdateConsumer(INDEX_MAX_KNN_NUM_CANDIDATES_SETTING, this::setMaxKnnNumCandidates);
12061221
}
12071222

12081223
private void setSearchIdleAfter(TimeValue searchIdleAfter) {
@@ -1821,4 +1836,12 @@ public TimestampBounds getTimestampBounds() {
18211836
public IndexRouting getIndexRouting() {
18221837
return indexRouting;
18231838
}
1839+
1840+
public int getMaxKnnNumCandidates() {
1841+
return maxKnnNumCandidates;
1842+
}
1843+
1844+
public void setMaxKnnNumCandidates(int maxKnnNumCandidates) {
1845+
this.maxKnnNumCandidates = maxKnnNumCandidates;
1846+
}
18241847
}

server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import java.util.Map;
4545

4646
import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST;
47+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
4748

4849
/**
4950
* DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
@@ -177,7 +178,7 @@ private static Timer maybeStartTimer(DfsProfiler profiler, DfsTimingType dtt) {
177178
return null;
178179
};
179180

180-
private static void executeKnnVectorQuery(SearchContext context) throws IOException {
181+
static void executeKnnVectorQuery(SearchContext context) throws IOException {
181182
SearchSourceBuilder source = context.request().source();
182183
if (source == null || source.knnSearch().isEmpty()) {
183184
return;
@@ -186,6 +187,15 @@ private static void executeKnnVectorQuery(SearchContext context) throws IOExcept
186187
SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
187188
List<KnnSearchBuilder> knnSearch = source.knnSearch();
188189
List<KnnVectorQueryBuilder> knnVectorQueryBuilders = knnSearch.stream().map(KnnSearchBuilder::toQueryBuilder).toList();
190+
int maxKnnNumCandidates = context.indexShard().indexSettings().getMaxKnnNumCandidates();
191+
for (KnnVectorQueryBuilder knnVectorQueryBuilder : knnVectorQueryBuilders) {
192+
if (knnVectorQueryBuilder.numCands() != null && knnVectorQueryBuilder.numCands() > maxKnnNumCandidates) {
193+
throw new IllegalArgumentException(
194+
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]"
195+
);
196+
}
197+
}
198+
189199
// Since we apply boost during the DfsQueryPhase, we should not apply boost here:
190200
knnVectorQueryBuilders.forEach(knnVectorQueryBuilder -> knnVectorQueryBuilder.boost(DEFAULT_BOOST));
191201

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
* Defines a kNN search to run in the search request.
4444
*/
4545
public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewriteable<KnnSearchBuilder> {
46-
public static final int NUM_CANDS_LIMIT = 10_000;
4746
public static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;
4847

4948
public static final ParseField FIELD_FIELD = new ParseField("field");
@@ -264,9 +263,6 @@ private KnnSearchBuilder(
264263
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]"
265264
);
266265
}
267-
if (numCandidates > NUM_CANDS_LIMIT) {
268-
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
269-
}
270266
if (queryVector == null && queryVectorBuilder == null) {
271267
throw new IllegalArgumentException(
272268
format(
@@ -667,9 +663,7 @@ public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) {
667663
public KnnSearchBuilder build(int size) {
668664
int requestSize = size < 0 ? DEFAULT_SIZE : size;
669665
int adjustedK = k == null ? requestSize : k;
670-
int adjustedNumCandidates = numCandidates == null
671-
? Math.round(Math.min(NUM_CANDS_LIMIT, NUM_CANDS_MULTIPLICATIVE_FACTOR * adjustedK))
672-
: numCandidates;
666+
int adjustedNumCandidates = numCandidates == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * adjustedK) : numCandidates;
673667
return new KnnSearchBuilder(
674668
field,
675669
queryVectorBuilder,

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ public void toSearchRequest(SearchRequestBuilder builder) {
195195

196196
// visible for testing
197197
static class KnnSearch {
198-
private static final int NUM_CANDS_LIMIT = 10000;
199198
static final ParseField FIELD_FIELD = new ParseField("field");
200199
static final ParseField K_FIELD = new ParseField("k");
201200
static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
@@ -253,9 +252,6 @@ public KnnVectorQueryBuilder toQueryBuilder() {
253252
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]"
254253
);
255254
}
256-
if (numCands > NUM_CANDS_LIMIT) {
257-
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
258-
}
259255
return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, null, null);
260256
}
261257

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
*/
5757
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
5858
public static final String NAME = "knn";
59-
private static final int NUM_CANDS_LIMIT = 10_000;
6059
private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;
6160

6261
public static final ParseField FIELD_FIELD = new ParseField("field");
@@ -183,9 +182,6 @@ private KnnVectorQueryBuilder(
183182
if (k != null && k < 1) {
184183
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
185184
}
186-
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
187-
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
188-
}
189185
if (k != null && numCands != null && numCands < k) {
190186
throw new IllegalArgumentException(
191187
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]"
@@ -496,7 +492,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
496492
k = Math.min(k, numCands);
497493
}
498494
}
499-
int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, NUM_CANDS_LIMIT)) : numCands;
495+
int adjustedNumCands = numCands == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * k) : numCands;
500496
if (fieldType == null) {
501497
return new MatchNoDocsQuery();
502498
}

server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@
1717
import org.apache.lucene.search.Query;
1818
import org.apache.lucene.store.Directory;
1919
import org.apache.lucene.tests.index.RandomIndexWriter;
20+
import org.elasticsearch.common.settings.Settings;
21+
import org.elasticsearch.index.IndexSettings;
22+
import org.elasticsearch.index.shard.IndexShard;
23+
import org.elasticsearch.search.builder.SearchSourceBuilder;
2024
import org.elasticsearch.search.internal.ContextIndexSearcher;
25+
import org.elasticsearch.search.internal.SearchContext;
26+
import org.elasticsearch.search.internal.ShardSearchRequest;
2127
import org.elasticsearch.search.profile.Profilers;
2228
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
2329
import org.elasticsearch.search.profile.query.CollectorResult;
2430
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
31+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
2532
import org.elasticsearch.test.ESTestCase;
33+
import org.elasticsearch.test.IndexSettingsModule;
2634
import org.elasticsearch.threadpool.TestThreadPool;
2735
import org.elasticsearch.threadpool.ThreadPool;
2836
import org.junit.After;
@@ -32,6 +40,9 @@
3240
import java.util.List;
3341
import java.util.concurrent.ThreadPoolExecutor;
3442

43+
import static org.mockito.Mockito.mock;
44+
import static org.mockito.Mockito.when;
45+
3546
public class DfsPhaseTests extends ESTestCase {
3647

3748
ThreadPoolExecutor threadPoolExecutor;
@@ -102,4 +113,37 @@ public void testSingleKnnSearch() throws IOException {
102113
reader.close();
103114
}
104115
}
116+
117+
public void testNumCandidatesExceedsMax() {
118+
Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build();
119+
IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings);
120+
121+
SearchContext context = mock(SearchContext.class);
122+
when(context.indexShard()).thenAnswer(invocation -> {
123+
IndexShard mockIndexShard = mock(IndexShard.class);
124+
when(mockIndexShard.indexSettings()).thenReturn(indexSettings);
125+
return mockIndexShard;
126+
});
127+
128+
// 构造超过最大值的查询参数
129+
KnnSearchBuilder queryBuilder = new KnnSearchBuilder(
130+
"float_vector",
131+
new float[] { 0, 0, 0 },
132+
10,
133+
150, // 超过maxKnnNumCandidates的值
134+
null,
135+
null
136+
);
137+
SearchSourceBuilder source = new SearchSourceBuilder();
138+
source.knnSearch(List.of(queryBuilder));
139+
ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
140+
when(searchRequest.source()).thenReturn(source);
141+
when(context.request()).thenReturn(searchRequest);
142+
143+
// 验证异常抛出
144+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> DfsPhase.executeKnnVectorQuery(context));
145+
assertEquals("[num_candidates] cannot exceed [100]", e.getMessage());
146+
147+
}
148+
105149
}

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,6 @@ public void testNumCandsLessThanK() {
238238
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
239239
}
240240

241-
public void testNumCandsExceedsLimit() {
242-
IllegalArgumentException e = expectThrows(
243-
IllegalArgumentException.class,
244-
() -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null)
245-
);
246-
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
247-
}
248-
249241
public void testInvalidK() {
250242
IllegalArgumentException e = expectThrows(
251243
IllegalArgumentException.class,

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException {
179179
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
180180
}
181181

182-
public void testNumCandsExceedsLimit() throws IOException {
183-
XContentType xContentType = randomFrom(XContentType.values());
184-
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())
185-
.startObject()
186-
.startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName())
187-
.field(KnnSearch.FIELD_FIELD.getPreferredName(), "field")
188-
.field(KnnSearch.K_FIELD.getPreferredName(), 100)
189-
.field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002)
190-
.field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f })
191-
.endObject()
192-
.endObject();
193-
194-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder));
195-
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
196-
}
197-
198182
public void testInvalidK() throws IOException {
199183
XContentType xContentType = randomFrom(XContentType.values());
200184
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())

0 commit comments

Comments
 (0)