Skip to content

Commit 7e9049e

Browse files
Adjust GPU graph building params (#137074)
cuvs 2025.12 https://github.com/rapidsai/cuvs/pull/1448/files will provide an API for converting HNSW CPU Params to Cagra params. But for current ES that uses 2025.10 version, we need to adjust params ourselves. This PR adjust params based on the code from the cuvs library.
1 parent b852cf9 commit 7e9049e

5 files changed

Lines changed: 21 additions & 19 deletions

File tree

docs/changelog/137074.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137074
2+
summary: Adjust GPU graph building params
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
package org.elasticsearch.xpack.gpu;
88

99
import org.apache.lucene.codecs.KnnVectorsFormat;
10-
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
1110
import org.elasticsearch.common.settings.Setting;
1211
import org.elasticsearch.common.util.FeatureFlag;
1312
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -94,13 +93,14 @@ private static KnnVectorsFormat getVectorsFormat(
9493
DenseVectorFieldMapper.DenseVectorIndexOptions indexOptions,
9594
DenseVectorFieldMapper.VectorSimilarity similarity
9695
) {
96+
// TODO: cuvs 2025.12 will provide an API for converting HNSW CPU Params to Cagra params; use that instead
9797
if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.HNSW) {
9898
DenseVectorFieldMapper.HnswIndexOptions hnswIndexOptions = (DenseVectorFieldMapper.HnswIndexOptions) indexOptions;
9999
int efConstruction = hnswIndexOptions.efConstruction();
100-
if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
101-
efConstruction = ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
102-
}
103-
return new ES92GpuHnswVectorsFormat(hnswIndexOptions.m(), efConstruction);
100+
int m = hnswIndexOptions.m();
101+
int gpuM = 2 + m * 2 / 3;
102+
int gpuEfConstruction = m + m * efConstruction / 256;
103+
return new ES92GpuHnswVectorsFormat(gpuM, gpuEfConstruction);
104104
} else if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.INT8_HNSW) {
105105
if (similarity == DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT) {
106106
throw new IllegalArgumentException(
@@ -115,16 +115,10 @@ private static KnnVectorsFormat getVectorsFormat(
115115
}
116116
DenseVectorFieldMapper.Int8HnswIndexOptions int8HnswIndexOptions = (DenseVectorFieldMapper.Int8HnswIndexOptions) indexOptions;
117117
int efConstruction = int8HnswIndexOptions.efConstruction();
118-
if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
119-
efConstruction = ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
120-
}
121-
return new ES92GpuHnswSQVectorsFormat(
122-
int8HnswIndexOptions.m(),
123-
efConstruction,
124-
int8HnswIndexOptions.confidenceInterval(),
125-
7,
126-
false
127-
);
118+
int m = int8HnswIndexOptions.m();
119+
int gpuM = 2 + m * 2 / 3;
120+
int gpuEfConstruction = m + m * efConstruction / 256;
121+
return new ES92GpuHnswSQVectorsFormat(gpuM, gpuEfConstruction, int8HnswIndexOptions.confidenceInterval(), 7, false);
128122
} else {
129123
throw new IllegalArgumentException(
130124
"GPU vector indexing is not supported on this vector type: [" + indexOptions.getType() + "]"

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormat.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
1414
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
1515
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
16+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
1617
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
1718
import org.apache.lucene.index.SegmentReadState;
1819
import org.apache.lucene.index.SegmentWriteState;
@@ -36,8 +37,9 @@ public class ES92GpuHnswVectorsFormat extends KnnVectorsFormat {
3637
static final String LUCENE99_HNSW_VECTOR_INDEX_EXTENSION = "vex";
3738
static final int LUCENE99_VERSION_CURRENT = VERSION_GROUPVARINT;
3839

39-
static final int DEFAULT_MAX_CONN = 16; // graph degree
40-
public static final int DEFAULT_BEAM_WIDTH = 128; // intermediate graph degree
40+
public static final int DEFAULT_MAX_CONN = (2 + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN * 2 / 3); // graph degree
41+
public static final int DEFAULT_BEAM_WIDTH = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN
42+
* Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH / 256; // intermediate graph degree
4143
static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2;
4244

4345
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ private CagraIndex buildGPUIndex(
332332
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
333333
.withGraphDegree(M)
334334
.withIntermediateGraphDegree(beamWidth)
335+
.withNNDescentNumIterations(5)
335336
.withMetric(distanceType)
336337
.build();
337338

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public void testKnnVectorsFormat() throws IOException {
4444
// TODO improve test with custom parameters
4545
KnnVectorsFormat knnVectorsFormat = getKnnVectorsFormat("hnsw");
4646
String expectedStr = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, "
47-
+ "maxConn=16, beamWidth=128, flatVectorFormat=Lucene99FlatVectorsFormat)";
47+
+ "maxConn=12, beamWidth=22, flatVectorFormat=Lucene99FlatVectorsFormat)";
4848
assertEquals(expectedStr, knnVectorsFormat.toString());
4949
}
5050

@@ -53,7 +53,7 @@ public void testKnnQuantizedHNSWVectorsFormat() throws IOException {
5353
// TOD improve the test with custom parameters
5454
KnnVectorsFormat knnVectorsFormat = getKnnVectorsFormat("int8_hnsw");
5555
String expectedStr = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, "
56-
+ "maxConn=16, beamWidth=128, flatVectorFormat=ES814ScalarQuantizedVectorsFormat";
56+
+ "maxConn=12, beamWidth=22, flatVectorFormat=ES814ScalarQuantizedVectorsFormat";
5757
assertTrue(knnVectorsFormat.toString().startsWith(expectedStr));
5858
}
5959

0 commit comments

Comments
 (0)