Skip to content

Commit 59d6c6c

Browse files
committed
improve
1 parent 2639d28 commit 59d6c6c

File tree

7 files changed

+84
-88
lines changed

7 files changed

+84
-88
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ public static boolean isNotUnitVector(float magnitude) {
123123
// vector
124124
public static final int MAGNITUDE_BYTES = 4;
125125

126+
public static final String M_FIELD = "m";
127+
public static final String EF_CONSTRUCTION_FIELD = "ef_construction";
128+
public static final String MAX_SEARCH_EF_FIELD = "max_search_ef";
129+
126130
private static DenseVectorFieldMapper toType(FieldMapper in) {
127131
return (DenseVectorFieldMapper) in;
128132
}
@@ -230,6 +234,9 @@ public Builder(String name, IndexVersion indexVersionCreated) {
230234
if (v != null) {
231235
v.validateElementType(elementType.getValue());
232236
}
237+
if (v != null) {
238+
v.validateNumCandidates();
239+
}
233240
})
234241
.acceptsNull()
235242
.setMergeValidator(
@@ -1238,7 +1245,7 @@ public void validateDimension(int dim) {
12381245
throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim);
12391246
}
12401247

1241-
public void validateNumCandidates(int numCands) {
1248+
public void validateNumCandidates() {
12421249

12431250
}
12441251

@@ -1292,9 +1299,9 @@ abstract static class AbstractHnswIndexOptions extends IndexOptions {
12921299
}
12931300

12941301
@Override
1295-
public void validateNumCandidates(int numCands) {
1296-
if (numCands > maxSearchEf) {
1297-
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxSearchEf + "]");
1302+
public void validateNumCandidates() {
1303+
if (maxSearchEf <= 0) {
1304+
throw new IllegalArgumentException("[" + MAX_SEARCH_EF_FIELD + "] must be greater than 0");
12981305
}
12991306
}
13001307

@@ -1337,9 +1344,9 @@ public enum VectorIndexType {
13371344
HNSW("hnsw", false) {
13381345
@Override
13391346
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
1340-
Object mNode = indexOptionsMap.remove("m");
1341-
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
1342-
Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef");
1347+
Object mNode = indexOptionsMap.remove(M_FIELD);
1348+
Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD);
1349+
Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD);
13431350
if (mNode == null) {
13441351
mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
13451352
}
@@ -1370,9 +1377,9 @@ public boolean supportsDimension(int dims) {
13701377
INT8_HNSW("int8_hnsw", true) {
13711378
@Override
13721379
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
1373-
Object mNode = indexOptionsMap.remove("m");
1374-
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
1375-
Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef");
1380+
Object mNode = indexOptionsMap.remove(M_FIELD);
1381+
Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD);
1382+
Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD);
13761383
Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval");
13771384
if (mNode == null) {
13781385
mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
@@ -1406,9 +1413,9 @@ public boolean supportsDimension(int dims) {
14061413
},
14071414
INT4_HNSW("int4_hnsw", true) {
14081415
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
1409-
Object mNode = indexOptionsMap.remove("m");
1410-
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
1411-
Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef");
1416+
Object mNode = indexOptionsMap.remove(M_FIELD);
1417+
Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD);
1418+
Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD);
14121419
Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval");
14131420
if (mNode == null) {
14141421
mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
@@ -1504,9 +1511,9 @@ public boolean supportsDimension(int dims) {
15041511
BBQ_HNSW("bbq_hnsw", true) {
15051512
@Override
15061513
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
1507-
Object mNode = indexOptionsMap.remove("m");
1508-
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
1509-
Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef");
1514+
Object mNode = indexOptionsMap.remove(M_FIELD);
1515+
Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD);
1516+
Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD);
15101517
if (mNode == null) {
15111518
mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
15121519
}
@@ -2075,7 +2082,11 @@ public Query createKnnQuery(
20752082
);
20762083
}
20772084

2078-
indexOptions.validateNumCandidates(numCands);
2085+
if (numCands > indexOptions.maxSearchEf()) {
2086+
throw new IllegalArgumentException(
2087+
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + indexOptions.maxSearchEf() + "]"
2088+
);
2089+
}
20792090

20802091
return switch (getElementType()) {
20812092
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,6 @@ private KnnSearchBuilder(
264264
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]"
265265
);
266266
}
267-
if (numCandidates > NUM_CANDS_LIMIT) {
268-
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
269-
}
270267
if (queryVector == null && queryVectorBuilder == null) {
271268
throw new IllegalArgumentException(
272269
format(

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,6 @@ private KnnVectorQueryBuilder(
183183
if (k != null && k < 1) {
184184
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
185185
}
186-
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
187-
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
188-
}
189186
if (k != null && numCands != null && numCands < k) {
190187
throw new IllegalArgumentException(
191188
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]"

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363

6464
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
6565
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
66+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT;
6667
import static org.hamcrest.Matchers.containsString;
6768
import static org.hamcrest.Matchers.equalTo;
6869
import static org.hamcrest.Matchers.instanceOf;
@@ -852,6 +853,23 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
852853
.endObject()
853854
)
854855
);
856+
857+
checker.registerUpdateCheck(
858+
b -> b.field("type", "dense_vector")
859+
.field("dims", dims)
860+
.field("index", true)
861+
.startObject("index_options")
862+
.field("type", "int4_hnsw")
863+
.endObject(),
864+
b -> b.field("type", "dense_vector")
865+
.field("dims", dims)
866+
.field("index", true)
867+
.startObject("index_options")
868+
.field("type", "int4_hnsw")
869+
.field("max_search_ef", 1000)
870+
.endObject(),
871+
m -> assertTrue(m.toString().contains("\"max_search_ef\":1000"))
872+
);
855873
}
856874

857875
@Override
@@ -937,7 +955,7 @@ public void testMergeDims() throws IOException {
937955
.field("type", "int8_hnsw")
938956
.field("m", 16)
939957
.field("ef_construction", 100)
940-
.field("max_search_ef", DenseVectorFieldMapper.NUM_CANDS_LIMIT)
958+
.field("max_search_ef", NUM_CANDS_LIMIT)
941959
.endObject();
942960
b.endObject();
943961
});
@@ -2091,6 +2109,22 @@ public void testInvalidVectorDimensions() {
20912109
}
20922110
}
20932111

