Skip to content

Commit fa16bca

Browse files
Adjust GPU graph building params (elastic#137074) (elastic#137093)
cuvs 2025.12 https://github.com/rapidsai/cuvs/pull/1448/files will provide an API for converting HNSW CPU Params to Cagra params. But for current ES that uses 2025.10 version, we need to adjust params ourselves. This PR adjust params based on the code from the cuvs library.
1 parent 092131d commit fa16bca

File tree

5 files changed

+21
-19
lines changed

5 files changed

+21
-19
lines changed

docs/changelog/137074.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137074
2+
summary: Adjust GPU graph building params
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
package org.elasticsearch.xpack.gpu;
88

99
import org.apache.lucene.codecs.KnnVectorsFormat;
10-
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
1110
import org.elasticsearch.common.settings.Setting;
1211
import org.elasticsearch.common.util.FeatureFlag;
1312
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -92,13 +91,14 @@ private static KnnVectorsFormat getVectorsFormat(
9291
DenseVectorFieldMapper.DenseVectorIndexOptions indexOptions,
9392
DenseVectorFieldMapper.VectorSimilarity similarity
9493
) {
94+
// TODO: cuvs 2025.12 will provide an API for converting HNSW CPU Params to Cagra params; use that instead
9595
if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.HNSW) {
9696
DenseVectorFieldMapper.HnswIndexOptions hnswIndexOptions = (DenseVectorFieldMapper.HnswIndexOptions) indexOptions;
9797
int efConstruction = hnswIndexOptions.efConstruction();
98-
if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
99-
efConstruction = ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
100-
}
101-
return new ES92GpuHnswVectorsFormat(hnswIndexOptions.m(), efConstruction);
98+
int m = hnswIndexOptions.m();
99+
int gpuM = 2 + m * 2 / 3;
100+
int gpuEfConstruction = m + m * efConstruction / 256;
101+
return new ES92GpuHnswVectorsFormat(gpuM, gpuEfConstruction);
102102
} else if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.INT8_HNSW) {
103103
if (similarity == DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT) {
104104
throw new IllegalArgumentException(
@@ -113,16 +113,10 @@ private static KnnVectorsFormat getVectorsFormat(
113113
}
114114
DenseVectorFieldMapper.Int8HnswIndexOptions int8HnswIndexOptions = (DenseVectorFieldMapper.Int8HnswIndexOptions) indexOptions;
115115
int efConstruction = int8HnswIndexOptions.efConstruction();
116-
if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
117-
efConstruction = ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
118-
}
119-
return new ES92GpuHnswSQVectorsFormat(
120-
int8HnswIndexOptions.m(),
121-
efConstruction,
122-
int8HnswIndexOptions.confidenceInterval(),
123-
7,
124-
false
125-
);
116+
int m = int8HnswIndexOptions.m();
117+
int gpuM = 2 + m * 2 / 3;
118+
int gpuEfConstruction = m + m * efConstruction / 256;
119+
return new ES92GpuHnswSQVectorsFormat(gpuM, gpuEfConstruction, int8HnswIndexOptions.confidenceInterval(), 7, false);
126120
} else {
127121
throw new IllegalArgumentException(
128122
"GPU vector indexing is not supported on this vector type: [" + indexOptions.getType() + "]"

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormat.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
1414
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
1515
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
16+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
1617
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
1718
import org.apache.lucene.index.SegmentReadState;
1819
import org.apache.lucene.index.SegmentWriteState;
@@ -36,8 +37,9 @@ public class ES92GpuHnswVectorsFormat extends KnnVectorsFormat {
3637
static final String LUCENE99_HNSW_VECTOR_INDEX_EXTENSION = "vex";
3738
static final int LUCENE99_VERSION_CURRENT = VERSION_GROUPVARINT;
3839

39-
static final int DEFAULT_MAX_CONN = 16; // graph degree
40-
public static final int DEFAULT_BEAM_WIDTH = 128; // intermediate graph degree
40+
public static final int DEFAULT_MAX_CONN = (2 + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN * 2 / 3); // graph degree
41+
public static final int DEFAULT_BEAM_WIDTH = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN
42+
* Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH / 256; // intermediate graph degree
4143
static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2;
4244

4345
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ private CagraIndex buildGPUIndex(
332332
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
333333
.withGraphDegree(M)
334334
.withIntermediateGraphDegree(beamWidth)
335+
.withNNDescentNumIterations(5)
335336
.withMetric(distanceType)
336337
.build();
337338

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public void testKnnVectorsFormat() throws IOException {
4444
// TODO improve test with custom parameters
4545
KnnVectorsFormat knnVectorsFormat = getKnnVectorsFormat("hnsw");
4646
String expectedStr = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, "
47-
+ "maxConn=16, beamWidth=128, flatVectorFormat=Lucene99FlatVectorsFormat)";
47+
+ "maxConn=12, beamWidth=22, flatVectorFormat=Lucene99FlatVectorsFormat)";
4848
assertEquals(expectedStr, knnVectorsFormat.toString());
4949
}
5050

@@ -53,7 +53,7 @@ public void testKnnQuantizedHNSWVectorsFormat() throws IOException {
5353
// TOD improve the test with custom parameters
5454
KnnVectorsFormat knnVectorsFormat = getKnnVectorsFormat("int8_hnsw");
5555
String expectedStr = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, "
56-
+ "maxConn=16, beamWidth=128, flatVectorFormat=ES814ScalarQuantizedVectorsFormat";
56+
+ "maxConn=12, beamWidth=22, flatVectorFormat=ES814ScalarQuantizedVectorsFormat";
5757
assertTrue(knnVectorsFormat.toString().startsWith(expectedStr));
5858
}
5959

0 commit comments

Comments
 (0)