Skip to content

Commit 8a839dc

Browse files
authored
[DiskBBQ] Add bulk scoring for int7 centroid scoring (elastic#138204)
This provides a minor but measurable speed improvement. JMH shows a nicer story: this PR ``` Int7ScorerBenchmark.scoreFromMemorySegmentBulk 384 thrpt 5 57.738 ± 0.510 ops/ms Int7ScorerBenchmark.scoreFromMemorySegmentBulk 782 thrpt 5 25.804 ± 0.196 ops/ms Int7ScorerBenchmark.scoreFromMemorySegmentBulk 1024 thrpt 5 23.813 ± 2.751 ops/ms ``` baseline: ``` Int7ScorerBenchmark.scoreFromMemorySegmentBulk 384 thrpt 5 35.412 ± 0.202 ops/ms Int7ScorerBenchmark.scoreFromMemorySegmentBulk 782 thrpt 5 20.663 ± 0.521 ops/ms Int7ScorerBenchmark.scoreFromMemorySegmentBulk 1024 thrpt 5 19.765 ± 1.296 ops/ms ``` I will need help rebuilding and publishing the binaries :)
1 parent d74334e commit 8a839dc

File tree

6 files changed

+67
-5
lines changed

6 files changed

+67
-5
lines changed

docs/changelog/138204.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138204
2+
summary: "[DiskBBQ] Add bulk scoring for int7 centroid scoring"
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.5"
22-
var vecVersion = "1.0.13"
22+
var vecVersion = "1.0.14"
2323

2424
repositories {
2525
exclusiveContent {

libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ public interface VectorSimilarityFunctions {
3030
*/
3131
MethodHandle dotProductHandle7u();
3232

33+
MethodHandle dotProductHandle7uBulk();
34+
3335
/**
3436
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
3537
*

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
3232
static final Logger logger = LogManager.getLogger(JdkVectorLibrary.class);
3333

3434
static final MethodHandle dot7u$mh;
35+
static final MethodHandle dot7uBulk$mh;
3536
static final MethodHandle sqr7u$mh;
3637
static final MethodHandle cosf32$mh;
3738
static final MethodHandle dotf32$mh;
@@ -53,6 +54,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
5354
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
5455
LinkerHelperUtil.critical()
5556
);
57+
dot7uBulk$mh = downcallHandle(
58+
"dot7u_bulk_2",
59+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
60+
LinkerHelperUtil.critical()
61+
);
5662
sqr7u$mh = downcallHandle(
5763
"sqr7u_2",
5864
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -79,6 +85,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
7985
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
8086
LinkerHelperUtil.critical()
8187
);
88+
dot7uBulk$mh = downcallHandle(
89+
"dot7u_bulk",
90+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
91+
LinkerHelperUtil.critical()
92+
);
8293
sqr7u$mh = downcallHandle(
8394
"sqr7u",
8495
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -108,6 +119,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
108119
enable them in your OS/Hypervisor/VM/container""");
109120
}
110121
dot7u$mh = null;
122+
dot7uBulk$mh = null;
111123
sqr7u$mh = null;
112124
cosf32$mh = null;
113125
dotf32$mh = null;
@@ -142,6 +154,10 @@ static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
142154
return dot7u(a, b, length);
143155
}
144156

157+
static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
158+
dot7uBulk(a, b, length, count, result);
159+
}
160+
145161
/**
146162
* Computes the square distance of given unsigned int7 byte vectors.
147163
*
@@ -210,6 +226,14 @@ private static int dot7u(MemorySegment a, MemorySegment b, int length) {
210226
}
211227
}
212228

229+
private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
230+
try {
231+
JdkVectorLibrary.dot7uBulk$mh.invokeExact(a, b, length, count, result);
232+
} catch (Throwable t) {
233+
throw new AssertionError(t);
234+
}
235+
}
236+
213237
private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
214238
try {
215239
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
@@ -243,6 +267,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
243267
}
244268

245269
static final MethodHandle DOT_HANDLE_7U;
270+
static final MethodHandle DOT_HANDLE_7U_BULK;
246271
static final MethodHandle SQR_HANDLE_7U;
247272
static final MethodHandle COS_HANDLE_FLOAT32;
248273
static final MethodHandle DOT_HANDLE_FLOAT32;
@@ -253,6 +278,11 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
253278
var lookup = MethodHandles.lookup();
254279
var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
255280
DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt);
281+
DOT_HANDLE_7U_BULK = lookup.findStatic(
282+
JdkVectorSimilarityFunctions.class,
283+
"dotProduct7uBulk",
284+
MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class)
285+
);
256286
SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt);
257287

258288
mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
@@ -269,6 +299,11 @@ public MethodHandle dotProductHandle7u() {
269299
return DOT_HANDLE_7U;
270300
}
271301

302+
@Override
303+
public MethodHandle dotProductHandle7uBulk() {
304+
return DOT_HANDLE_7U_BULK;
305+
}
306+
272307
@Override
273308
public MethodHandle squareDistanceHandle7u() {
274309
return SQR_HANDLE_7U;

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,23 @@ public class Similarities {
2222
.orElseThrow(AssertionError::new);
2323

2424
static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u();
25+
static final MethodHandle DOT_PRODUCT_7U_BULK = DISTANCE_FUNCS.dotProductHandle7uBulk();
2526
static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u();
2627

28+
static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment scores) {
29+
try {
30+
DOT_PRODUCT_7U_BULK.invokeExact(a, b, length, count, scores);
31+
} catch (Throwable e) {
32+
if (e instanceof Error err) {
33+
throw err;
34+
} else if (e instanceof RuntimeException re) {
35+
throw re;
36+
} else {
37+
throw new RuntimeException(e);
38+
}
39+
}
40+
}
41+
2742
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
2843
try {
2944
return (int) DOT_PRODUCT_7U.invokeExact(a, b, length);

libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,19 @@ private long nativeInt7DotProduct(byte[] q) throws IOException {
4848
return res;
4949
}
5050

51+
private void nativeInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
52+
final MemorySegment scoresSegment = MemorySegment.ofArray(scores);
53+
final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions * count);
54+
final MemorySegment querySegment = MemorySegment.ofArray(q);
55+
Similarities.dotProduct7uBulk(segment, querySegment, dimensions, count, scoresSegment);
56+
in.skipBytes(dimensions * count);
57+
}
58+
5159
@Override
5260
public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
5361
assert q.length == dimensions;
5462
if (NATIVE_SUPPORTED) {
55-
// TODO: can we speed up bulks in native code?
56-
for (int i = 0; i < count; i++) {
57-
scores[i] = nativeInt7DotProduct(q);
58-
}
63+
nativeInt7DotProductBulk(q, count, scores);
5964
} else {
6065
panamaInt7DotProductBulk(q, count, scores);
6166
}

0 commit comments

Comments
 (0)