2112+
public void testMaxSearchEfBounds() {
2113+
Exception e = expectThrows(MapperParsingException.class, () -> createDocumentMapper(fieldMapping(b -> {
2114+
b.field("type", "dense_vector");
2115+
b.field("dims", dims);
2116+
b.field("index", true);
2117+
b.field("similarity", "dot_product");
2118+
b.startObject("index_options");
2119+
b.field("type", "hnsw");
2120+
b.field("m", 5);
2121+
b.field("ef_construction", 50);
2122+
b.field("max_search_ef", 0); // Invalid value
2123+
b.endObject();
2124+
})));
2125+
assertThat(e.getMessage(), containsString("Failed to parse mapping: [max_search_ef] must be greater than 0"));
2126+
}
2127+
20942128
@Override
20952129
protected IngestScriptSupport ingestScriptSupport() {
20962130
throw new AssumptionViolatedException("not supported");

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,22 @@ public DenseVectorFieldTypeTests() {
5151

5252
private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() {
5353
return randomFrom(
54-
new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)),
54+
new DenseVectorFieldMapper.HnswIndexOptions(
55+
randomIntBetween(1, 100),
56+
randomIntBetween(1, 10_000),
57+
randomIntBetween(1_000, 10_000)
58+
),
5559
new DenseVectorFieldMapper.FlatIndexOptions()
5660
);
5761
}
5862

