Skip to content

Commit f83fdee

Browse files
committed
Add bfloat16 native tests
1 parent a0e2bbd commit f83fdee

15 files changed

Lines changed: 1397 additions & 5 deletions

File tree

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
2727
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
2828
import org.apache.lucene.util.quantization.ScalarQuantizer;
29+
import org.elasticsearch.index.codec.vectors.BFloat16;
30+
import org.elasticsearch.index.codec.vectors.es93.OffHeapBFloat16VectorValues;
2931
import org.elasticsearch.simdvec.VectorScorerFactory;
3032

3133
import java.io.IOException;
@@ -65,6 +67,16 @@ static void writeFloatVectorData(Directory dir, float[][] vectors) throws IOExce
6567
}
6668
}
6769

70+
static void writeBFloat16VectorData(Directory dir, float[][] vectors) throws IOException {
71+
try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
72+
ByteBuffer buffer = ByteBuffer.allocate(vectors[0].length * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN);
73+
for (float[] vector : vectors) {
74+
BFloat16.floatToBFloat16(vector, buffer.asShortBuffer());
75+
out.writeBytes(buffer.array(), buffer.capacity());
76+
}
77+
}
78+
}
79+
6880
static void writeByteVectorData(Directory dir, byte[][] vectors) throws IOException {
6981
try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
7082
for (byte[] vector : vectors) {
@@ -97,6 +109,11 @@ static FloatVectorValues floatVectorValues(int dims, int size, IndexInput in, Ve
97109
return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dims, size, slice, dims * Float.BYTES, null, sim);
98110
}
99111

112+
static FloatVectorValues bfloat16VectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
113+
var slice = in.slice("values", 0, in.length());
114+
return new OffHeapBFloat16VectorValues.DenseOffHeapVectorValues(dims, size, slice, dims * BFloat16.BYTES, null, sim);
115+
}
116+
100117
static ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
101118
var slice = in.slice("values", 0, in.length());
102119
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(dims, size, slice, dims, null, sim);
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector.scorer;
11+
12+
import org.apache.lucene.store.Directory;
13+
import org.apache.lucene.store.IOContext;
14+
import org.apache.lucene.store.IndexInput;
15+
import org.apache.lucene.store.MMapDirectory;
16+
import org.apache.lucene.util.VectorUtil;
17+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
18+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
19+
import org.elasticsearch.benchmark.Utils;
20+
import org.elasticsearch.core.IOUtils;
21+
import org.elasticsearch.index.codec.vectors.BFloat16;
22+
import org.elasticsearch.simdvec.VectorScorerFactory;
23+
import org.elasticsearch.simdvec.VectorSimilarityType;
24+
import org.openjdk.jmh.annotations.Benchmark;
25+
import org.openjdk.jmh.annotations.BenchmarkMode;
26+
import org.openjdk.jmh.annotations.Fork;
27+
import org.openjdk.jmh.annotations.Measurement;
28+
import org.openjdk.jmh.annotations.Mode;
29+
import org.openjdk.jmh.annotations.OutputTimeUnit;
30+
import org.openjdk.jmh.annotations.Param;
31+
import org.openjdk.jmh.annotations.Scope;
32+
import org.openjdk.jmh.annotations.Setup;
33+
import org.openjdk.jmh.annotations.State;
34+
import org.openjdk.jmh.annotations.TearDown;
35+
import org.openjdk.jmh.annotations.Warmup;
36+
37+
import java.io.IOException;
38+
import java.nio.file.Files;
39+
import java.nio.file.Path;
40+
import java.util.concurrent.ThreadLocalRandom;
41+
import java.util.concurrent.TimeUnit;
42+
43+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.bfloat16VectorValues;
44+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.getScorerFactoryOrDie;
45+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScoreSupplier;
46+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScorer;
47+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;
48+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.writeBFloat16VectorData;
49+
import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.dotProduct;
50+
import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.squareDistance;
51+
52+
@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
53+
@Warmup(iterations = 3, time = 3)
54+
@Measurement(iterations = 5, time = 3)
55+
@BenchmarkMode(Mode.Throughput)
56+
@OutputTimeUnit(TimeUnit.SECONDS)
57+
@State(Scope.Thread)
58+
public class VectorScorerBFloat16Benchmark {
59+
60+
static {
61+
Utils.configureBenchmarkLogging();
62+
}
63+
64+
@Param({ "96", "768", "1024" })
65+
public int dims;
66+
public static int numVectors = 2; // there are only two vectors to compare
67+
68+
@Param
69+
public VectorImplementation implementation;
70+
71+
@Param({ "DOT_PRODUCT", "EUCLIDEAN" })
72+
public VectorSimilarityType function;
73+
74+
private Path path;
75+
private Directory dir;
76+
private IndexInput in;
77+
78+
private static class ScalarDotProduct implements UpdateableRandomVectorScorer {
79+
private final float[] vec1;
80+
private final float[] vec2;
81+
82+
private ScalarDotProduct(float[] vec1, float[] vec2) {
83+
this.vec1 = vec1;
84+
this.vec2 = vec2;
85+
}
86+
87+
@Override
88+
public float score(int ordinal) {
89+
return VectorUtil.normalizeToUnitInterval(dotProduct(vec1, vec2));
90+
}
91+
92+
@Override
93+
public int maxOrd() {
94+
return 0;
95+
}
96+
97+
@Override
98+
public void setScoringOrdinal(int targetOrd) {}
99+
}
100+
101+
private static class ScalarSquareDistance implements UpdateableRandomVectorScorer {
102+
private final float[] vec1;
103+
private final float[] vec2;
104+
105+
private ScalarSquareDistance(float[] vec1, float[] vec2) {
106+
this.vec1 = vec1;
107+
this.vec2 = vec2;
108+
}
109+
110+
@Override
111+
public float score(int ordinal) {
112+
return VectorUtil.normalizeDistanceToUnitInterval(squareDistance(vec1, vec2));
113+
}
114+
115+
@Override
116+
public int maxOrd() {
117+
return 0;
118+
}
119+
120+
@Override
121+
public void setScoringOrdinal(int targetOrd) {}
122+
}
123+
124+
private UpdateableRandomVectorScorer scorer;
125+
private RandomVectorScorer queryScorer;
126+
127+
static class VectorData {
128+
private final float[][] vectorData;
129+
private final float[] queryVector;
130+
131+
VectorData(int dims) {
132+
this(dims, 2);
133+
}
134+
135+
VectorData(int dims, int numVectors) {
136+
vectorData = new float[numVectors][dims];
137+
ThreadLocalRandom random = ThreadLocalRandom.current();
138+
for (int v = 0; v < numVectors; v++) {
139+
for (int d = 0; d < dims; d++) {
140+
vectorData[v][d] = BFloat16.truncateToBFloat16(random.nextFloat());
141+
}
142+
}
143+
144+
queryVector = new float[dims];
145+
for (int i = 0; i < dims; i++) {
146+
// query uses full floats
147+
queryVector[i] = random.nextFloat();
148+
}
149+
}
150+
151+
}
152+
153+
@Setup
154+
public void setup() throws IOException {
155+
setup(new VectorData(dims, numVectors));
156+
}
157+
158+
public void setup(VectorData vectorData) throws IOException {
159+
VectorScorerFactory factory = getScorerFactoryOrDie();
160+
161+
path = Files.createTempDirectory("BFloat16ScorerBenchmark");
162+
dir = new MMapDirectory(path);
163+
writeBFloat16VectorData(dir, vectorData.vectorData);
164+
165+
in = dir.openInput("vector.data", IOContext.DEFAULT);
166+
var values = bfloat16VectorValues(dims, numVectors, in, function.function());
167+
168+
switch (implementation) {
169+
case SCALAR:
170+
float[] vec1 = values.vectorValue(0).clone();
171+
float[] vec2 = values.vectorValue(1).clone();
172+
173+
scorer = switch (function) {
174+
case DOT_PRODUCT -> new ScalarDotProduct(vec1, vec2);
175+
case EUCLIDEAN -> new ScalarSquareDistance(vec1, vec2);
176+
default -> throw new IllegalArgumentException(function + " not supported");
177+
};
178+
break;
179+
case LUCENE:
180+
scorer = luceneScoreSupplier(values, function.function()).scorer();
181+
if (supportsHeapSegments()) {
182+
queryScorer = luceneScorer(values, function.function(), vectorData.queryVector);
183+
}
184+
break;
185+
case NATIVE:
186+
scorer = factory.getBFloat16VectorScorerSupplier(function, in, values).orElseThrow().scorer();
187+
if (supportsHeapSegments()) {
188+
queryScorer = factory.getBFloat16VectorScorer(function.function(), values, vectorData.queryVector).orElseThrow();
189+
}
190+
break;
191+
}
192+
193+
scorer.setScoringOrdinal(0);
194+
}
195+
196+
@TearDown
197+
public void teardown() throws IOException {
198+
IOUtils.close(dir, in);
199+
IOUtils.rm(path);
200+
}
201+
202+
@Benchmark
203+
public float score() throws IOException {
204+
return scorer.score(1);
205+
}
206+
207+
@Benchmark
208+
public float scoreQuery() throws IOException {
209+
return queryScorer.score(1);
210+
}
211+
}

0 commit comments

Comments
 (0)