5963
private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() {
6064
return randomFrom(
61-
new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)),
65+
new DenseVectorFieldMapper.HnswIndexOptions(
66+
randomIntBetween(1, 100),
67+
randomIntBetween(1, 10_000),
68+
randomIntBetween(1_000, 10_000)
69+
),
6270
new DenseVectorFieldMapper.Int8HnswIndexOptions(
6371
randomIntBetween(1, 100),
6472
randomIntBetween(1, 10_000),
@@ -74,7 +82,11 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() {
7482
new DenseVectorFieldMapper.FlatIndexOptions(),
7583
new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))),
7684
new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))),
77-
new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)),
85+
new DenseVectorFieldMapper.BBQHnswIndexOptions(
86+
randomIntBetween(1, 100),
87+
randomIntBetween(1, 10_000),
88+
randomIntBetween(1_000, 10_000)
89+
),
7890
new DenseVectorFieldMapper.BBQFlatIndexOptions()
7991
);
8092
}
@@ -93,7 +105,11 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() {
93105
randomIntBetween(1_000, 10_000),
94106
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))
95107
),
96-
new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000))
108+
new DenseVectorFieldMapper.BBQHnswIndexOptions(
109+
randomIntBetween(1, 100),
110+
randomIntBetween(1, 10_000),
111+
randomIntBetween(1_000, 10_000)
112+
)
97113
);
98114
}
99115

@@ -176,37 +192,6 @@ public void testFetchSourceValue() throws IOException {
176192
assertEquals(vector, fetchSourceValue(bft, vector));
177193
}
178194

179-
public void testValidateNumCandidates() {
180-
// Test case where numCands is less than or equal to maxSearchEf
181-
{
182-
int maxSearchEf = randomIntBetween(1, 1000);
183-
int numCands = randomIntBetween(0, maxSearchEf);
184-
DenseVectorFieldMapper.IndexOptions indexOptions = new DenseVectorFieldMapper.HnswIndexOptions(
185-
randomIntBetween(1, 100),
186-
randomIntBetween(1, 10_000),
187-
maxSearchEf
188-
);
189-
indexOptions.validateNumCandidates(numCands);
190-
// No exception should be thrown
191-
}
192-
193-
// Test case where numCands is greater than maxSearchEf
194-
{
195-
int maxSearchEf = randomIntBetween(1, 1000);
196-
int numCands = randomIntBetween(maxSearchEf + 1, maxSearchEf + 1000);
197-
DenseVectorFieldMapper.IndexOptions indexOptions = new DenseVectorFieldMapper.HnswIndexOptions(
198-
randomIntBetween(1, 100),
199-
randomIntBetween(1, 10_000),
200-
maxSearchEf
201-
);
202-
IllegalArgumentException e = expectThrows(
203-
IllegalArgumentException.class,
204-
() -> indexOptions.validateNumCandidates(numCands)
205-
);
206-
assertThat(e.getMessage(), containsString("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxSearchEf + "]"));
207-
}
208-
}
209-
210195
public void testCreateNestedKnnQuery() {
211196
BitSetProducer producer = context -> null;
212197

@@ -414,11 +399,7 @@ public void testCreateKnnQuerValidateNumCandidates() {
414399
dims,
415400
true,
416401
VectorSimilarity.COSINE,
417-
new DenseVectorFieldMapper.HnswIndexOptions(
418-
randomIntBetween(1, 100),
419-
randomIntBetween(1, 10_000),
420-
1000
421-
),
402+
new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), 1000),
422403
Collections.emptyMap()
423404
);
424405
float[] queryVector = new float[dims];

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,6 @@ public void testNumCandsLessThanK() {
238238
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
239239
}
240240

241-
public void testNumCandsExceedsLimit() {
242-
IllegalArgumentException e = expectThrows(
243-
IllegalArgumentException.class,
244-
() -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null)
245-
);
246-
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
247-
}
248-
249241
public void testInvalidK() {
250242
IllegalArgumentException e = expectThrows(
251243
IllegalArgumentException.class,

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException {
179179
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
180180
}
181181

182-
public void testNumCandsExceedsLimit() throws IOException {
183-
XContentType xContentType = randomFrom(XContentType.values());
184-
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())
185-
.startObject()
186-
.startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName())
187-
.field(KnnSearch.FIELD_FIELD.getPreferredName(), "field")
188-
.field(KnnSearch.K_FIELD.getPreferredName(), 100)
189-
.field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002)
190-
.field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f })
191-
.endObject()
192-
.endObject();
193-
194-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder));
195-
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
196-
}
197-
198182
public void testInvalidK() throws IOException {
199183
XContentType xContentType = randomFrom(XContentType.values());
200184
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())

0 commit comments

Comments
 (0)